From 51a0016873437f1dfb1aa27c9ffe3cad99659b9b Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 30 Aug 2024 23:08:45 +0800 Subject: [PATCH] optimize predict vram Former-commit-id: a244f143f48a01910ce1cd56c0855ef11d62a72a --- data/README.md | 2 +- data/README_zh.md | 2 +- src/llamafactory/train/rm/trainer.py | 4 ++-- src/llamafactory/train/sft/metric.py | 2 +- src/llamafactory/train/sft/trainer.py | 10 +++++----- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/data/README.md b/data/README.md index 5a34bcbe..7e030ac1 100644 --- a/data/README.md +++ b/data/README.md @@ -172,7 +172,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh - [Example dataset](mllm_demo.json) -Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image. +Multimodal datasets require a `images` column containing the paths to the input images. ```json [ diff --git a/data/README_zh.md b/data/README_zh.md index 7456ed1d..cd0a4b0e 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -172,7 +172,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人 - [样例数据集](mllm_demo.json) -多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。 +多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。 ```json [ diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 45d9e26b..9ceebbe8 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -75,8 +75,8 @@ class PairwiseTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) def compute_loss( - self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 69327379..47657b75 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor if logits.dim() != 3: raise ValueError("Cannot process the logits.") - return torch.argmax(logits, dim=-1) + return torch.argmax(logits, dim=-1).cpu() @dataclass diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index e4958aa2..d08b2eda 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -78,16 +78,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def prediction_step( self, model: "torch.nn.Module", - inputs: Dict[str, Union[torch.Tensor, Any]], + inputs: Dict[str, Union["torch.Tensor", Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, - ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: r""" Removes the prompt part in the generated tokens. Subclass and override to inject custom behavior. """ - labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels + labels = inputs["labels"].detach().clone().cpu() if "labels" in inputs else None # backup labels (d2h) if self.args.predict_with_generate: assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) @@ -101,11 +101,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) if generated_tokens is not None and self.args.predict_with_generate: generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id - generated_tokens = generated_tokens.contiguous() + generated_tokens = generated_tokens.contiguous().cpu() # d2h return loss, generated_tokens, labels - def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: + def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor": r""" Pads the tensor to the same length as the target tensor. """