diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 4686692585dfa..32a965551e65f 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -20,6 +20,7 @@ from transformers import PreTrainedModel, Trainer, logging from transformers.file_utils import is_torch_tpu_available +from transformers.integrations import is_fairscale_available from transformers.models.fsmt.configuration_fsmt import FSMTConfig from transformers.optimization import ( Adafactor, @@ -35,6 +36,10 @@ from transformers.training_args import ParallelMode +if is_fairscale_available(): + from fairscale.optim import OSS + + logger = logging.get_logger(__name__) arg_to_scheduler = { @@ -99,18 +104,25 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): "weight_decay": 0.0, }, ] + optimizer_cls = Adafactor if self.args.adafactor else AdamW if self.args.adafactor: - self.optimizer = Adafactor( - optimizer_grouped_parameters, - lr=self.args.learning_rate, - scale_parameter=False, - relative_step=False, - ) - + optimizer_cls = Adafactor + optimizer_kwargs = {"scale_parameter": False, "relative_step": False} else: - self.optimizer = AdamW( - optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon + optimizer_cls = AdamW + optimizer_kwargs = { + "betas": (self.args.adam_beta1, self.args.adam_beta2), + "eps": self.args.adam_epsilon, + } + optimizer_kwargs["lr"] = self.args.learning_rate + if self.sharded_dpp: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if self.lr_scheduler is None: self.lr_scheduler = self._get_lr_scheduler(num_training_steps) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 78495590b7d52..ecc2a9f635a68 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -92,6 +92,13 @@ except ImportError: _has_mlflow = False +try: + import fairscale # noqa: F401 + + _has_fairscale = True +except ImportError: + _has_fairscale = False + # No transformer imports above this point from .file_utils import is_torch_tpu_available # noqa: E402 @@ -128,6 +135,10 @@ def is_mlflow_available(): return _has_mlflow +def is_fairscale_available(): + return _has_fairscale + + def hp_params(trial): if is_optuna_available(): if isinstance(trial, optuna.Trial): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 102bf090ceb3f..d6c5a23d293fb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -33,6 +33,7 @@ hp_params, is_azureml_available, is_comet_available, + is_fairscale_available, is_mlflow_available, is_optuna_available, is_ray_available, @@ -153,6 +154,11 @@ DEFAULT_CALLBACKS.append(AzureMLCallback) +if is_fairscale_available(): + from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + logger = logging.get_logger(__name__) @@ -285,6 +291,16 @@ def __init__( if isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(self.eval_dataset, description="evaluation") + # Setup Sharded DDP training + self.sharded_dpp = False + if args.sharded_ddp: + if args.local_rank == -1: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + else: + self.sharded_dpp = True + # Mixed precision setup self.use_apex = False self.use_amp = False @@ -296,7 +312,7 @@ def __init__( if backend == "amp": self.use_amp = True - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler() else: if not is_apex_available(): raise ImportError( @@ -491,12 +507,21 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): "weight_decay": 0.0, }, ] - self.optimizer = AdamW( - optimizer_grouped_parameters, - lr=self.args.learning_rate, - betas=(self.args.adam_beta1, self.args.adam_beta2), - eps=self.args.adam_epsilon, - ) + if self.sharded_dpp: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=AdamW, + lr=self.args.learning_rate, + betas=(self.args.adam_beta1, self.args.adam_beta2), + eps=self.args.adam_epsilon, + ) + else: + self.optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.args.learning_rate, + betas=(self.args.adam_beta1, self.args.adam_beta2), + eps=self.args.adam_epsilon, + ) if self.lr_scheduler is None: self.lr_scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps @@ -643,7 +668,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) - if self.args.local_rank != -1: + if self.sharded_dpp: + model = ShardedDDP(model, self.optimizer) + elif self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], @@ -654,8 +681,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D else True ), ) - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # Train! if is_torch_tpu_available(): @@ -895,6 +922,8 @@ def _save_checkpoint(self, model, trial, metrics=None): self.save_model(output_dir) # Save optimizer and scheduler + if self.sharded_dpp: + self.optimizer.consolidate_state_dict() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0e11b9e16801a..0a02bd2232851 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -215,6 +215,9 @@ class TrainingArguments: The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the requested backend. + sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): + Use Sharded DDP training from `FairScale `__ (in distributed + training only). This is an experimental feature. """ output_dir: str = field( @@ -386,6 +389,10 @@ class TrainingArguments: default="auto", metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, ) + sharded_ddp: bool = field( + default=False, + metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, + ) def __post_init__(self): if self.disable_tqdm is None: