mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[v1] support resume training from checkpoint (#10280)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,8 @@ Arguments:
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def convert(hf_path: str, dcp_path: str) -> None:
|
||||
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
|
||||
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
|
||||
|
||||
print(f"Loading HF model from {hf_path}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
config = AutoConfig.from_pretrained(hf_path)
|
||||
architectures = getattr(config, "architectures", [])
|
||||
if architectures:
|
||||
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
|
||||
else:
|
||||
model_cls = transformers.AutoModelForCausalLM
|
||||
|
||||
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
|
||||
print(f"Saving to DCP format at {dcp_path}...")
|
||||
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
|
||||
|
||||
Reference in New Issue
Block a user