optimize predict vram

Former-commit-id: a244f143f48a01910ce1cd56c0855ef11d62a72a
This commit is contained in:
hiyouga 2024-08-30 23:08:45 +08:00
parent c883542583
commit 51a0016873
5 changed files with 10 additions and 10 deletions

View File

@ -172,7 +172,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
- [Example dataset](mllm_demo.json) - [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image. Multimodal datasets require a `images` column containing the paths to the input images.
```json ```json
[ [

View File

@ -172,7 +172,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
- [样例数据集](mllm_demo.json) - [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。 多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。
```json ```json
[ [

View File

@ -75,8 +75,8 @@ class PairwiseTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

View File

@ -54,7 +54,7 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
if logits.dim() != 3: if logits.dim() != 3:
raise ValueError("Cannot process the logits.") raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1) return torch.argmax(logits, dim=-1).cpu()
@dataclass @dataclass

View File

@ -78,16 +78,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def prediction_step( def prediction_step(
self, self,
model: "torch.nn.Module", model: "torch.nn.Module",
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""
Removes the prompt part in the generated tokens. Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels labels = inputs["labels"].detach().clone().cpu() if "labels" in inputs else None # backup labels (d2h)
if self.args.predict_with_generate: if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
@ -101,11 +101,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous() generated_tokens = generated_tokens.contiguous().cpu() # d2h
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
""" """