[v1] support resume training from checkpoint (#10280)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
浮梦
2026-04-20 20:28:08 +08:00
committed by GitHub
parent c5aecaf31d
commit c4bbac49b2
9 changed files with 577 additions and 10 deletions

76
scripts/dcp2hf.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a DCP checkpoint to HuggingFace model format.
Usage:
python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config
Arguments:
dcp_path: Path to the DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
import transformers
from transformers import AutoConfig
def convert(dcp_path: str, hf_path: str, config_path: str) -> None:
"""Convert DCP model weights to HF.
Note: this script is used to convert a DCP checkpoint to HuggingFace model format,
it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer,
you may need to copy from the original model.
Args:
dcp_path: DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
if not dcp_path or not hf_path or not config_path:
raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.")
print(f"Loading config from {config_path}...")
config = AutoConfig.from_pretrained(config_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
print("Initializing model on CPU...")
model = model_cls(config).to(torch.bfloat16)
print(f"Loading DCP from {dcp_path}...")
state_dict = model.state_dict()
dcp.load(state_dict, checkpoint_id=dcp_path)
model.load_state_dict(state_dict)
print(f"Saving to HF format at {hf_path}...")
model.save_pretrained(hf_path)
config.save_pretrained(hf_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -25,7 +25,8 @@ Arguments:
import fire import fire
import torch import torch
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM import transformers
from transformers import AutoConfig
def convert(hf_path: str, dcp_path: str) -> None: def convert(hf_path: str, dcp_path: str) -> None:
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.") raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...") print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16) config = AutoConfig.from_pretrained(hf_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...") print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path) dcp.save(model.state_dict(), checkpoint_id=dcp_path)

View File

@@ -85,6 +85,28 @@ class TrainingArguments:
default=42, default=42,
metadata={"help": "Random seed that will be set at the beginning of training."}, metadata={"help": "Random seed that will be set at the beginning of training."},
) )
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
)
save_steps: int | None = field(
default=None,
metadata={"help": "Save a training checkpoint every N global steps."},
)
save_epochs: float | None = field(
default=None,
metadata={"help": "Save a training checkpoint every N epochs."},
)
save_ckpt_as_hf: bool = field(
default=False,
metadata={
"help": "Save intermediate checkpoints in HuggingFace format instead of distributed format. Warning: doubles memory usage."
},
)
save_total_limit: int | None = field(
default=None,
metadata={"help": "Maximum number of checkpoints to keep. Oldest checkpoints are deleted."},
)
logging_steps: int = field( logging_steps: int = field(
default=1, default=1,
metadata={"help": "Log metrics every N optimizer steps."}, metadata={"help": "Log metrics every N optimizer steps."},

View File

@@ -45,6 +45,7 @@ from ..utils.callbacks import (
from ..utils.helper import compute_valid_tokens from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator from .utils.batching import BatchGenerator
from .utils.checkpoint import TrainingCheckpointCoordinator
from .utils.rendering import Renderer from .utils.rendering import Renderer
@@ -81,6 +82,10 @@ class BaseTrainer:
else: else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.save_epochs is not None:
steps_per_epoch = len(self.train_batch_generator)
self.args.save_steps = max(1, int(steps_per_epoch * self.args.save_epochs))
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False}) self.model.gradient_checkpointing_enable({"use_reentrant": False})
@@ -107,13 +112,28 @@ class BaseTrainer:
self._init_optimizer() self._init_optimizer()
self._init_lr_scheduler() self._init_lr_scheduler()
self._resume_epoch = 0
self._checkpoint = TrainingCheckpointCoordinator(self)
if self.args.resume_from_checkpoint:
self._checkpoint.resume(self.args.resume_from_checkpoint)
if self.args.save_ckpt_as_hf:
logger.warning_rank0(
"save_ckpt_as_hf is enabled. Intermediate checkpoints will be saved in Hugging Face format. "
"Note that this will significantly increase memory consumption during saving."
)
# Callbacks # Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self) self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []: for cb in callbacks or []:
self.callback_handler.add_callback(cb) self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run. # Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps) self.state = TrainerState(
num_training_steps=self.num_training_steps,
global_step=self.global_step,
epoch=self._resume_epoch,
)
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1: if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future. # qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -206,7 +226,7 @@ class BaseTrainer:
"""Train the model.""" """Train the model."""
self.model.train() self.model.train()
self.callback_handler.on_train_begin(self.args, self.state) self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs): for epoch in range(self._resume_epoch, self.args.num_train_epochs):
self.state.epoch = epoch self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch) self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state) self.callback_handler.on_epoch_begin(self.args, self.state)
@@ -300,6 +320,9 @@ class BaseTrainer:
} }
self.callback_handler.on_log(self.args, self.state, logs) self.callback_handler.on_log(self.args, self.state, logs)
if self.args.save_steps and self.global_step % self.args.save_steps == 0:
self._checkpoint.save(epoch)
# Check if max_steps is reached # Check if max_steps is reached
if self.global_step >= self.num_training_steps: if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")

View File

@@ -165,7 +165,6 @@ class BatchGenerator(Iterator):
def __iter__(self): def __iter__(self):
if not self._is_resuming: if not self._is_resuming:
self._buffer.clear() self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider) self._data_iter = iter(self._data_provider)
self._is_resuming = False self._is_resuming = False
@@ -203,14 +202,12 @@ class BatchGenerator(Iterator):
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
return { return {
"buffer": self._buffer, "buffer": self._buffer.state_dict(),
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(), "data_provider": self._data_provider.state_dict(),
} }
def load_state_dict(self, state: dict[str, Any]) -> None: def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"] self._buffer.load_state_dict(state["buffer"])
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"]) self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True self._is_resuming = True

View File

@@ -0,0 +1,339 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint utilities: low-level helpers and full training save/resume orchestration."""
import glob
import json
import os
import random
import shutil
from typing import Any
import numpy as np
import torch
from safetensors.torch import load_file
from ...accelerator.helper import DeviceType, get_current_accelerator
from ...accelerator.interface import DistributedInterface
from ...utils import logging
logger = logging.get_logger(__name__)
CHECKPOINT_COMPLETE_MARKER = "CHECKPOINT_COMPLETE"
def _parse_checkpoint_step(path: str) -> int:
"""Extract the step number from a checkpoint directory name, or -1 if invalid."""
try:
return int(os.path.basename(path).split("-")[-1])
except ValueError:
return -1
def find_latest_checkpoint(output_dir: str) -> str | None:
"""Find the latest valid checkpoint directory in output_dir."""
pattern = os.path.join(output_dir, "checkpoint-*")
ckpt_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
ckpt_dirs.sort(key=_parse_checkpoint_step)
for d in reversed(ckpt_dirs):
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
return d
return None
def rotate_checkpoints(output_dir: str, limit: int) -> None:
"""Keep only the latest `limit` complete checkpoints, delete older ones and incomplete leftovers."""
pattern = os.path.join(output_dir, "checkpoint-*")
all_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
all_dirs.sort(key=_parse_checkpoint_step)
complete_dirs = []
for d in all_dirs:
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
complete_dirs.append(d)
else:
shutil.rmtree(d)
logger.info_rank0(f"Cleaned up incomplete checkpoint: {d}")
while len(complete_dirs) > limit:
oldest = complete_dirs.pop(0)
shutil.rmtree(oldest)
logger.info_rank0(f"Deleted old checkpoint: {oldest}")
def save_metadata(ckpt_dir: str, **kwargs) -> None:
"""Save training metadata as JSON (rank 0 only)."""
with open(os.path.join(ckpt_dir, "metadata.json"), "w") as f:
json.dump(kwargs, f, indent=2)
def load_metadata(ckpt_dir: str) -> dict:
"""Load training metadata from a checkpoint directory."""
with open(os.path.join(ckpt_dir, "metadata.json")) as f:
return json.load(f)
def _get_accelerator_rng_state():
"""Get RNG state for the current accelerator, device-agnostic."""
device_type = get_current_accelerator().type
if device_type == DeviceType.CUDA:
return torch.cuda.get_rng_state_all()
elif device_type == DeviceType.NPU:
return torch.npu.get_rng_state_all()
elif device_type == DeviceType.XPU:
return torch.xpu.get_rng_state_all()
return None
def _set_accelerator_rng_state(state) -> None:
"""Set RNG state for the current accelerator, device-agnostic."""
if state is None:
return
device_type = get_current_accelerator().type
if device_type == DeviceType.CUDA:
torch.cuda.set_rng_state_all(state)
elif device_type == DeviceType.NPU:
torch.npu.set_rng_state_all(state)
elif device_type == DeviceType.XPU:
torch.xpu.set_rng_state_all(state)
def save_rng_state(ckpt_dir: str, rank: int) -> None:
"""Save per-rank RNG states for reproducibility."""
rng_state = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"accelerator": _get_accelerator_rng_state(),
}
rng_dir = os.path.join(ckpt_dir, "rng_state")
os.makedirs(rng_dir, exist_ok=True)
torch.save(rng_state, os.path.join(rng_dir, f"rank_{rank}.pt"))
def load_rng_state(ckpt_dir: str, rank: int) -> None:
"""Restore per-rank RNG states from a checkpoint."""
path = os.path.join(ckpt_dir, "rng_state", f"rank_{rank}.pt")
if not os.path.exists(path):
logger.warning_rank0(f"RNG state file not found at {path}. Skipping RNG state restoration.")
return
rng_state = torch.load(path, map_location="cpu", weights_only=False)
random.setstate(rng_state["python"])
np.random.set_state(rng_state["numpy"])
torch.random.set_rng_state(rng_state["torch"])
_set_accelerator_rng_state(rng_state.get("accelerator"))
def mark_checkpoint_complete(ckpt_dir: str) -> None:
"""Write a marker file indicating the checkpoint is fully saved."""
open(os.path.join(ckpt_dir, CHECKPOINT_COMPLETE_MARKER), "w").close()
def resolve_resume_checkpoint_path(ckpt_path: str, output_dir: str) -> str | None:
"""Resolve 'auto' to the latest valid checkpoint, or return the path as-is."""
if ckpt_path == "auto":
resolved = find_latest_checkpoint(output_dir)
if resolved is None:
logger.warning_rank0(
"resume_from_checkpoint='auto' but no valid checkpoint found in "
f"'{output_dir}'. Training from scratch."
)
else:
logger.info_rank0(f"Auto-detected latest checkpoint: {resolved}")
return resolved
return ckpt_path
def _save_standard_training_states(
ckpt_dir: str,
model: Any,
optimizer: torch.optim.Optimizer,
processor: Any,
save_ckpt_as_hf: bool,
) -> None:
"""Save model and optimizer for DDP / single-GPU via save_pretrained."""
rank = DistributedInterface().get_rank()
if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer", "state_dict.pt"))
if save_ckpt_as_hf:
logger.info("Standard saving already uses HF format. No additional 'hf_model' directory created.")
def _load_standard_training_states(
ckpt_dir: str,
model: Any,
optimizer: torch.optim.Optimizer,
map_location: torch.device,
) -> None:
"""Load model and optimizer for DDP / single-GPU."""
model_dir = os.path.join(ckpt_dir, "model")
model_to_load = model.module if hasattr(model, "module") else model
is_adapter_ckpt = os.path.exists(os.path.join(model_dir, "adapter_config.json"))
if is_adapter_ckpt:
from peft import set_peft_model_state_dict
adapter_file = os.path.join(model_dir, "adapter_model.safetensors")
if not os.path.exists(adapter_file):
adapter_file = os.path.join(model_dir, "adapter_model.bin")
adapter_state = torch.load(adapter_file, map_location="cpu", weights_only=True)
else:
adapter_state = load_file(adapter_file, device="cpu")
set_peft_model_state_dict(model_to_load, adapter_state)
else:
state_dict = {}
for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))):
state_dict.update(load_file(f, device="cpu"))
if not state_dict:
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict:
model_to_load.load_state_dict(state_dict)
else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
optim_path = os.path.join(ckpt_dir, "optimizer", "state_dict.pt")
if os.path.exists(optim_path):
optimizer.load_state_dict(torch.load(optim_path, map_location=map_location, weights_only=True))
class TrainingCheckpointCoordinator:
"""Coordinates full checkpoint save/resume for a trainer instance."""
def __init__(self, trainer: Any) -> None:
self._t = trainer
@property
def _dist_name(self) -> str | None:
return self._t.args.dist_config.name if self._t.args.dist_config is not None else None
def save(self, epoch: int) -> None:
"""Save a full training checkpoint at the current global step."""
ckpt_dir = os.path.join(self._t.args.output_dir, f"checkpoint-{self._t.global_step}")
os.makedirs(ckpt_dir, exist_ok=True)
rank = DistributedInterface().get_rank()
if rank == 0:
save_metadata(
ckpt_dir,
global_step=self._t.global_step,
epoch=epoch,
num_training_steps=self._t.num_training_steps,
)
if self._dist_name in ("fsdp2", "deepspeed"):
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
DistributedPlugin(self._dist_name).save_checkpoint(
self._t.model,
self._t.optimizer,
ckpt_dir,
save_ckpt_as_hf=self._t.args.save_ckpt_as_hf,
processor=self._t.renderer.processor,
)
else:
_save_standard_training_states(
ckpt_dir,
self._t.model,
self._t.optimizer,
self._t.renderer.processor,
self._t.args.save_ckpt_as_hf,
)
if self._dist_name != "deepspeed" and rank == 0:
torch.save(self._t.lr_scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt"))
dl_dir = os.path.join(ckpt_dir, "dataloader")
os.makedirs(dl_dir, exist_ok=True)
torch.save(
self._t.train_batch_generator.state_dict(),
os.path.join(dl_dir, f"rank_{rank}.pt"),
)
if self._dist_name != "deepspeed":
save_rng_state(ckpt_dir, rank)
DistributedInterface().sync()
if rank == 0:
mark_checkpoint_complete(ckpt_dir)
if self._t.args.save_total_limit is not None:
rotate_checkpoints(self._t.args.output_dir, self._t.args.save_total_limit)
logger.info_rank0(f"Checkpoint saved to {ckpt_dir}")
def resume(self, ckpt_path: str) -> None:
"""Restore full training state from a checkpoint directory."""
ckpt_dir = resolve_resume_checkpoint_path(ckpt_path, self._t.args.output_dir)
if ckpt_dir is None:
return
if not os.path.isdir(ckpt_dir):
raise ValueError(f"Checkpoint directory does not exist: {ckpt_dir}")
rank = DistributedInterface().get_rank()
metadata = load_metadata(ckpt_dir)
self._t.global_step = metadata["global_step"]
self._t._resume_epoch = metadata["epoch"]
if self._dist_name in ("fsdp2", "deepspeed"):
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
DistributedPlugin(self._dist_name).load_checkpoint(
self._t.model,
self._t.optimizer,
ckpt_dir,
processor=self._t.renderer.processor,
)
else:
_load_standard_training_states(
ckpt_dir,
self._t.model,
self._t.optimizer,
self._t.device,
)
if self._dist_name != "deepspeed":
sched_path = os.path.join(ckpt_dir, "scheduler.pt")
if os.path.exists(sched_path):
self._t.lr_scheduler.load_state_dict(torch.load(sched_path, map_location="cpu", weights_only=True))
dl_path = os.path.join(ckpt_dir, "dataloader", f"rank_{rank}.pt")
if os.path.exists(dl_path):
self._t.train_batch_generator.load_state_dict(torch.load(dl_path, map_location="cpu", weights_only=False))
else:
logger.warning_rank0(
f"Dataloader state file not found at {dl_path}. Skipping Dataloader state restoration."
)
if self._dist_name != "deepspeed":
load_rng_state(ckpt_dir, rank)
logger.info_rank0(f"Resumed from checkpoint: step={self._t.global_step}, epoch={self._t._resume_epoch}")

View File

@@ -19,6 +19,7 @@ this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving. initialization, backward, gradient accumulation, and model saving.
""" """
import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@@ -127,3 +128,22 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}") logger.info_rank0(f"Model saved to {output_dir}")
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
processor = kwargs.get("processor", None)
# Always save DeepSpeed state for resume capability
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
accelerator.save_state(ckpt_dir)
# Additionally save HF format if requested
if save_ckpt_as_hf:
hf_dir = os.path.join(ckpt_dir, "hf_model")
save_model(model, hf_dir, processor)
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
accelerator.load_state(ckpt_dir)

View File

@@ -18,9 +18,16 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
MixedPrecisionPolicy, MixedPrecisionPolicy,
@@ -66,6 +73,51 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
logger.info(f"Model saved to {output_dir}") logger.info(f"Model saved to {output_dir}")
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
processor = kwargs.get("processor", None)
# Always save DCP format for resume capability
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
model_state = get_model_state_dict(model, options=options)
dcp.save(state_dict=model_state, checkpoint_id=os.path.join(ckpt_dir, "model"))
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
dcp.save(state_dict=optim_state, checkpoint_id=os.path.join(ckpt_dir, "optimizer"))
# Additionally save HF format if requested
if save_ckpt_as_hf:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving additional HF format checkpoint...")
hf_options = StateDictOptions(full_state_dict=True, cpu_offload=True)
hf_state_dict = get_model_state_dict(model, options=hf_options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
hf_dir = os.path.join(ckpt_dir, "hf_model")
model_to_save.save_pretrained(hf_dir, state_dict=hf_state_dict, max_shard_size="4GB")
if processor is not None:
processor.save_pretrained(hf_dir, max_shard_size="4GB")
logger.info(f"Additional HF format checkpoint saved to {hf_dir}")
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
ckpt_model_dir = os.path.join(ckpt_dir, "model")
model_state = get_model_state_dict(model, options=options)
dcp.load(state_dict=model_state, checkpoint_id=ckpt_model_dir)
set_model_state_dict(model, model_state, options=options)
ckpt_optim_dir = os.path.join(ckpt_dir, "optimizer")
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
dcp.load(state_dict=optim_state, checkpoint_id=ckpt_optim_dir)
set_optimizer_state_dict(model, optimizer, optim_state, options=options)
class FSDP2Engine: class FSDP2Engine:
def __init__(self, dist_config: dict): def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface() self.dist_interface = DistributedInterface()

View File

@@ -16,6 +16,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch
from ....config.arg_utils import PluginConfig from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin from ....utils.plugin import BasePlugin
@@ -43,6 +45,20 @@ def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> N
return save_model(model, output_dir, processor) return save_model(model, output_dir, processor)
@DistributedPlugin("fsdp2").register("save_checkpoint")
def save_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .fsdp2 import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("fsdp2").register("load_checkpoint")
def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .fsdp2 import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("deepspeed").register() @DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine from .deepspeed import DeepSpeedEngine
@@ -59,3 +75,17 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
from .deepspeed import save_model from .deepspeed import save_model
return save_model(model, output_dir, processor) return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register("save_checkpoint")
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
from .deepspeed import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir)
@DistributedPlugin("deepspeed").register("load_checkpoint")
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
from .deepspeed import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir)