Former-commit-id: 1278c3e92eeb297e883aab89e2384c1df1d0e910
This commit is contained in:
hoshi-hiyouga 2025-01-14 18:40:07 +08:00 committed by GitHub
parent 864ee06243
commit 91433d639c
5 changed files with 11 additions and 7 deletions

View File

@ -87,6 +87,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
</details>
## Changelog
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.

View File

@ -169,7 +169,7 @@ class HuggingfaceEngine(BaseEngine):
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value])
value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
@ -215,7 +215,7 @@ class HuggingfaceEngine(BaseEngine):
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o
generate_output = generate_output[1][0] # post-process the minicpm_o output
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(

View File

@ -732,15 +732,18 @@ _register_template(
stop_words=["<|im_end|>"],
)
# copied from intern2 template
_register_template(
name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"]
stop_words=["<|im_end|>"],
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),

View File

@ -153,7 +153,7 @@ def patch_model(
):
gen_config.do_sample = True
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)