Merge pull request #6512 from hiyouga/hiyouga/fix_gen_logic

[trainer] fix generate logic

Former-commit-id: 72d86ecc9e327933a0a2c893b8ffd2740c99be6b
This commit is contained in:
hoshi-hiyouga 2025-01-02 19:36:54 +08:00 committed by GitHub
commit b921dde749
5 changed files with 17 additions and 27 deletions

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import ( from .processors.supervised import (
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
preprocess_supervised_dataset, preprocess_supervised_dataset,
@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer, tokenizer=tokenizer,
data_args=data_args, data_args=data_args,
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask if data_args.neat_packing: # hack datasets to have int32 attention mask

View File

@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result["input_ids"][i][0] = tokenizer.bos_token_id result["input_ids"][i][0] = tokenizer.bos_token_id
return result return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))

View File

@ -111,40 +111,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs: Dict[str, Union["torch.Tensor", Any]], inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""
Removes the prompt part in the generated tokens. Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"] if "labels" in inputs else None if self.args.predict_with_generate: # do not pass labels to model when generate
if self.args.predict_with_generate: labels = inputs.pop("labels", None)
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor." else:
labels = labels.detach().clone() if labels is not None else None # backup labels labels = inputs.get("labels")
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous() generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.processing_class.pad_token_id is not None, "Pad token is required."
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None: ) -> None:

View File

@ -117,8 +117,6 @@ def run_sft(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
@ -126,8 +124,6 @@ def run_sft(
if training_args.do_predict: if training_args.do_predict:
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens) trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)