Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why cant I set TrainingArguments.device in Huggingface?

Question

When I try to set the .device attribute to torch.device('cpu'), I get an error. How am I supposed to set device then?

Python Code

from transformers import TrainingArguments
from transformers import Trainer
import torch

training_args = TrainingArguments(
    output_dir="./some_local_dir",
    overwrite_output_dir=True,

    per_device_train_batch_size=4,
    dataloader_num_workers=2,

    max_steps=500,
    logging_steps=1,

    evaluation_strategy="steps",
    eval_steps=5
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

training_args.device = torch.device('cpu')

Python Error

AttributeError                            Traceback (most recent call last)
<ipython-input-11-30a92c0570b8> in <cell line: 28>()
     26 )
     27 
---> 28 training_args.device = torch.device('cpu')

AttributeError: can't set attribute
like image 734
AlanSTACK Avatar asked Sep 11 '25 03:09

AlanSTACK


2 Answers

There is a parameter in TrainingArguments called no_cuda. If you set that to True, training will take place on the CPU even if you have a GPU in your setup. For example, the following code worked for me:

    # initialize the trainer
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=1,
        per_device_train_batch_size=1,
        fp16=False,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        prediction_loss_only=True,
        no_cuda=True
    )
    trainer = Trainer(
        model,
        training_args,
        train_dataset=tokenized_dataset["train"],
    )

    # execute the training!
    trainer.train()
like image 115
Jason D Avatar answered Sep 13 '25 18:09

Jason D


You do not need to set the device in training args. The training will take place on the device of the model. The following code should help you train your model on cpu

device = torch.device('cpu')
model = model.to(device)

training_args.device is an attribute which you can only read and not set, hence the error.

like image 44
DareGhost Avatar answered Sep 13 '25 17:09

DareGhost