diff --git a/README.md b/README.md index 799c5019..ad9ff286 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog + [25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 6f61f2ab..c2e3e114 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -169,7 +169,7 @@ class HuggingfaceEngine(BaseEngine): if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs value = torch.stack(value) # assume they have same sizes elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs - value = torch.stack([torch.stack(per_value) for per_value in value]) + value = torch.stack([torch.stack(v) for v in value]) elif not isinstance(value, torch.Tensor): value = torch.tensor(value) @@ -215,7 +215,7 @@ class HuggingfaceEngine(BaseEngine): ) generate_output = model.generate(**gen_kwargs) if isinstance(generate_output, tuple): - generate_output = generate_output[1][0] # for minicpm_o + generate_output = generate_output[1][0] # post-process the minicpm_o output response_ids = generate_output[:, prompt_length:] response = tokenizer.batch_decode( diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 07be3cec..4d7f5eb9 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -732,15 +732,18 @@ _register_template( stop_words=["<|im_end|>"], ) + +# copied from intern2 template _register_template( name="intern3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - stop_words=["<|im_end|>"] + stop_words=["<|im_end|>"], ) + _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 4d6b5a99..b450e72d 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -85,7 +85,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ) except Exception as e: raise OSError("Failed to load tokenizer.") from e - + if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: tokenizer.model_max_length = model_args.model_max_length diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 4dc7e1b5..c33527a6 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -110,7 +110,7 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn - + if getattr(config, "model_type", None) == "minicpmo": setattr(config, "init_audio", False) setattr(config, "init_tts", False) @@ -119,7 +119,7 @@ def patch_config( raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): - raise RuntimeError("InternLM3 model requires transformers >= 4.47.1, please upgrade it.") + raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) @@ -153,7 +153,7 @@ def patch_model( ): gen_config.do_sample = True - if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( + if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( model.generate.__func__ ): model.generate = MethodType(PreTrainedModel.generate, model)