mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
tiny fix
Former-commit-id: 3d3cc6705d4575f7f20bf4da2b7dab60b337006b
This commit is contained in:
parent
66e473d519
commit
fd79cf8551
@ -183,6 +183,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai/Pixtral-12B-2409) | 12B | pixtral |
|
||||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
@ -184,6 +184,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai/Pixtral-12B-2409) | 12B | pixtral |
|
||||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
@ -325,14 +325,6 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
class PixtralPlugin(BasePlugin):
|
class PixtralPlugin(BasePlugin):
|
||||||
# @override
|
|
||||||
# def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
|
||||||
# image = super()._preprocess_image(image, **kwargs)
|
|
||||||
# UP_SIZE = (512,512)
|
|
||||||
# image = image.resize(UP_SIZE, resample=Image.NEAREST)
|
|
||||||
|
|
||||||
# return image
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
@ -396,16 +388,15 @@ class PixtralPlugin(BasePlugin):
|
|||||||
seqlens: Sequence[int],
|
seqlens: Sequence[int],
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
if mm_inputs.get('image_sizes'):
|
if mm_inputs.get("image_sizes"):
|
||||||
del mm_inputs['image_sizes']
|
mm_inputs.pop("image_sizes")
|
||||||
# TODO fix this type error
|
|
||||||
# if isinstance(mm_inputs.get("pixel_values"), list): #List[List[torch.tensor]] -> [B C W H]
|
if isinstance(mm_inputs.get("pixel_values"), list) and len(mm_inputs.get("pixel_values")[0]) >= 2:
|
||||||
# recommend for batch==1 for one gpu or it will rise the error of BatchEncoding.
|
raise ValueError("Now it only supports batchsize=1 on per gpu due to `List[tensor]` can not pack into BachEncoding")
|
||||||
|
|
||||||
mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0)
|
mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0)
|
||||||
# mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")
|
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
@ -1060,8 +1060,8 @@ register_model_group(
|
|||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Pixtral-12B-2409": {
|
"Pixtral-12B-2409": {
|
||||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
DownloadSource.DEFAULT: "mistralai/Pixtral-12B-2409",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
DownloadSource.MODELSCOPE: "LLM-Research/Pixtral-12B-2409",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="mistral",
|
template="mistral",
|
||||||
|
@ -96,6 +96,9 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||||
|
# TODO check it
|
||||||
|
elif model_type == "pixtral":
|
||||||
|
mm_projector: "torch.nn.Module" = getattr(model, "vision_language_adapte")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -122,9 +125,11 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "paligemma", "pixtral"]:
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
#TODO check it
|
||||||
|
forbidden_modules.add("vision_encoder")
|
||||||
|
|
||||||
if finetuning_args.train_mm_proj_only:
|
if finetuning_args.train_mm_proj_only:
|
||||||
forbidden_modules.add("language_model")
|
forbidden_modules.add("language_model")
|
||||||
@ -150,7 +155,7 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
|||||||
image_seqlen += 1
|
image_seqlen += 1
|
||||||
elif model_type == "paligemma":
|
elif model_type == "paligemma":
|
||||||
image_seqlen = config.vision_config.num_image_tokens
|
image_seqlen = config.vision_config.num_image_tokens
|
||||||
elif model_type == "qwen2_vl": # variable length
|
elif model_type in ["qwen2_vl", "pixtral"]: # variable length
|
||||||
image_seqlen = -1
|
image_seqlen = -1
|
||||||
|
|
||||||
return image_seqlen
|
return image_seqlen
|
||||||
@ -168,10 +173,14 @@ def patch_target_modules(
|
|||||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif model_type == "pixtral":
|
||||||
|
return "^(?!.*vision_encoder).*(?:{}).*".format("|".join(target_modules))
|
||||||
else:
|
else:
|
||||||
return target_modules
|
return target_modules
|
||||||
else:
|
else:
|
||||||
if model_type == "qwen2_vl":
|
if model_type == "qwen2_vl":
|
||||||
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif model_type == "pixtral":
|
||||||
|
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
|
||||||
else:
|
else:
|
||||||
return target_modules
|
return target_modules
|
||||||
|
Loading…
x
Reference in New Issue
Block a user