Former-commit-id: 142191e4664cb1b920aff2f51d1bac6180f2c24b
This commit is contained in:
hiyouga 2024-12-17 10:06:46 +00:00
parent 0f49e9cb07
commit 50ca43c3fb
3 changed files with 26 additions and 16 deletions

View File

@ -106,9 +106,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_input_ids[0] = features[0]["input_ids"]
@ -123,7 +129,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], _ = self.model.get_rope_index(
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),

View File

@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import ProcessorMixin
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
@ -53,6 +53,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
@ -113,7 +115,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
"""
labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
@ -125,7 +127,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
@ -134,8 +136,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
assert self.processing_class.pad_token_id is not None, "Pad token is required."
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
@ -152,20 +154,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
)
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):

View File

@ -56,7 +56,7 @@ def run_sft(
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,