mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
Merge pull request #6512 from hiyouga/hiyouga/fix_gen_logic
[trainer] fix generate logic Former-commit-id: 72d86ecc9e327933a0a2c893b8ffd2740c99be6b
This commit is contained in:
commit
b921dde749
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
from .processors.feedback import preprocess_feedback_dataset
|
||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||
from .processors.pretrain import preprocess_pretrain_dataset
|
||||
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
|
||||
from .processors.supervised import (
|
||||
preprocess_packed_supervised_dataset,
|
||||
preprocess_supervised_dataset,
|
||||
@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing: # hack datasets to have int32 attention mask
|
||||
|
@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
|
||||
|
@ -111,40 +111,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
inputs: Dict[str, Union["torch.Tensor", Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
**gen_kwargs,
|
||||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
labels = inputs["labels"] if "labels" in inputs else None
|
||||
if self.args.predict_with_generate:
|
||||
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
labels = labels.detach().clone() if labels is not None else None # backup labels
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
||||
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
||||
if self.args.predict_with_generate: # do not pass labels to model when generate
|
||||
labels = inputs.pop("labels", None)
|
||||
else:
|
||||
labels = inputs.get("labels")
|
||||
|
||||
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
|
||||
)
|
||||
if generated_tokens is not None and self.args.predict_with_generate:
|
||||
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
|
||||
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
|
||||
generated_tokens = generated_tokens.contiguous()
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
"""
|
||||
assert self.processing_class.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(
|
||||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
|
||||
) -> None:
|
||||
|
@ -117,8 +117,6 @@ def run_sft(
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||
metrics.pop("eval_loss", None)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
@ -126,8 +124,6 @@ def run_sft(
|
||||
if training_args.do_predict:
|
||||
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||
|
Loading…
x
Reference in New Issue
Block a user