diff --git a/data/dataset_info.json b/data/dataset_info.json index b985582e..5a90e077 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -38,6 +38,20 @@ "assistant_tag": "assistant" } }, + "mllm_pt_demo": { + "file_name": "mllm_pt_demo.json", + "formatting": "sharegpt", + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } + }, "alpaca_en": { "hf_hub_url": "llamafactory/alpaca_en", "ms_hub_url": "llamafactory/alpaca_en" diff --git a/data/mllm_pt_demo.json b/data/mllm_pt_demo.json new file mode 100644 index 00000000..2ee01ce6 --- /dev/null +++ b/data/mllm_pt_demo.json @@ -0,0 +1,92 @@ +[ + { + "messages": [ + { + "content": "Render a clear and concise summary of the photo.", + "role": "user" + }, + { + "content": "There are two soccer players on the field.", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/1.jpg" + ] + }, + { + "messages": [ + { + "content": "Write a terse but informative summary of the picture.", + "role": "user" + }, + { + "content": "A soccer player is sliding on his knees to celebrate", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/2.jpg" + ] + }, + { + "messages": [ + { + "content": "What is this?", + "role": "user" + }, + { + "content": "A man is giving a speech.", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/3.jpg" + ] + }, + { + "messages": [ + { + "content": "对照片进行简明扼要的概括。", + "role": "user" + }, + { + "content": "两个足球运动员在场上", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/1.jpg" + ] + }, + { + "messages": [ + { + "content": "为图片写一个简短但内容丰富的摘要。", + "role": "user" + }, + { + "content": "一个足球运动员在跪地滑行庆祝", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/2.jpg" + ] + }, + { + "messages": [ + { + "content": "这是什么?", + "role": "user" + }, + { + "content": "一个男人在演讲", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/3.jpg" + ] + } +] \ No newline at end of file diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..255051dc 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -85,6 +85,10 @@ class ModelArguments: default=False, metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, ) + tune_mm_proj: bool = field( + default=False, + metadata={"help": "Whethor or not only finetune mm_projector for MLLM."}, + ) moe_aux_loss_coef: Optional[float] = field( default=None, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 49b347d5..d9784593 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -163,6 +163,11 @@ def load_model( else: model.train() + if model_args.visual_inputs and model_args.tune_mm_proj: + lm_params = [param for name, param in model.named_parameters() if "language_model" in name] + for param in lm_params: + param.requires_grad_(False) + trainable_params, all_param = count_parameters(model) if is_trainable: param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(