From 3e3652f4e0f3b08d57b9abc90e9b80d63a3ee984 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 28 Jul 2023 17:02:26 +0800 Subject: [PATCH] fix #268 Former-commit-id: 91dd17d8a6fcb0a154f29a2d1ff9f4266b720b9e --- src/llmtuner/extras/misc.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 40c8d337..9c4e165e 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -28,7 +28,7 @@ class AverageMeter: self.avg = self.sum / self.count -# Avoid runtime error in model.generate(do_sample=True). +# Avoids runtime error in model.generate(do_sample=True). class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -63,7 +63,7 @@ def print_trainable_params(model: torch.nn.Module) -> None: def prepare_model_for_training( model: PreTrainedModel, finetuning_type: str, - output_embedding_layer_name: Optional[str] = "lm_head", + output_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES ) -> PreTrainedModel: @@ -83,19 +83,23 @@ def prepare_model_for_training( model.gradient_checkpointing_enable() model.config.use_cache = False # turn off when gradient checkpointing is enabled - if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): - output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) - input_dtype = output_embedding_layer.weight.dtype + if finetuning_type != "full" and hasattr(model, output_layer_name): + output_layer: torch.nn.Linear = getattr(model, output_layer_name) + input_dtype = output_layer.weight.dtype class CastOutputToFloat(torch.nn.Sequential): def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x.to(input_dtype)).to(torch.float32) - setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) - + new_output_layer = CastOutputToFloat(output_layer) + # adapt to LLaMA-2's pretraining_tp (actually LLaMA models can automatically do casting but BLOOM models cannot) + # (https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py#L819) + setattr(new_output_layer, "weight", output_layer.weight) + setattr(model, output_layer_name, new_output_layer) return model + def torch_gc() -> None: r""" Collects GPU memory.