[v1] support resume training from checkpoint (#10280)

Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
浮梦
2026-04-20 20:28:08 +08:00
committed by GitHub
parent c5aecaf31d
commit c4bbac49b2
9 changed files with 577 additions and 10 deletions

76
scripts/dcp2hf.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a DCP checkpoint to HuggingFace model format.
Usage:
python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config
Arguments:
dcp_path: Path to the DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
import transformers
from transformers import AutoConfig
def convert(dcp_path: str, hf_path: str, config_path: str) -> None:
"""Convert DCP model weights to HF.
Note: this script is used to convert a DCP checkpoint to HuggingFace model format,
it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer,
you may need to copy from the original model.
Args:
dcp_path: DCP checkpoint directory.
hf_path: Output path (directory) for HuggingFace model.
config_path: Path to the HuggingFace model directory containing config.json.
"""
if not dcp_path or not hf_path or not config_path:
raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.")
print(f"Loading config from {config_path}...")
config = AutoConfig.from_pretrained(config_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
print("Initializing model on CPU...")
model = model_cls(config).to(torch.bfloat16)
print(f"Loading DCP from {dcp_path}...")
state_dict = model.state_dict()
dcp.load(state_dict, checkpoint_id=dcp_path)
model.load_state_dict(state_dict)
print(f"Saving to HF format at {hf_path}...")
model.save_pretrained(hf_path)
config.save_pretrained(hf_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -25,7 +25,8 @@ Arguments:
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
import transformers
from transformers import AutoConfig
def convert(hf_path: str, dcp_path: str) -> None:
@@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(hf_path)
architectures = getattr(config, "architectures", [])
if architectures:
model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM)
else:
model_cls = transformers.AutoModelForCausalLM
model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)

View File

@@ -85,6 +85,28 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
)
save_steps: int | None = field(
default=None,
metadata={"help": "Save a training checkpoint every N global steps."},
)
save_epochs: float | None = field(
default=None,
metadata={"help": "Save a training checkpoint every N epochs."},
)
save_ckpt_as_hf: bool = field(
default=False,
metadata={
"help": "Save intermediate checkpoints in HuggingFace format instead of distributed format. Warning: doubles memory usage."
},
)
save_total_limit: int | None = field(
default=None,
metadata={"help": "Maximum number of checkpoints to keep. Oldest checkpoints are deleted."},
)
logging_steps: int = field(
default=1,
metadata={"help": "Log metrics every N optimizer steps."},

View File

@@ -45,6 +45,7 @@ from ..utils.callbacks import (
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
from .utils.checkpoint import TrainingCheckpointCoordinator
from .utils.rendering import Renderer
@@ -81,6 +82,10 @@ class BaseTrainer:
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.save_epochs is not None:
steps_per_epoch = len(self.train_batch_generator)
self.args.save_steps = max(1, int(steps_per_epoch * self.args.save_epochs))
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
@@ -107,13 +112,28 @@ class BaseTrainer:
self._init_optimizer()
self._init_lr_scheduler()
self._resume_epoch = 0
self._checkpoint = TrainingCheckpointCoordinator(self)
if self.args.resume_from_checkpoint:
self._checkpoint.resume(self.args.resume_from_checkpoint)
if self.args.save_ckpt_as_hf:
logger.warning_rank0(
"save_ckpt_as_hf is enabled. Intermediate checkpoints will be saved in Hugging Face format. "
"Note that this will significantly increase memory consumption during saving."
)
# Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []:
self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
self.state = TrainerState(
num_training_steps=self.num_training_steps,
global_step=self.global_step,
epoch=self._resume_epoch,
)
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -206,7 +226,7 @@ class BaseTrainer:
"""Train the model."""
self.model.train()
self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs):
for epoch in range(self._resume_epoch, self.args.num_train_epochs):
self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
@@ -300,6 +320,9 @@ class BaseTrainer:
}
self.callback_handler.on_log(self.args, self.state, logs)
if self.args.save_steps and self.global_step % self.args.save_steps == 0:
self._checkpoint.save(epoch)
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")

View File

@@ -165,7 +165,6 @@ class BatchGenerator(Iterator):
def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider)
self._is_resuming = False
@@ -203,14 +202,12 @@ class BatchGenerator(Iterator):
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"buffer": self._buffer.state_dict(),
"data_provider": self._data_provider.state_dict(),
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._buffer.load_state_dict(state["buffer"])
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True

View File

@@ -0,0 +1,339 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint utilities: low-level helpers and full training save/resume orchestration."""
import glob
import json
import os
import random
import shutil
from typing import Any
import numpy as np
import torch
from safetensors.torch import load_file
from ...accelerator.helper import DeviceType, get_current_accelerator
from ...accelerator.interface import DistributedInterface
from ...utils import logging
logger = logging.get_logger(__name__)
CHECKPOINT_COMPLETE_MARKER = "CHECKPOINT_COMPLETE"
def _parse_checkpoint_step(path: str) -> int:
"""Extract the step number from a checkpoint directory name, or -1 if invalid."""
try:
return int(os.path.basename(path).split("-")[-1])
except ValueError:
return -1
def find_latest_checkpoint(output_dir: str) -> str | None:
"""Find the latest valid checkpoint directory in output_dir."""
pattern = os.path.join(output_dir, "checkpoint-*")
ckpt_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
ckpt_dirs.sort(key=_parse_checkpoint_step)
for d in reversed(ckpt_dirs):
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
return d
return None
def rotate_checkpoints(output_dir: str, limit: int) -> None:
"""Keep only the latest `limit` complete checkpoints, delete older ones and incomplete leftovers."""
pattern = os.path.join(output_dir, "checkpoint-*")
all_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0]
all_dirs.sort(key=_parse_checkpoint_step)
complete_dirs = []
for d in all_dirs:
if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)):
complete_dirs.append(d)
else:
shutil.rmtree(d)
logger.info_rank0(f"Cleaned up incomplete checkpoint: {d}")
while len(complete_dirs) > limit:
oldest = complete_dirs.pop(0)
shutil.rmtree(oldest)
logger.info_rank0(f"Deleted old checkpoint: {oldest}")
def save_metadata(ckpt_dir: str, **kwargs) -> None:
"""Save training metadata as JSON (rank 0 only)."""
with open(os.path.join(ckpt_dir, "metadata.json"), "w") as f:
json.dump(kwargs, f, indent=2)
def load_metadata(ckpt_dir: str) -> dict:
"""Load training metadata from a checkpoint directory."""
with open(os.path.join(ckpt_dir, "metadata.json")) as f:
return json.load(f)
def _get_accelerator_rng_state():
"""Get RNG state for the current accelerator, device-agnostic."""
device_type = get_current_accelerator().type
if device_type == DeviceType.CUDA:
return torch.cuda.get_rng_state_all()
elif device_type == DeviceType.NPU:
return torch.npu.get_rng_state_all()
elif device_type == DeviceType.XPU:
return torch.xpu.get_rng_state_all()
return None
def _set_accelerator_rng_state(state) -> None:
"""Set RNG state for the current accelerator, device-agnostic."""
if state is None:
return
device_type = get_current_accelerator().type
if device_type == DeviceType.CUDA:
torch.cuda.set_rng_state_all(state)
elif device_type == DeviceType.NPU:
torch.npu.set_rng_state_all(state)
elif device_type == DeviceType.XPU:
torch.xpu.set_rng_state_all(state)
def save_rng_state(ckpt_dir: str, rank: int) -> None:
"""Save per-rank RNG states for reproducibility."""
rng_state = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"accelerator": _get_accelerator_rng_state(),
}
rng_dir = os.path.join(ckpt_dir, "rng_state")
os.makedirs(rng_dir, exist_ok=True)
torch.save(rng_state, os.path.join(rng_dir, f"rank_{rank}.pt"))
def load_rng_state(ckpt_dir: str, rank: int) -> None:
"""Restore per-rank RNG states from a checkpoint."""
path = os.path.join(ckpt_dir, "rng_state", f"rank_{rank}.pt")
if not os.path.exists(path):
logger.warning_rank0(f"RNG state file not found at {path}. Skipping RNG state restoration.")
return
rng_state = torch.load(path, map_location="cpu", weights_only=False)
random.setstate(rng_state["python"])
np.random.set_state(rng_state["numpy"])
torch.random.set_rng_state(rng_state["torch"])
_set_accelerator_rng_state(rng_state.get("accelerator"))
def mark_checkpoint_complete(ckpt_dir: str) -> None:
"""Write a marker file indicating the checkpoint is fully saved."""
open(os.path.join(ckpt_dir, CHECKPOINT_COMPLETE_MARKER), "w").close()
def resolve_resume_checkpoint_path(ckpt_path: str, output_dir: str) -> str | None:
"""Resolve 'auto' to the latest valid checkpoint, or return the path as-is."""
if ckpt_path == "auto":
resolved = find_latest_checkpoint(output_dir)
if resolved is None:
logger.warning_rank0(
"resume_from_checkpoint='auto' but no valid checkpoint found in "
f"'{output_dir}'. Training from scratch."
)
else:
logger.info_rank0(f"Auto-detected latest checkpoint: {resolved}")
return resolved
return ckpt_path
def _save_standard_training_states(
ckpt_dir: str,
model: Any,
optimizer: torch.optim.Optimizer,
processor: Any,
save_ckpt_as_hf: bool,
) -> None:
"""Save model and optimizer for DDP / single-GPU via save_pretrained."""
rank = DistributedInterface().get_rank()
if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer", "state_dict.pt"))
if save_ckpt_as_hf:
logger.info("Standard saving already uses HF format. No additional 'hf_model' directory created.")
def _load_standard_training_states(
ckpt_dir: str,
model: Any,
optimizer: torch.optim.Optimizer,
map_location: torch.device,
) -> None:
"""Load model and optimizer for DDP / single-GPU."""
model_dir = os.path.join(ckpt_dir, "model")
model_to_load = model.module if hasattr(model, "module") else model
is_adapter_ckpt = os.path.exists(os.path.join(model_dir, "adapter_config.json"))
if is_adapter_ckpt:
from peft import set_peft_model_state_dict
adapter_file = os.path.join(model_dir, "adapter_model.safetensors")
if not os.path.exists(adapter_file):
adapter_file = os.path.join(model_dir, "adapter_model.bin")
adapter_state = torch.load(adapter_file, map_location="cpu", weights_only=True)
else:
adapter_state = load_file(adapter_file, device="cpu")
set_peft_model_state_dict(model_to_load, adapter_state)
else:
state_dict = {}
for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))):
state_dict.update(load_file(f, device="cpu"))
if not state_dict:
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict:
model_to_load.load_state_dict(state_dict)
else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
optim_path = os.path.join(ckpt_dir, "optimizer", "state_dict.pt")
if os.path.exists(optim_path):
optimizer.load_state_dict(torch.load(optim_path, map_location=map_location, weights_only=True))
class TrainingCheckpointCoordinator:
"""Coordinates full checkpoint save/resume for a trainer instance."""
def __init__(self, trainer: Any) -> None:
self._t = trainer
@property
def _dist_name(self) -> str | None:
return self._t.args.dist_config.name if self._t.args.dist_config is not None else None
def save(self, epoch: int) -> None:
"""Save a full training checkpoint at the current global step."""
ckpt_dir = os.path.join(self._t.args.output_dir, f"checkpoint-{self._t.global_step}")
os.makedirs(ckpt_dir, exist_ok=True)
rank = DistributedInterface().get_rank()
if rank == 0:
save_metadata(
ckpt_dir,
global_step=self._t.global_step,
epoch=epoch,
num_training_steps=self._t.num_training_steps,
)
if self._dist_name in ("fsdp2", "deepspeed"):
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
DistributedPlugin(self._dist_name).save_checkpoint(
self._t.model,
self._t.optimizer,
ckpt_dir,
save_ckpt_as_hf=self._t.args.save_ckpt_as_hf,
processor=self._t.renderer.processor,
)
else:
_save_standard_training_states(
ckpt_dir,
self._t.model,
self._t.optimizer,
self._t.renderer.processor,
self._t.args.save_ckpt_as_hf,
)
if self._dist_name != "deepspeed" and rank == 0:
torch.save(self._t.lr_scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt"))
dl_dir = os.path.join(ckpt_dir, "dataloader")
os.makedirs(dl_dir, exist_ok=True)
torch.save(
self._t.train_batch_generator.state_dict(),
os.path.join(dl_dir, f"rank_{rank}.pt"),
)
if self._dist_name != "deepspeed":
save_rng_state(ckpt_dir, rank)
DistributedInterface().sync()
if rank == 0:
mark_checkpoint_complete(ckpt_dir)
if self._t.args.save_total_limit is not None:
rotate_checkpoints(self._t.args.output_dir, self._t.args.save_total_limit)
logger.info_rank0(f"Checkpoint saved to {ckpt_dir}")
def resume(self, ckpt_path: str) -> None:
"""Restore full training state from a checkpoint directory."""
ckpt_dir = resolve_resume_checkpoint_path(ckpt_path, self._t.args.output_dir)
if ckpt_dir is None:
return
if not os.path.isdir(ckpt_dir):
raise ValueError(f"Checkpoint directory does not exist: {ckpt_dir}")
rank = DistributedInterface().get_rank()
metadata = load_metadata(ckpt_dir)
self._t.global_step = metadata["global_step"]
self._t._resume_epoch = metadata["epoch"]
if self._dist_name in ("fsdp2", "deepspeed"):
from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin
DistributedPlugin(self._dist_name).load_checkpoint(
self._t.model,
self._t.optimizer,
ckpt_dir,
processor=self._t.renderer.processor,
)
else:
_load_standard_training_states(
ckpt_dir,
self._t.model,
self._t.optimizer,
self._t.device,
)
if self._dist_name != "deepspeed":
sched_path = os.path.join(ckpt_dir, "scheduler.pt")
if os.path.exists(sched_path):
self._t.lr_scheduler.load_state_dict(torch.load(sched_path, map_location="cpu", weights_only=True))
dl_path = os.path.join(ckpt_dir, "dataloader", f"rank_{rank}.pt")
if os.path.exists(dl_path):
self._t.train_batch_generator.load_state_dict(torch.load(dl_path, map_location="cpu", weights_only=False))
else:
logger.warning_rank0(
f"Dataloader state file not found at {dl_path}. Skipping Dataloader state restoration."
)
if self._dist_name != "deepspeed":
load_rng_state(ckpt_dir, rank)
logger.info_rank0(f"Resumed from checkpoint: step={self._t.global_step}, epoch={self._t._resume_epoch}")

View File

@@ -19,6 +19,7 @@ this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving.
"""
import os
from typing import Any, Optional
import torch
@@ -127,3 +128,22 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}")
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
processor = kwargs.get("processor", None)
# Always save DeepSpeed state for resume capability
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
accelerator.save_state(ckpt_dir)
# Additionally save HF format if requested
if save_ckpt_as_hf:
hf_dir = os.path.join(ckpt_dir, "hf_model")
save_model(model, hf_dir, processor)
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
accelerator.load_state(ckpt_dir)

View File

@@ -18,9 +18,16 @@ import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
@@ -66,6 +73,51 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
logger.info(f"Model saved to {output_dir}")
def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False)
processor = kwargs.get("processor", None)
# Always save DCP format for resume capability
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
model_state = get_model_state_dict(model, options=options)
dcp.save(state_dict=model_state, checkpoint_id=os.path.join(ckpt_dir, "model"))
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
dcp.save(state_dict=optim_state, checkpoint_id=os.path.join(ckpt_dir, "optimizer"))
# Additionally save HF format if requested
if save_ckpt_as_hf:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving additional HF format checkpoint...")
hf_options = StateDictOptions(full_state_dict=True, cpu_offload=True)
hf_state_dict = get_model_state_dict(model, options=hf_options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
hf_dir = os.path.join(ckpt_dir, "hf_model")
model_to_save.save_pretrained(hf_dir, state_dict=hf_state_dict, max_shard_size="4GB")
if processor is not None:
processor.save_pretrained(hf_dir, max_shard_size="4GB")
logger.info(f"Additional HF format checkpoint saved to {hf_dir}")
def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
options = StateDictOptions(full_state_dict=False, cpu_offload=True)
ckpt_model_dir = os.path.join(ckpt_dir, "model")
model_state = get_model_state_dict(model, options=options)
dcp.load(state_dict=model_state, checkpoint_id=ckpt_model_dir)
set_model_state_dict(model, model_state, options=options)
ckpt_optim_dir = os.path.join(ckpt_dir, "optimizer")
optim_state = get_optimizer_state_dict(model, optimizer, options=options)
dcp.load(state_dict=optim_state, checkpoint_id=ckpt_optim_dir)
set_optimizer_state_dict(model, optimizer, optim_state, options=options)
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()

View File

@@ -16,6 +16,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
@@ -43,6 +45,20 @@ def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> N
return save_model(model, output_dir, processor)
@DistributedPlugin("fsdp2").register("save_checkpoint")
def save_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .fsdp2 import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("fsdp2").register("load_checkpoint")
def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .fsdp2 import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine
@@ -59,3 +75,17 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
from .deepspeed import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register("save_checkpoint")
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
from .deepspeed import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir)
@DistributedPlugin("deepspeed").register("load_checkpoint")
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
from .deepspeed import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir)