mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
lint (#6641)
Former-commit-id: 1278c3e92eeb297e883aab89e2384c1df1d0e910
This commit is contained in:
parent
864ee06243
commit
91433d639c
@ -87,6 +87,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
</details>
|
||||
|
||||
## Changelog
|
||||
|
||||
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
||||
|
||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
||||
|
@ -169,7 +169,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
||||
value = torch.stack([torch.stack(per_value) for per_value in value])
|
||||
value = torch.stack([torch.stack(v) for v in value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
@ -215,7 +215,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
if isinstance(generate_output, tuple):
|
||||
generate_output = generate_output[1][0] # for minicpm_o
|
||||
generate_output = generate_output[1][0] # post-process the minicpm_o output
|
||||
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = tokenizer.batch_decode(
|
||||
|
@ -732,15 +732,18 @@ _register_template(
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from intern2 template
|
||||
_register_template(
|
||||
name="intern3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"]
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
|
@ -153,7 +153,7 @@ def patch_model(
|
||||
):
|
||||
gen_config.do_sample = True
|
||||
|
||||
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||
model.generate.__func__
|
||||
):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user