Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

using huggingface Trainer with distributed data parallel

To speed up performace I looked into pytorches DistributedDataParallel and tried to apply it to transformer Trainer.

The pytorch examples for DDP states that this should at least be faster:

DataParallel is single-process, multi-thread, and only works on a single machine, while DistributedDataParallel is multi-process and works for both single- and multi- machine training. DataParallel is usually slower than DistributedDataParallel even on a single machine due to GIL contention across threads, per-iteration replicated model, and additional overhead introduced by scattering inputs and gathering outputs.

My DataParallel trainer looks like this:

import os
from datetime import datetime
import sys
import torch
from transformers import Trainer, TrainingArguments, BertConfig

training_args = TrainingArguments(
        output_dir=os.path.join(path_storage, 'results', "mlm"),  # output directory
        num_train_epochs=1,  # total # of training epochs
        gradient_accumulation_steps=2,  # for accumulation over multiple steps
        per_device_train_batch_size=4,  # batch size per device during training
        per_device_eval_batch_size=4,  # batch size for evaluation
        logging_dir=os.path.join(path_storage, 'logs', "mlm"),  # directory for storing logs
        evaluate_during_training=False,
        max_steps=20,
    )

mlm_train_dataset = ProteinBertMaskedLMDataset(
        path_vocab, os.path.join(path_storage, "data", "uniparc", "uniparc_train_sorted.h5"),
)

mlm_config = BertConfig(
        vocab_size=mlm_train_dataset.tokenizer.vocab_size,
        max_position_embeddings=mlm_train_dataset.input_size
)
mlm_model = ProteinBertForMaskedLM(mlm_config)
trainer = Trainer(
   model=mlm_model,  # the instantiated 🤗 Transformers model to be trained
   args=training_args,  # training arguments, defined above
   train_dataset=mlm_train_dataset,  # training dataset
   data_collator=mlm_train_dataset.collate_fn,
)
print("build trainer with on device:", training_args.device, "with n gpus:", training_args.n_gpu)
start = datetime.now()
trainer.train()
print(f"finished in {datetime.now() - start} seconds")

The output:

build trainer with on device: cuda:0 with n gpus: 4
finished in 0:02:47.537038 seconds

My DistributedDataParallel trainer is build like this:

def create_transformer_trainer(rank, world_size, train_dataset, model):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)

    training_args = TrainingArguments(
        output_dir=os.path.join(path_storage, 'results', "mlm"),  # output directory
        num_train_epochs=1,  # total # of training epochs
        gradient_accumulation_steps=2,  # for accumulation over multiple steps
        per_device_train_batch_size=4,  # batch size per device during training
        per_device_eval_batch_size=4,  # batch size for evaluation
        logging_dir=os.path.join(path_storage, 'logs', "mlm"),  # directory for storing logs
        local_rank=rank,
        max_steps=20,
    )

    trainer = Trainer(
        model=model,  # the instantiated 🤗 Transformers model to be trained
        args=training_args,  # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        data_collator=train_dataset.collate_fn,
    )
    print("build trainer with on device:", training_args.device, "with n gpus:", training_args.n_gpu)
    start = datetime.now()
    trainer.train()
    print(f"finished in {datetime.now() - start} seconds")


mlm_train_dataset = ProteinBertMaskedLMDataset(
    path_vocab, os.path.join(path_storage, "data", "uniparc", "uniparc_train_sorted.h5"))

mlm_config = BertConfig(
    vocab_size=mlm_train_dataset.tokenizer.vocab_size,
    max_position_embeddings=mlm_train_dataset.input_size
)
mlm_model = ProteinBertForMaskedLM(mlm_config)
torch.multiprocessing.spawn(create_transformer_trainer,
     args=(4, mlm_train_dataset, mlm_model),
     nprocs=4,
     join=True)

The output:

The current process just got forked. Disabling parallelism to avoid deadlocks...
To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)
The current process just got forked. Disabling parallelism to avoid deadlocks...
To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)
The current process just got forked. Disabling parallelism to avoid deadlocks...
To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)
The current process just got forked. Disabling parallelism to avoid deadlocks...
To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)
The current process just got forked. Disabling parallelism to avoid deadlocks...
To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)
build trainer with on device: cuda:1 with n gpus: 1
build trainer with on device: cuda:2 with n gpus: 1
build trainer with on device: cuda:3 with n gpus: 1
build trainer with on device: cuda:0 with n gpus: 1
finished in 0:04:15.937331 seconds
finished in 0:04:16.899411 seconds
finished in 0:04:16.938141 seconds
finished in 0:04:17.391887 seconds

About the inital forking warning: What is exaclty forked and is this expected?

And about the resulting time: Is the trainer incorrectly used since it seemed to be a lot slower than the DataParallel approach?

like image 888
sheming Avatar asked May 02 '26 19:05

sheming


1 Answers

Kind of late to the party but anyway. I'm going to leave this comment here to help anyone wondering if it is possible to keep the parallelism in the tokenizer.

According to this comment on github the FastTokenizers seem to be the issue. Also according to another comment on gitmemory you shouldn't use the tokenizer before forking the process. (which basically means before iterating through your dataloader)

So the solution is to not use FastTokenizers before training/fine-tuning or use the normal Tokenizers.

Check the huggingface documentation to find out if you really need the FastTokenizer.

like image 66
mingaflo Avatar answered May 04 '26 08:05

mingaflo