mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #6359 from hiyouga/hiyouga/fix_qwen2vl_infer
[model] fix qwen2vl infern Former-commit-id: 81815f053f9eef23fa4906cc47496806cfc1735c
This commit is contained in:
commit
4caf043cf8
@ -106,9 +106,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
if self.tokenizer.padding_side == "right":
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
||||
else:
|
||||
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
|
||||
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
|
||||
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
|
||||
|
||||
batch_images = fake_images
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
@ -123,7 +129,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
features["position_ids"], _ = self.model.get_rope_index(
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
|
||||
input_ids=features["input_ids"],
|
||||
image_grid_thw=mm_inputs.get("image_grid_thw", None),
|
||||
video_grid_thw=mm_inputs.get("video_grid_thw", None),
|
||||
|
@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import ProcessorMixin
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
@ -53,6 +53,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@ -113,7 +115,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
"""
|
||||
labels = inputs["labels"] if "labels" in inputs else None
|
||||
if self.args.predict_with_generate:
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
labels = labels.detach().clone() if labels is not None else None # backup labels
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
@ -125,7 +127,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
if generated_tokens is not None and self.args.predict_with_generate:
|
||||
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
|
||||
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
|
||||
generated_tokens = generated_tokens.contiguous()
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
@ -134,8 +136,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
"""
|
||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
assert self.processing_class.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
@ -152,20 +154,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
labels = np.where(
|
||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
|
||||
)
|
||||
preds = np.where(
|
||||
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
|
||||
predict_results.predictions != IGNORE_INDEX,
|
||||
predict_results.predictions,
|
||||
self.processing_class.pad_token_id,
|
||||
)
|
||||
|
||||
for i in range(len(preds)):
|
||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
||||
|
@ -56,7 +56,7 @@ def run_sft(
|
||||
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=model,
|
||||
model=model if not training_args.predict_with_generate else None,
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
|
Loading…
x
Reference in New Issue
Block a user