mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
lint (#6641)
Former-commit-id: 1278c3e92eeb297e883aab89e2384c1df1d0e910
This commit is contained in:
parent
864ee06243
commit
91433d639c
@ -87,6 +87,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
[25/01/15] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
|
||||||
|
|
||||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
|
||||||
|
@ -169,7 +169,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
value = torch.stack(value) # assume they have same sizes
|
value = torch.stack(value) # assume they have same sizes
|
||||||
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
||||||
value = torch.stack([torch.stack(per_value) for per_value in value])
|
value = torch.stack([torch.stack(v) for v in value])
|
||||||
elif not isinstance(value, torch.Tensor):
|
elif not isinstance(value, torch.Tensor):
|
||||||
value = torch.tensor(value)
|
value = torch.tensor(value)
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
if isinstance(generate_output, tuple):
|
if isinstance(generate_output, tuple):
|
||||||
generate_output = generate_output[1][0] # for minicpm_o
|
generate_output = generate_output[1][0] # post-process the minicpm_o output
|
||||||
|
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = tokenizer.batch_decode(
|
response = tokenizer.batch_decode(
|
||||||
|
@ -732,15 +732,18 @@ _register_template(
|
|||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from intern2 template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern3",
|
name="intern3",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|im_end|>"]
|
stop_words=["<|im_end|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
|
@ -119,7 +119,7 @@ def patch_config(
|
|||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
raise RuntimeError("InternLM3 model requires transformers >= 4.47.1, please upgrade it.")
|
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||||
|
|
||||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
@ -153,7 +153,7 @@ def patch_model(
|
|||||||
):
|
):
|
||||||
gen_config.do_sample = True
|
gen_config.do_sample = True
|
||||||
|
|
||||||
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||||
model.generate.__func__
|
model.generate.__func__
|
||||||
):
|
):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user