From da8721a70ec3da5b8be650bd1106e561303c487a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 2 Jan 2025 11:17:29 +0000 Subject: [PATCH] fix #6499 Former-commit-id: 1800f8c72dfa618c71c84a3a18ecdef4d82754f7 --- src/llamafactory/data/preprocess.py | 4 +-- src/llamafactory/data/processors/pretrain.py | 5 ++++ .../data/processors/unsupervised.py | 2 ++ src/llamafactory/train/sft/trainer.py | 29 +++++-------------- src/llamafactory/train/sft/workflow.py | 4 --- 5 files changed, 17 insertions(+), 27 deletions(-) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 9f015b38..c5a10ec9 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example -from .processors.pretrain import preprocess_pretrain_dataset +from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example from .processors.supervised import ( preprocess_packed_supervised_dataset, preprocess_supervised_dataset, @@ -47,7 +47,7 @@ def get_preprocess_and_print_func( tokenizer=tokenizer, data_args=data_args, ) - print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not do_generate: if data_args.packing: if data_args.neat_packing: # hack datasets to have int32 attention mask diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 6d6b98d6..2cee40e5 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -52,3 +52,8 @@ def preprocess_pretrain_dataset( result["input_ids"][i][0] = tokenizer.bos_token_id return result + + +def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index bc5ad34c..e21ebd42 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset( def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False))) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 45998262..28ec25eb 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -111,40 +111,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): inputs: Dict[str, Union["torch.Tensor", Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, + **gen_kwargs, ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: r""" Removes the prompt part in the generated tokens. Subclass and override to inject custom behavior. """ - labels = inputs["labels"] if "labels" in inputs else None - if self.args.predict_with_generate: - assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor." - labels = labels.detach().clone() if labels is not None else None # backup labels - prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) - if prompt_len > label_len: - inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) - if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) - inputs["labels"] = inputs["labels"][:, :prompt_len] + if self.args.predict_with_generate: # do not pass labels to model when generate + labels = inputs.pop("labels", None) + else: + labels = inputs.get("labels") - loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) - model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + loss, generated_tokens, _ = super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs ) if generated_tokens is not None and self.args.predict_with_generate: - generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id + generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id generated_tokens = generated_tokens.contiguous() return loss, generated_tokens, labels - def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor": - r""" - Pads the tensor to the same length as the target tensor. - """ - assert self.processing_class.pad_token_id is not None, "Pad token is required." - padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor) - padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding - return padded_tensor.contiguous() # in contiguous memory - def save_predictions( self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True ) -> None: diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5f4a09cc..1ccfa9ef 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -117,8 +117,6 @@ def run_sft( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) - if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled - metrics.pop("eval_loss", None) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) @@ -126,8 +124,6 @@ def run_sft( if training_args.do_predict: logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) - if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled - predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)