mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
adapt to new mllm_param
Former-commit-id: 291384dea8a5c10f0358a30d124eaf85557548eb
This commit is contained in:
parent
d5b18ee4a6
commit
08e8499a98
@ -37,10 +37,6 @@ if TYPE_CHECKING:
|
|||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
def pad(seq, padding_value=0):
|
|
||||||
return pad_sequence(seq, batch_first=True, padding_value=padding_value)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||||
r"""
|
r"""
|
||||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||||
@ -159,14 +155,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
if "image_bound" in features: # for minicpmv inputs
|
if "image_bound" in features: # for minicpmv inputs
|
||||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||||
features["input_ids"] = pad(
|
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
|
||||||
features["input_ids"],
|
features["labels"] = pad_sequence(features["labels"], batch_first=True, padding_value=-100)
|
||||||
)
|
features["attention_mask"] = pad_sequence(features["attention_mask"], batch_first=True, padding_value=0)
|
||||||
features["position_ids"] = pad(features["position_ids"])
|
|
||||||
features["labels"] = pad(features["labels"], padding_value=-100)
|
|
||||||
features["attention_mask"] = pad(
|
|
||||||
features["attention_mask"],
|
|
||||||
)
|
|
||||||
new_features = {}
|
new_features = {}
|
||||||
new_features.update({"data": features})
|
new_features.update({"data": features})
|
||||||
new_features.update(features)
|
new_features.update(features)
|
||||||
|
@ -171,13 +171,6 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
|
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
|
||||||
forbidden_modules.update(language_model_keys)
|
forbidden_modules.update(language_model_keys)
|
||||||
|
|
||||||
elif model_type == "minicpmv":
|
|
||||||
if finetuning_args.freeze_vision_tower:
|
|
||||||
forbidden_modules.add("vpm")
|
|
||||||
forbidden_modules.add("apm")
|
|
||||||
forbidden_modules.add("resampler")
|
|
||||||
forbidden_modules.add("tts")
|
|
||||||
|
|
||||||
return forbidden_modules
|
return forbidden_modules
|
||||||
|
|
||||||
|
|
||||||
@ -257,6 +250,12 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="minicpmv",
|
||||||
|
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="paligemma",
|
model_type="paligemma",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user