mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[v1] support resume training from checkpoint (#10280)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
76
scripts/dcp2hf.py
Normal file
76
scripts/dcp2hf.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a DCP checkpoint to HuggingFace model format.
|
||||
|
||||
Usage:
|
||||
python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config
|
||||
|
||||
Arguments:
|
||||
dcp_path: Path to the DCP checkpoint directory.
|
||||
hf_path: Output path (directory) for HuggingFace model.
|
||||
config_path: Path to the HuggingFace model directory containing config.json.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def convert(dcp_path: str, hf_path: str, config_path: str) -> None:
|
||||
"""Convert DCP model weights to HF.
|
||||
|
||||
Note: this script is used to convert a DCP checkpoint to HuggingFace model format,
|
||||
it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer,
|
||||
you may need to copy from the original model.
|
||||
|
||||
Args:
|
||||
dcp_path: DCP checkpoint directory.
|
||||
hf_path: Output path (directory) for HuggingFace model.
|
||||
config_path: Path to the HuggingFace model directory containing config.json.
|
||||
"""
|
||||
if not dcp_path or not hf_path or not config_path:
|
||||
raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.")
|
||||
|
||||
print(f"Loading config from {config_path}...")
|
||||
config = AutoConfig.from_pretrained(config_path)
|
||||
architectures = getattr(config, "architectures", [])
|
||||
if architectures:
|
||||
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
|
||||
else:
|
||||
model_cls = transformers.AutoModelForCausalLM
|
||||
|
||||
print("Initializing model on CPU...")
|
||||
model = model_cls(config).to(torch.bfloat16)
|
||||
|
||||
print(f"Loading DCP from {dcp_path}...")
|
||||
state_dict = model.state_dict()
|
||||
dcp.load(state_dict, checkpoint_id=dcp_path)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
print(f"Saving to HF format at {hf_path}...")
|
||||
model.save_pretrained(hf_path)
|
||||
config.save_pretrained(hf_path)
|
||||
print("Done!")
|
||||
|
||||
|
||||
def help() -> None:
|
||||
"""Show help message."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire({"convert": convert, "help": help, "--convert": convert})
|
||||
@@ -25,7 +25,8 @@ Arguments:
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def convert(hf_path: str, dcp_path: str) -> None:
|
||||
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
|
||||
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
|
||||
|
||||
print(f"Loading HF model from {hf_path}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
config = AutoConfig.from_pretrained(hf_path)
|
||||
architectures = getattr(config, "architectures", [])
|
||||
if architectures:
|
||||
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
|
||||
else:
|
||||
model_cls = transformers.AutoModelForCausalLM
|
||||
|
||||
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
|
||||
print(f"Saving to DCP format at {dcp_path}...")
|
||||
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
|
||||
|
||||
@@ -85,6 +85,28 @@ class TrainingArguments:
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
resume_from_checkpoint: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
|
||||
)
|
||||
save_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Save a training checkpoint every N global steps."},
|
||||
)
|
||||
save_epochs: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Save a training checkpoint every N epochs."},
|
||||
)
|
||||
save_ckpt_as_hf: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Save intermediate checkpoints in HuggingFace format instead of distributed format. Warning: doubles memory usage."
|
||||
},
|
||||
)
|
||||
save_total_limit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Maximum number of checkpoints to keep. Oldest checkpoints are deleted."},
|
||||
)
|
||||
logging_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Log metrics every N optimizer steps."},
|
||||
|
||||
@@ -45,6 +45,7 @@ from ..utils.callbacks import (
|
||||
from ..utils.helper import compute_valid_tokens
|
||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||
from .utils.batching import BatchGenerator
|
||||
from .utils.checkpoint import TrainingCheckpointCoordinator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
@@ -81,6 +82,10 @@ class BaseTrainer:
|
||||
else:
|
||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||
|
||||
if self.args.save_epochs is not None:
|
||||
steps_per_epoch = len(self.train_batch_generator)
|
||||
self.args.save_steps = max(1, int(steps_per_epoch * self.args.save_epochs))
|
||||
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
@@ -107,13 +112,28 @@ class BaseTrainer:
|
||||
self._init_optimizer()
|
||||
self._init_lr_scheduler()
|
||||
|
||||
self._resume_epoch = 0
|
||||
self._checkpoint = TrainingCheckpointCoordinator(self)
|
||||
if self.args.resume_from_checkpoint:
|
||||
self._checkpoint.resume(self.args.resume_from_checkpoint)
|
||||
|
||||
if self.args.save_ckpt_as_hf:
|
||||
logger.warning_rank0(
|
||||
"save_ckpt_as_hf is enabled. Intermediate checkpoints will be saved in Hugging Face format. "
|
||||
"Note that this will significantly increase memory consumption during saving."
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
|
||||
for cb in callbacks or []:
|
||||
self.callback_handler.add_callback(cb)
|
||||
|
||||
# Callbacks: TrainerState tracks progress across the full run.
|
||||
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||
self.state = TrainerState(
|
||||
num_training_steps=self.num_training_steps,
|
||||
global_step=self.global_step,
|
||||
epoch=self._resume_epoch,
|
||||
)
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
@@ -206,7 +226,7 @@ class BaseTrainer:
|
||||
"""Train the model."""
|
||||
self.model.train()
|
||||
self.callback_handler.on_train_begin(self.args, self.state)
|
||||
for epoch in range(self.args.num_train_epochs):
|
||||
for epoch in range(self._resume_epoch, self.args.num_train_epochs):
|
||||
self.state.epoch = epoch
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
@@ -300,6 +320,9 @@ class BaseTrainer:
|
||||
}
|
||||
self.callback_handler.on_log(self.args, self.state, logs)
|
||||
|
||||
if self.args.save_steps and self.global_step % self.args.save_steps == 0:
|
||||
self._checkpoint.save(epoch)
|
||||
|
||||
# Check if max_steps is reached
|
||||
if self.global_step >= self.num_training_steps:
|
||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||
|
||||
@@ -165,7 +165,6 @@ class BatchGenerator(Iterator):
|
||||
def __iter__(self):
|
||||
if not self._is_resuming:
|
||||
self._buffer.clear()
|
||||
self._buffer_tokens = 0
|
||||
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._is_resuming = False
|
||||
@@ -203,14 +202,12 @@ class BatchGenerator(Iterator):
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer,
|
||||
"buffer_tokens": self._buffer_tokens,
|
||||
"buffer": self._buffer.state_dict(),
|
||||
"data_provider": self._data_provider.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self._buffer = state["buffer"]
|
||||
self._buffer_tokens = state["buffer_tokens"]
|
||||
self._buffer.load_state_dict(state["buffer"])
|
||||
self._data_provider.load_state_dict(state["data_provider"])
|
||||
self._is_resuming = True
|
||||
|
||||
|
||||
339
src/llamafactory/v1/core/utils/checkpoint.py
Normal file
339
src/llamafactory/v1/core/utils/checkpoint.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Checkpoint utilities: low-level helpers and full training save/resume orchestration."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from ...accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
CHECKPOINT_COMPLETE_MARKER = "CHECKPOINT_COMPLETE"
|
||||
|
||||
|
||||
def _parse_checkpoint_step(path: str) -> int:
|
||||
"""Extract the step number from a checkpoint directory name, or -1 if invalid."""
|
||||
try:
|
||||
return int(os.path.basename(path).split("-")[-1])
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
|
||||
def find_latest_checkpoint(output_dir: str) -> str | None:
|
||||
"""Find the latest valid checkpoint directory in output_dir."""
|
||||
pattern = os.path.join(output_dir, "checkpoint-*")
|
||||
ckpt_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
|
||||
ckpt_dirs.sort(key=_parse_checkpoint_step)
|
||||
for d in reversed(ckpt_dirs):
|
||||
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
|
||||
return d
|
||||
return None
|
||||
|
||||
|
||||
def rotate_checkpoints(output_dir: str, limit: int) -> None:
|
||||
"""Keep only the latest `limit` complete checkpoints, delete older ones and incomplete leftovers."""
|
||||
pattern = os.path.join(output_dir, "checkpoint-*")
|
||||
all_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
|
||||
all_dirs.sort(key=_parse_checkpoint_step)
|
||||
|
||||
complete_dirs = []
|
||||
for d in all_dirs:
|
||||
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
|
||||
complete_dirs.append(d)
|
||||
else:
|
||||
shutil.rmtree(d)
|
||||
logger.info_rank0(f"Cleaned up incomplete checkpoint: {d}")
|
||||
|
||||
while len(complete_dirs) > limit:
|
||||
oldest = complete_dirs.pop(0)
|
||||
shutil.rmtree(oldest)
|
||||
logger.info_rank0(f"Deleted old checkpoint: {oldest}")
|
||||
|
||||
|
||||
def save_metadata(ckpt_dir: str, **kwargs) -> None:
|
||||
"""Save training metadata as JSON (rank 0 only)."""
|
||||
with open(os.path.join(ckpt_dir, "metadata.json"), "w") as f:
|
||||
json.dump(kwargs, f, indent=2)
|
||||
|
||||
|
||||
def load_metadata(ckpt_dir: str) -> dict:
|
||||
"""Load training metadata from a checkpoint directory."""
|
||||
with open(os.path.join(ckpt_dir, "metadata.json")) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _get_accelerator_rng_state():
|
||||
"""Get RNG state for the current accelerator, device-agnostic."""
|
||||
device_type = get_current_accelerator().type
|
||||
if device_type == DeviceType.CUDA:
|
||||
return torch.cuda.get_rng_state_all()
|
||||
elif device_type == DeviceType.NPU:
|
||||
return torch.npu.get_rng_state_all()
|
||||
elif device_type == DeviceType.XPU:
|
||||
return torch.xpu.get_rng_state_all()
|
||||
return None
|
||||
|
||||
|
||||
def _set_accelerator_rng_state(state) -> None:
|
||||
"""Set RNG state for the current accelerator, device-agnostic."""
|
||||
if state is None:
|
||||
return
|
||||
|
||||
device_type = get_current_accelerator().type
|
||||
if device_type == DeviceType.CUDA:
|
||||
torch.cuda.set_rng_state_all(state)
|
||||
elif device_type == DeviceType.NPU:
|
||||
torch.npu.set_rng_state_all(state)
|
||||
elif device_type == DeviceType.XPU:
|
||||
torch.xpu.set_rng_state_all(state)
|
||||
|
||||
|
||||
def save_rng_state(ckpt_dir: str, rank: int) -> None:
|
||||
"""Save per-rank RNG states for reproducibility."""
|
||||
rng_state = {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"torch": torch.random.get_rng_state(),
|
||||
"accelerator": _get_accelerator_rng_state(),
|
||||
}
|
||||
rng_dir = os.path.join(ckpt_dir, "rng_state")
|
||||
os.makedirs(rng_dir, exist_ok=True)
|
||||
torch.save(rng_state, os.path.join(rng_dir, f"rank_{rank}.pt"))
|
||||
|
||||
|
||||
def load_rng_state(ckpt_dir: str, rank: int) -> None:
|
||||
"""Restore per-rank RNG states from a checkpoint."""
|
||||
path = os.path.join(ckpt_dir, "rng_state", f"rank_{rank}.pt")
|
||||
|
||||
if not os.path.exists(path):
|
||||
logger.warning_rank0(f"RNG state file not found at {path}. Skipping RNG state restoration.")
|
||||
return
|
||||
|
||||
rng_state = torch.load(path, map_location="cpu", weights_only=False)
|
||||
random.setstate(rng_state["python"])
|
||||
np.random.set_state(rng_state["numpy"])
|
||||
torch.random.set_rng_state(rng_state["torch"])
|
||||
_set_accelerator_rng_state(rng_state.get("accelerator"))
|
||||
|
||||
|
||||
def mark_checkpoint_complete(ckpt_dir: str) -> None:
|
||||
"""Write a marker file indicating the checkpoint is fully saved."""
|
||||
open(os.path.join(ckpt_dir, CHECKPOINT_COMPLETE_MARKER), "w").close()
|
||||
|
||||
|
||||
def resolve_resume_checkpoint_path(ckpt_path: str, output_dir: str) -> str | None:
|
||||
"""Resolve 'auto' to the latest valid checkpoint, or return the path as-is."""
|
||||
if ckpt_path == "auto":
|
||||
resolved = find_latest_checkpoint(output_dir)
|
||||
if resolved is None:
|
||||
logger.warning_rank0(
|
||||
"resume_from_checkpoint='auto' but no valid checkpoint found in "
|
||||
f"'{output_dir}'. Training from scratch."
|
||||
)
|
||||
else:
|
||||
logger.info_rank0(f"Auto-detected latest checkpoint: {resolved}")
|
||||
return resolved
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def _save_standard_training_states(
|
||||
ckpt_dir: str,
|
||||
model: Any,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
processor: Any,
|
||||
save_ckpt_as_hf: bool,
|
||||
) -> None:
|
||||
"""Save model and optimizer for DDP / single-GPU via save_pretrained."""
|
||||
rank = DistributedInterface().get_rank()
|
||||
if rank == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
||||
processor.save_pretrained(model_dir)
|
||||
|
||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||
torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer", "state_dict.pt"))
|
||||
|
||||
if save_ckpt_as_hf:
|
||||
logger.info("Standard saving already uses HF format. No additional 'hf_model' directory created.")
|
||||
|
||||
|
||||
def _load_standard_training_states(
|
||||
ckpt_dir: str,
|
||||
model: Any,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
map_location: torch.device,
|
||||
) -> None:
|
||||
"""Load model and optimizer for DDP / single-GPU."""
|
||||
model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_to_load = model.module if hasattr(model, "module") else model
|
||||
|
||||
is_adapter_ckpt = os.path.exists(os.path.join(model_dir, "adapter_config.json"))
|
||||
|
||||
if is_adapter_ckpt:
|
||||
from peft import set_peft_model_state_dict
|
||||
|
||||
adapter_file = os.path.join(model_dir, "adapter_model.safetensors")
|
||||
if not os.path.exists(adapter_file):
|
||||
adapter_file = os.path.join(model_dir, "adapter_model.bin")
|
||||
adapter_state = torch.load(adapter_file, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
adapter_state = load_file(adapter_file, device="cpu")
|
||||
set_peft_model_state_dict(model_to_load, adapter_state)
|
||||
else:
|
||||
state_dict = {}
|
||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))):
|
||||
state_dict.update(load_file(f, device="cpu"))
|
||||
if not state_dict:
|
||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||
if state_dict:
|
||||
model_to_load.load_state_dict(state_dict)
|
||||
else:
|
||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||
|
||||
optim_path = os.path.join(ckpt_dir, "optimizer", "state_dict.pt")
|
||||
if os.path.exists(optim_path):
|
||||
optimizer.load_state_dict(torch.load(optim_path, map_location=map_location, weights_only=True))
|
||||
|
||||
|
||||
class TrainingCheckpointCoordinator:
|
||||
"""Coordinates full checkpoint save/resume for a trainer instance."""
|
||||
|
||||
def __init__(self, trainer: Any) -> None:
|
||||
self._t = trainer
|
||||
|
||||
@property
|
||||
def _dist_name(self) -> str | None:
|
||||
return self._t.args.dist_config.name if self._t.args.dist_config is not None else None
|
||||
|
||||
def save(self, epoch: int) -> None:
|
||||
"""Save a full training checkpoint at the current global step."""
|
||||
ckpt_dir = os.path.join(self._t.args.output_dir, f"checkpoint-{self._t.global_step}")
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
|
||||
rank = DistributedInterface().get_rank()
|
||||
|
||||
if rank == 0:
|
||||
save_metadata(
|
||||
ckpt_dir,
|
||||
global_step=self._t.global_step,
|
||||
epoch=epoch,
|
||||
num_training_steps=self._t.num_training_steps,
|
||||
)
|
||||
|
||||
if self._dist_name in ("fsdp2", "deepspeed"):
|
||||
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
DistributedPlugin(self._dist_name).save_checkpoint(
|
||||
self._t.model,
|
||||
self._t.optimizer,
|
||||
ckpt_dir,
|
||||
save_ckpt_as_hf=self._t.args.save_ckpt_as_hf,
|
||||
processor=self._t.renderer.processor,
|
||||
)
|
||||
else:
|
||||
_save_standard_training_states(
|
||||
ckpt_dir,
|
||||
self._t.model,
|
||||
self._t.optimizer,
|
||||
self._t.renderer.processor,
|
||||
self._t.args.save_ckpt_as_hf,
|
||||
)
|
||||
|
||||
if self._dist_name != "deepspeed" and rank == 0:
|
||||
torch.save(self._t.lr_scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt"))
|
||||
|
||||
dl_dir = os.path.join(ckpt_dir, "dataloader")
|
||||
os.makedirs(dl_dir, exist_ok=True)
|
||||
torch.save(
|
||||
self._t.train_batch_generator.state_dict(),
|
||||
os.path.join(dl_dir, f"rank_{rank}.pt"),
|
||||
)
|
||||
|
||||
if self._dist_name != "deepspeed":
|
||||
save_rng_state(ckpt_dir, rank)
|
||||
|
||||
DistributedInterface().sync()
|
||||
|
||||
if rank == 0:
|
||||
mark_checkpoint_complete(ckpt_dir)
|
||||
if self._t.args.save_total_limit is not None:
|
||||
rotate_checkpoints(self._t.args.output_dir, self._t.args.save_total_limit)
|
||||
|
||||
logger.info_rank0(f"Checkpoint saved to {ckpt_dir}")
|
||||
|
||||
def resume(self, ckpt_path: str) -> None:
|
||||
"""Restore full training state from a checkpoint directory."""
|
||||
ckpt_dir = resolve_resume_checkpoint_path(ckpt_path, self._t.args.output_dir)
|
||||
if ckpt_dir is None:
|
||||
return
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
raise ValueError(f"Checkpoint directory does not exist: {ckpt_dir}")
|
||||
|
||||
rank = DistributedInterface().get_rank()
|
||||
|
||||
metadata = load_metadata(ckpt_dir)
|
||||
self._t.global_step = metadata["global_step"]
|
||||
self._t._resume_epoch = metadata["epoch"]
|
||||
|
||||
if self._dist_name in ("fsdp2", "deepspeed"):
|
||||
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
DistributedPlugin(self._dist_name).load_checkpoint(
|
||||
self._t.model,
|
||||
self._t.optimizer,
|
||||
ckpt_dir,
|
||||
processor=self._t.renderer.processor,
|
||||
)
|
||||
else:
|
||||
_load_standard_training_states(
|
||||
ckpt_dir,
|
||||
self._t.model,
|
||||
self._t.optimizer,
|
||||
self._t.device,
|
||||
)
|
||||
|
||||
if self._dist_name != "deepspeed":
|
||||
sched_path = os.path.join(ckpt_dir, "scheduler.pt")
|
||||
if os.path.exists(sched_path):
|
||||
self._t.lr_scheduler.load_state_dict(torch.load(sched_path, map_location="cpu", weights_only=True))
|
||||
|
||||
dl_path = os.path.join(ckpt_dir, "dataloader", f"rank_{rank}.pt")
|
||||
|
||||
if os.path.exists(dl_path):
|
||||
self._t.train_batch_generator.load_state_dict(torch.load(dl_path, map_location="cpu", weights_only=False))
|
||||
else:
|
||||
logger.warning_rank0(
|
||||
f"Dataloader state file not found at {dl_path}. Skipping Dataloader state restoration."
|
||||
)
|
||||
|
||||
if self._dist_name != "deepspeed":
|
||||
load_rng_state(ckpt_dir, rank)
|
||||
|
||||
logger.info_rank0(f"Resumed from checkpoint: step={self._t.global_step}, epoch={self._t._resume_epoch}")
|
||||
@@ -19,6 +19,7 @@ this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
|
||||
initialization, backward, gradient accumulation, and model saving.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -127,3 +128,22 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
logger.info_rank0(f"Model saved to {output_dir}")
|
||||
|
||||
|
||||
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
|
||||
processor = kwargs.get("processor", None)
|
||||
|
||||
# Always save DeepSpeed state for resume capability
|
||||
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
|
||||
accelerator.save_state(ckpt_dir)
|
||||
|
||||
# Additionally save HF format if requested
|
||||
if save_ckpt_as_hf:
|
||||
hf_dir = os.path.join(ckpt_dir, "hf_model")
|
||||
save_model(model, hf_dir, processor)
|
||||
|
||||
|
||||
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
|
||||
accelerator.load_state(ckpt_dir)
|
||||
|
||||
@@ -18,9 +18,16 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
StateDictOptions,
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
set_model_state_dict,
|
||||
set_optimizer_state_dict,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
@@ -66,6 +73,51 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
logger.info(f"Model saved to {output_dir}")
|
||||
|
||||
|
||||
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
|
||||
processor = kwargs.get("processor", None)
|
||||
|
||||
# Always save DCP format for resume capability
|
||||
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
|
||||
|
||||
model_state = get_model_state_dict(model, options=options)
|
||||
dcp.save(state_dict=model_state, checkpoint_id=os.path.join(ckpt_dir, "model"))
|
||||
|
||||
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
|
||||
dcp.save(state_dict=optim_state, checkpoint_id=os.path.join(ckpt_dir, "optimizer"))
|
||||
|
||||
# Additionally save HF format if requested
|
||||
if save_ckpt_as_hf:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
logger.info("Gathering state dict for saving additional HF format checkpoint...")
|
||||
|
||||
hf_options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
hf_state_dict = get_model_state_dict(model, options=hf_options)
|
||||
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
hf_dir = os.path.join(ckpt_dir, "hf_model")
|
||||
model_to_save.save_pretrained(hf_dir, state_dict=hf_state_dict, max_shard_size="4GB")
|
||||
if processor is not None:
|
||||
processor.save_pretrained(hf_dir, max_shard_size="4GB")
|
||||
|
||||
logger.info(f"Additional HF format checkpoint saved to {hf_dir}")
|
||||
|
||||
|
||||
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
|
||||
|
||||
ckpt_model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_state = get_model_state_dict(model, options=options)
|
||||
dcp.load(state_dict=model_state, checkpoint_id=ckpt_model_dir)
|
||||
set_model_state_dict(model, model_state, options=options)
|
||||
|
||||
ckpt_optim_dir = os.path.join(ckpt_dir, "optimizer")
|
||||
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
|
||||
dcp.load(state_dict=optim_state, checkpoint_id=ckpt_optim_dir)
|
||||
set_optimizer_state_dict(model, optimizer, optim_state, options=options)
|
||||
|
||||
|
||||
class FSDP2Engine:
|
||||
def __init__(self, dist_config: dict):
|
||||
self.dist_interface = DistributedInterface()
|
||||
|
||||
@@ -16,6 +16,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ....config.arg_utils import PluginConfig
|
||||
from ....utils.plugin import BasePlugin
|
||||
|
||||
@@ -43,6 +45,20 @@ def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> N
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register("save_checkpoint")
|
||||
def save_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .fsdp2 import save_checkpoint
|
||||
|
||||
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register("load_checkpoint")
|
||||
def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .fsdp2 import load_checkpoint
|
||||
|
||||
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
@@ -59,3 +75,17 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
|
||||
from .deepspeed import save_model
|
||||
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
from .deepspeed import save_checkpoint
|
||||
|
||||
return save_checkpoint(model, optimizer, ckpt_dir)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
from .deepspeed import load_checkpoint
|
||||
|
||||
return load_checkpoint(model, optimizer, ckpt_dir)
|
||||
|
||||
Reference in New Issue
Block a user