diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index 5ec820a8..3ab02246 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -54,7 +54,7 @@ def prepare_model_for_training( input_embed: torch.nn.Embedding = model.get_input_embeddings() def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: - embeddings = input_embed.forward(x) + embeddings = torch.nn.Embedding.forward(self, x) if self.training: dims = self.num_embeddings * self.embedding_dim mag_norm = finetuning_args.neft_alpha / (dims ** 0.5) @@ -79,7 +79,7 @@ def prepare_model_for_training( input_dtype = output_layer.weight.dtype def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: - return output_layer.forward(x.to(input_dtype)).to(torch.float32) + return torch.nn.Linear.forward(self, x.to(input_dtype)).to(torch.float32) output_layer.forward = MethodType(forward_in_fp32, output_layer)