From 4583c534f96c2eeb3c0cd13d803dfff507200fc6 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Mon, 13 May 2024 23:28:28 +0800 Subject: [PATCH 01/12] add yi-vl Former-commit-id: 64dac4085e3949f20ab66e507cfb199b09189ead --- src/llmtuner/data/template.py | 14 ++++++++++++++ src/llmtuner/model/patcher.py | 3 ++- src/llmtuner/model/utils/visual.py | 26 +++++++++++++++++++++++++- src/llmtuner/train/sft/trainer.py | 12 ++++++++++-- src/llmtuner/train/sft/workflow.py | 2 ++ 5 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index ada6cfcd..7ab31147 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -856,6 +856,20 @@ _register_template( ) +_register_template( + name="yi-vl", + format_user=StringFormatter(slots=["### Human:\n{{content}}\n### Assistant: "]), + stop_words=["###"], + default_system=( + "This is a chat between an inquisitive human and an AI assistant. " + "Assume the role of the AI assistant. " + "Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers." + "这是一个好奇的人类和一个人工智能助手之间的对话。" + "假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。" + ), +) + + _register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index fd99bd3b..b7cad67c 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -16,7 +16,7 @@ from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope from .utils.valuehead import prepare_valuehead_model -from .utils.visual import autocast_projector_dtype, configure_hidden_size +from .utils.visual import autocast_projector_dtype, configure_hidden_size, configure_visual if TYPE_CHECKING: @@ -50,6 +50,7 @@ def patch_config( configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) configure_hidden_size(config) + configure_visual(config, model_args) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index b29a9ba5..d1556bb3 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -1,12 +1,14 @@ from typing import TYPE_CHECKING, Tuple import torch +import transformers +from torch import nn from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel + from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig from ...hparams import ModelArguments @@ -31,3 +33,25 @@ def autocast_projector_dtype( logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) + + +class LlavaMultiModalProjectorYiVL(nn.Module): + def __init__(self, config: "LlavaConfig"): + super().__init__() + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.act = nn.GELU() + self.proj = nn.Sequential(*[self.linear_1, self.linear_2, self.act, self.linear_3, self.linear_4]) + + def forward(self, image_features): + hidden_states = self.proj(image_features) + return hidden_states + + +def configure_visual(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + logger = get_logger(__name__) + if model_args.visual_inputs and "Yi" in getattr(config.text_config, "_name_or_path", None): + transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorYiVL + logger.info("Patched Multimodal Projector for Yi-VL.") diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index def427fd..1b456e50 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import Seq2SeqTrainer +from transformers import Seq2SeqTrainer, ProcessorMixin from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -26,9 +26,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. """ - def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: + def __init__(self, finetuning_args: "FinetuningArguments", processor: "ProcessorMixin", **kwargs) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args + self.processor = processor if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -120,3 +121,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + super().save_model(output_dir, _internal_call) + if self.processor is not None: + if output_dir is None: + output_dir = self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) \ No newline at end of file diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 4a9775b4..3b7b909a 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -30,6 +30,7 @@ def run_sft( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) @@ -55,6 +56,7 @@ def run_sft( model=model, args=training_args, finetuning_args=finetuning_args, + processor=processor, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, From 661565fc2e8a0b64f7bc7d50f6db87e5a933810f Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Tue, 14 May 2024 14:03:19 +0800 Subject: [PATCH 02/12] add support for Yi-VL Former-commit-id: ab3464ce6530830c14fde68f0a8990185db80592 --- src/llmtuner/model/utils/visual.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index d1556bb3..79a6570e 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -43,10 +43,13 @@ class LlavaMultiModalProjectorYiVL(nn.Module): self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True) self.act = nn.GELU() - self.proj = nn.Sequential(*[self.linear_1, self.linear_2, self.act, self.linear_3, self.linear_4]) def forward(self, image_features): - hidden_states = self.proj(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) return hidden_states From 9e247245a2e4137394a19643df72b70e022b43d1 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Tue, 14 May 2024 16:45:28 +0800 Subject: [PATCH 03/12] modify yi-vl template Former-commit-id: d72e6f8dfd670533f3bbdf0bf5e7d596e2dd34ac --- src/llmtuner/data/template.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 7fa2ccce..631c79c1 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -857,15 +857,16 @@ _register_template( _register_template( - name="yi-vl", - format_user=StringFormatter(slots=["### Human:\n{{content}}\n### Assistant: "]), + name="yivl", + format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), + format_assistant=StringFormatter(slots=[" {{content}}"]), stop_words=["###"], default_system=( "This is a chat between an inquisitive human and an AI assistant. " "Assume the role of the AI assistant. " "Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers." "这是一个好奇的人类和一个人工智能助手之间的对话。" - "假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。" + "假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n" ), ) From 92b184101f337e07818a3c0844b5b7865552e029 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 15 May 2024 09:54:00 +0800 Subject: [PATCH 04/12] add yivl and save processor to model_dir Former-commit-id: afc6c7b9fd350f9f611a220363a3caa930ac56aa --- src/llmtuner/data/template.py | 2 +- src/llmtuner/model/utils/visual.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 631c79c1..0b2ca0e6 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -859,7 +859,7 @@ _register_template( _register_template( name="yivl", format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), - format_assistant=StringFormatter(slots=[" {{content}}"]), + format_assistant=StringFormatter(slots=[" {{content}}\n"]), stop_words=["###"], default_system=( "This is a chat between an inquisitive human and an AI assistant. " diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 79a6570e..8553cf86 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -29,7 +29,8 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): + if hasattr(model, mm_projector_name) and (getattr(model.config, "quantization_method", None) + or "Yi" in getattr(model.config.text_config, "_name_or_path", None)): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) From dbc7b1c0464049951e3a4822a1ee6fc978266866 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 15 May 2024 10:18:10 +0800 Subject: [PATCH 05/12] modify style Former-commit-id: 771bed5bde510f3893d12cafc4163409d6cb21f3 --- src/llmtuner/model/utils/visual.py | 8 +++++--- src/llmtuner/train/sft/trainer.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 8553cf86..0dc844f5 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -8,7 +8,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from ...hparams import ModelArguments @@ -29,8 +29,10 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and (getattr(model.config, "quantization_method", None) - or "Yi" in getattr(model.config.text_config, "_name_or_path", None)): + if hasattr(model, mm_projector_name) and ( + getattr(model.config, "quantization_method", None) + or "Yi" in getattr(model.config.text_config, "_name_or_path", None) + ): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 1b456e50..5f187375 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import Seq2SeqTrainer, ProcessorMixin +from transformers import ProcessorMixin, Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -127,4 +127,4 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): if self.processor is not None: if output_dir is None: output_dir = self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) \ No newline at end of file + getattr(self.processor, "image_processor").save_pretrained(output_dir) From 3f38ef9f59ccf8d024b702187b6f3fa35d3cddbb Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 15 May 2024 11:22:15 +0800 Subject: [PATCH 06/12] cast dtype in mm_proj Former-commit-id: d2bf69740043012a0025dd9d80c7adf979dc3a88 --- src/llmtuner/model/utils/visual.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 0dc844f5..b8696096 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -8,7 +8,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, LlavaForConditionalGeneration from ...hparams import ModelArguments @@ -29,10 +29,7 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and ( - getattr(model.config, "quantization_method", None) - or "Yi" in getattr(model.config.text_config, "_name_or_path", None) - ): + if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) @@ -48,11 +45,13 @@ class LlavaMultiModalProjectorYiVL(nn.Module): self.act = nn.GELU() def forward(self, image_features): + dtype_ = self.linear_1.weight.dtype hidden_states = self.linear_1(image_features) hidden_states = self.linear_2(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_3(hidden_states) hidden_states = self.linear_4(hidden_states) + hidden_states = hidden_states.to(dtype_) return hidden_states From e1c2ff41a04856fb97f2defaa8a1c4556ff214fb Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 15 May 2024 12:48:18 +0800 Subject: [PATCH 07/12] rm extra import Former-commit-id: db1622f76b0fe9d669af206299ecec10954647af --- src/llmtuner/model/utils/visual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index b8696096..1f770861 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -8,7 +8,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, LlavaForConditionalGeneration + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from ...hparams import ModelArguments From 7622300c4b13a97c0600358b9e5e8048879c7d20 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 14:13:01 +0800 Subject: [PATCH 08/12] Update workflow.py Former-commit-id: c309605ff565dc34d043314269fce5881212c27c --- src/llmtuner/train/sft/workflow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 3b7b909a..d9d7c8e9 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -30,7 +30,6 @@ def run_sft( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - processor = tokenizer_module["processor"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) @@ -56,11 +55,10 @@ def run_sft( model=model, args=training_args, finetuning_args=finetuning_args, - processor=processor, - tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + **tokenizer_module, **split_dataset(dataset, data_args, training_args), ) From cea8cea9dd074a85f102fb7470c63f9ad43a4329 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 14:13:26 +0800 Subject: [PATCH 09/12] Update trainer.py Former-commit-id: aa4a8933dd520227401b7041dae40fc6fb2ddaa2 --- src/llmtuner/train/sft/trainer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 5f187375..35671e1b 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import ProcessorMixin, Seq2SeqTrainer +from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -13,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from transformers import ProcessorMixin from transformers.trainer import PredictionOutput from ...hparams import FinetuningArguments @@ -26,7 +27,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. """ - def __init__(self, finetuning_args: "FinetuningArguments", processor: "ProcessorMixin", **kwargs) -> None: + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor @@ -46,6 +49,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: + super()._save(output_dir, state_dict) + if self.processor is not None: + output_dir = output_dir if output_dir is not None else self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) + def prediction_step( self, model: "torch.nn.Module", @@ -121,10 +130,3 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) - - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - super().save_model(output_dir, _internal_call) - if self.processor is not None: - if output_dir is None: - output_dir = self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) From 3d65c4ceabcfb9b660324e02f5f1450a729a4832 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 14:20:39 +0800 Subject: [PATCH 10/12] Update template.py Former-commit-id: 780ca8306b31d5ac856f68de3abed7e838848464 --- src/llmtuner/data/template.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 0b2ca0e6..b20c9203 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -857,17 +857,17 @@ _register_template( _register_template( - name="yivl", + name="yi_vl", format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), - format_assistant=StringFormatter(slots=[" {{content}}\n"]), - stop_words=["###"], + format_separator=EmptyFormatter(slots=["\n"]), default_system=( "This is a chat between an inquisitive human and an AI assistant. " - "Assume the role of the AI assistant. " - "Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers." - "这是一个好奇的人类和一个人工智能助手之间的对话。" - "假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n" + "Assume the role of the AI assistant. Read all the images carefully, " + "and respond to the human's questions with informative, helpful, detailed and polite answers. " + "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。" + "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n" ), + stop_words=["###"], ) From e09d68985f55ecd1ed826ddba976c28fca9c37be Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 15:37:07 +0800 Subject: [PATCH 11/12] Update patcher.py Former-commit-id: 5a0c8a8d343adb15b510f65286ee08f33b1b2751 --- src/llmtuner/model/patcher.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index bddea594..70aed709 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -17,7 +17,7 @@ from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope from .utils.valuehead import prepare_valuehead_model -from .utils.visual import autocast_projector_dtype, configure_hidden_size, configure_visual +from .utils.visual import autocast_projector_dtype, configure_visual_model if TYPE_CHECKING: @@ -54,8 +54,7 @@ def patch_config( configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) - configure_hidden_size(config) - configure_visual(config, model_args) + configure_visual_model(config) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) From e80e50805cd501addada9cdbcd9e55d475b39b6e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 16:39:57 +0800 Subject: [PATCH 12/12] Update visual.py Former-commit-id: cbeef2aaea0577fd1929e7f156a2b8601b31814e --- src/llmtuner/model/utils/visual.py | 53 ++++++++++++++---------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 1f770861..9a5134ff 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, Tuple import torch -import transformers -from torch import nn +import transformers.models +from transformers.activations import ACT2FN from ...extras.logging import get_logger @@ -16,9 +16,23 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_hidden_size(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "llava": - setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) +class LlavaMultiModalProjector(torch.nn.Module): + def __init__(self, config: "LlavaConfig"): + super().__init__() + + self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) + return hidden_states def autocast_projector_dtype( @@ -35,28 +49,9 @@ def autocast_projector_dtype( mm_projector.register_forward_hook(_mm_projector_forward_post_hook) -class LlavaMultiModalProjectorYiVL(nn.Module): - def __init__(self, config: "LlavaConfig"): - super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) - self.linear_2 = nn.LayerNorm(config.text_config.hidden_size, bias=True) - self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) - self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True) - self.act = nn.GELU() +def configure_visual_model(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": + setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) - def forward(self, image_features): - dtype_ = self.linear_1.weight.dtype - hidden_states = self.linear_1(image_features) - hidden_states = self.linear_2(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_3(hidden_states) - hidden_states = self.linear_4(hidden_states) - hidden_states = hidden_states.to(dtype_) - return hidden_states - - -def configure_visual(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - logger = get_logger(__name__) - if model_args.visual_inputs and "Yi" in getattr(config.text_config, "_name_or_path", None): - transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorYiVL - logger.info("Patched Multimodal Projector for Yi-VL.") + if getattr(config, "is_yi_vl_derived_model", None): + transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjector