diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index e07a1936..851767f8 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -36,31 +36,44 @@ class Seq2SeqPeftTrainer(PeftTrainer): inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) if label_len > prompt_len: inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"]) + if "attention_mask" in inputs: + inputs["attention_mask"] = self._pad_tensors_to_target_len( + inputs["attention_mask"], inputs["labels"], pad_token_id=0 + ) + if "position_ids" in inputs: + inputs["position_ids"] = self._pad_tensors_to_target_len( + inputs["position_ids"], inputs["labels"], pad_token_id=0 + ) loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) - generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None + generated_tokens = ( + generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None + ) return (loss, generated_tokens, labels) - def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: + def _pad_tensors_to_target_len( + self, + src_tensor: torch.Tensor, + tgt_tensor: torch.Tensor, + pad_token_id: Optional[int] = None + ) -> torch.Tensor: r""" Pads the tensor to the same length as the target tensor. Should only be called when predict_with_generate=True. """ - if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): - assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." - # If PAD token is not defined at least EOS token has to be defined - pad_token_id = ( - self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id - ) - else: - if self.model.config.pad_token_id is not None: - pad_token_id = self.model.config.pad_token_id + if pad_token_id is None: + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + pad_token_id = self.tokenizer.pad_token_id else: - raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + if self.model.config.pad_token_id is not None: + pad_token_id = self.model.config.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model.") padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding