Former-commit-id: 142191e4664cb1b920aff2f51d1bac6180f2c24b
This commit is contained in:
hiyouga 2024-12-17 10:06:46 +00:00
parent 0f49e9cb07
commit 50ca43c3fb
3 changed files with 26 additions and 16 deletions

View File

@ -106,9 +106,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids if self.tokenizer.padding_side == "right":
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids) features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images batch_images = fake_images
batch_input_ids[0] = features[0]["input_ids"] batch_input_ids[0] = features[0]["input_ids"]
@ -123,7 +129,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], _ = self.model.get_rope_index( features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"], input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None), image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None), video_grid_thw=mm_inputs.get("video_grid_thw", None),

View File

@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
@ -53,6 +53,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None: ) -> None:
if is_transformers_version_greater_than("4.46"): if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
@ -113,7 +115,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
""" """
labels = inputs["labels"] if "labels" in inputs else None labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate: if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len: if prompt_len > label_len:
@ -125,7 +127,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous() generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels return loss, generated_tokens, labels
@ -134,8 +136,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
""" """
assert self.tokenizer.pad_token_id is not None, "Pad token is required." assert self.processing_class.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
@ -152,20 +154,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger.info_rank0(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where( labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
) )
preds = np.where( preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
) )
for i in range(len(preds)): for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as f: with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):

View File

@ -56,7 +56,7 @@ def run_sft(
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
template=template, template=template,
model=model, model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,