Former-commit-id: 1278c3e92eeb297e883aab89e2384c1df1d0e910
This commit is contained in:
hoshi-hiyouga 2025-01-14 18:40:07 +08:00 committed by GitHub
parent 864ee06243
commit 91433d639c
5 changed files with 11 additions and 7 deletions

View File

@ -87,6 +87,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
</details> </details>
## Changelog ## Changelog
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.

View File

@ -169,7 +169,7 @@ class HuggingfaceEngine(BaseEngine):
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value]) value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor): elif not isinstance(value, torch.Tensor):
value = torch.tensor(value) value = torch.tensor(value)
@ -215,7 +215,7 @@ class HuggingfaceEngine(BaseEngine):
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple): if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o generate_output = generate_output[1][0] # post-process the minicpm_o output
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode( response = tokenizer.batch_decode(

View File

@ -732,15 +732,18 @@ _register_template(
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )
# copied from intern2 template
_register_template( _register_template(
name="intern3", name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"] stop_words=["<|im_end|>"],
) )
_register_template( _register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),

View File

@ -119,7 +119,7 @@ def patch_config(
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers >= 4.47.1, please upgrade it.") raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
@ -153,7 +153,7 @@ def patch_model(
): ):
gen_config.do_sample = True gen_config.do_sample = True
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__ model.generate.__func__
): ):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)