[v1] support resume training from checkpoint (#10280)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
浮梦
2026-04-20 20:28:08 +08:00
committed by GitHub
parent c5aecaf31d
commit c4bbac49b2
9 changed files with 577 additions and 10 deletions

View File

@@ -25,7 +25,8 @@ Arguments:
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
import transformers
from transformers import AutoConfig
def convert(hf_path: str, dcp_path: str) -> None:
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(hf_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)