diff --git a/scripts/dcp2hf.py b/scripts/dcp2hf.py new file mode 100644 index 000000000..8e3256bbb --- /dev/null +++ b/scripts/dcp2hf.py @@ -0,0 +1,76 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a DCP checkpoint to HuggingFace model format. + +Usage: + python scripts/dcp2hf.py convert --dcp_path=/path/to/dcp --hf_path=/path/to/hf --config_path=/path/to/config + +Arguments: + dcp_path: Path to the DCP checkpoint directory. + hf_path: Output path (directory) for HuggingFace model. + config_path: Path to the HuggingFace model directory containing config.json. +""" + +import fire +import torch +import torch.distributed.checkpoint as dcp +import transformers +from transformers import AutoConfig + + +def convert(dcp_path: str, hf_path: str, config_path: str) -> None: + """Convert DCP model weights to HF. + + Note: this script is used to convert a DCP checkpoint to HuggingFace model format, + it will just convert the DCP checkpoint to a HuggingFace model format, for the tokenizer, + you may need to copy from the original model. + + Args: + dcp_path: DCP checkpoint directory. + hf_path: Output path (directory) for HuggingFace model. + config_path: Path to the HuggingFace model directory containing config.json. + """ + if not dcp_path or not hf_path or not config_path: + raise ValueError("All 'dcp_path', 'hf_path', and 'config_path' are required.") + + print(f"Loading config from {config_path}...") + config = AutoConfig.from_pretrained(config_path) + architectures = getattr(config, "architectures", []) + if architectures: + model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM) + else: + model_cls = transformers.AutoModelForCausalLM + + print("Initializing model on CPU...") + model = model_cls(config).to(torch.bfloat16) + + print(f"Loading DCP from {dcp_path}...") + state_dict = model.state_dict() + dcp.load(state_dict, checkpoint_id=dcp_path) + model.load_state_dict(state_dict) + + print(f"Saving to HF format at {hf_path}...") + model.save_pretrained(hf_path) + config.save_pretrained(hf_path) + print("Done!") + + +def help() -> None: + """Show help message.""" + print(__doc__) + + +if __name__ == "__main__": + fire.Fire({"convert": convert, "help": help, "--convert": convert}) diff --git a/scripts/hf2dcp.py b/scripts/hf2dcp.py index 9e6fbf8c5..da51580b6 100644 --- a/scripts/hf2dcp.py +++ b/scripts/hf2dcp.py @@ -25,7 +25,8 @@ Arguments: import fire import torch import torch.distributed.checkpoint as dcp -from transformers import AutoModelForCausalLM +import transformers +from transformers import AutoConfig def convert(hf_path: str, dcp_path: str) -> None: @@ -39,7 +40,14 @@ def convert(hf_path: str, dcp_path: str) -> None: raise ValueError("Both 'hf_path' and 'dcp_path' are required.") print(f"Loading HF model from {hf_path}...") - model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16) + config = AutoConfig.from_pretrained(hf_path) + architectures = getattr(config, "architectures", []) + if architectures: + model_cls = getattr(transformers, architectures[0], transformers.AutoModelForCausalLM) + else: + model_cls = transformers.AutoModelForCausalLM + + model = model_cls.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16) print(f"Saving to DCP format at {dcp_path}...") dcp.save(model.state_dict(), checkpoint_id=dcp_path) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 750899ccd..30b95f99e 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -85,6 +85,28 @@ class TrainingArguments: default=42, metadata={"help": "Random seed that will be set at the beginning of training."}, ) + resume_from_checkpoint: str | None = field( + default=None, + metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."}, + ) + save_steps: int | None = field( + default=None, + metadata={"help": "Save a training checkpoint every N global steps."}, + ) + save_epochs: float | None = field( + default=None, + metadata={"help": "Save a training checkpoint every N epochs."}, + ) + save_ckpt_as_hf: bool = field( + default=False, + metadata={ + "help": "Save intermediate checkpoints in HuggingFace format instead of distributed format. Warning: doubles memory usage." + }, + ) + save_total_limit: int | None = field( + default=None, + metadata={"help": "Maximum number of checkpoints to keep. Oldest checkpoints are deleted."}, + ) logging_steps: int = field( default=1, metadata={"help": "Log metrics every N optimizer steps."}, diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index a8afffec1..11cec3c65 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -45,6 +45,7 @@ from ..utils.callbacks import ( from ..utils.helper import compute_valid_tokens from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset from .utils.batching import BatchGenerator +from .utils.checkpoint import TrainingCheckpointCoordinator from .utils.rendering import Renderer @@ -81,6 +82,10 @@ class BaseTrainer: else: self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator) + if self.args.save_epochs is not None: + steps_per_epoch = len(self.train_batch_generator) + self.args.save_steps = max(1, int(steps_per_epoch * self.args.save_epochs)) + if self.args.enable_activation_checkpointing: self.model.gradient_checkpointing_enable({"use_reentrant": False}) @@ -107,13 +112,28 @@ class BaseTrainer: self._init_optimizer() self._init_lr_scheduler() + self._resume_epoch = 0 + self._checkpoint = TrainingCheckpointCoordinator(self) + if self.args.resume_from_checkpoint: + self._checkpoint.resume(self.args.resume_from_checkpoint) + + if self.args.save_ckpt_as_hf: + logger.warning_rank0( + "save_ckpt_as_hf is enabled. Intermediate checkpoints will be saved in Hugging Face format. " + "Note that this will significantly increase memory consumption during saving." + ) + # Callbacks self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self) for cb in callbacks or []: self.callback_handler.add_callback(cb) # Callbacks: TrainerState tracks progress across the full run. - self.state = TrainerState(num_training_steps=self.num_training_steps) + self.state = TrainerState( + num_training_steps=self.num_training_steps, + global_step=self.global_step, + epoch=self._resume_epoch, + ) if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1: # qwen3.5 is not supported because of the different attention implementation, which will be supported in the future. @@ -206,7 +226,7 @@ class BaseTrainer: """Train the model.""" self.model.train() self.callback_handler.on_train_begin(self.args, self.state) - for epoch in range(self.args.num_train_epochs): + for epoch in range(self._resume_epoch, self.args.num_train_epochs): self.state.epoch = epoch self.train_batch_generator.set_epoch(epoch) self.callback_handler.on_epoch_begin(self.args, self.state) @@ -300,6 +320,9 @@ class BaseTrainer: } self.callback_handler.on_log(self.args, self.state, logs) + if self.args.save_steps and self.global_step % self.args.save_steps == 0: + self._checkpoint.save(epoch) + # Check if max_steps is reached if self.global_step >= self.num_training_steps: logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 2243fb078..25a626b94 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -165,7 +165,6 @@ class BatchGenerator(Iterator): def __iter__(self): if not self._is_resuming: self._buffer.clear() - self._buffer_tokens = 0 self._data_iter = iter(self._data_provider) self._is_resuming = False @@ -203,14 +202,12 @@ class BatchGenerator(Iterator): def state_dict(self) -> dict[str, Any]: return { - "buffer": self._buffer, - "buffer_tokens": self._buffer_tokens, + "buffer": self._buffer.state_dict(), "data_provider": self._data_provider.state_dict(), } def load_state_dict(self, state: dict[str, Any]) -> None: - self._buffer = state["buffer"] - self._buffer_tokens = state["buffer_tokens"] + self._buffer.load_state_dict(state["buffer"]) self._data_provider.load_state_dict(state["data_provider"]) self._is_resuming = True diff --git a/src/llamafactory/v1/core/utils/checkpoint.py b/src/llamafactory/v1/core/utils/checkpoint.py new file mode 100644 index 000000000..b9cfcd911 --- /dev/null +++ b/src/llamafactory/v1/core/utils/checkpoint.py @@ -0,0 +1,339 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities: low-level helpers and full training save/resume orchestration.""" + +import glob +import json +import os +import random +import shutil +from typing import Any + +import numpy as np +import torch +from safetensors.torch import load_file + +from ...accelerator.helper import DeviceType, get_current_accelerator +from ...accelerator.interface import DistributedInterface +from ...utils import logging + + +logger = logging.get_logger(__name__) + +CHECKPOINT_COMPLETE_MARKER = "CHECKPOINT_COMPLETE" + + +def _parse_checkpoint_step(path: str) -> int: + """Extract the step number from a checkpoint directory name, or -1 if invalid.""" + try: + return int(os.path.basename(path).split("-")[-1]) + except ValueError: + return -1 + + +def find_latest_checkpoint(output_dir: str) -> str | None: + """Find the latest valid checkpoint directory in output_dir.""" + pattern = os.path.join(output_dir, "checkpoint-*") + ckpt_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0] + ckpt_dirs.sort(key=_parse_checkpoint_step) + for d in reversed(ckpt_dirs): + if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)): + return d + return None + + +def rotate_checkpoints(output_dir: str, limit: int) -> None: + """Keep only the latest `limit` complete checkpoints, delete older ones and incomplete leftovers.""" + pattern = os.path.join(output_dir, "checkpoint-*") + all_dirs = [d for d in glob.glob(pattern) if _parse_checkpoint_step(d) >= 0] + all_dirs.sort(key=_parse_checkpoint_step) + + complete_dirs = [] + for d in all_dirs: + if os.path.exists(os.path.join(d, CHECKPOINT_COMPLETE_MARKER)): + complete_dirs.append(d) + else: + shutil.rmtree(d) + logger.info_rank0(f"Cleaned up incomplete checkpoint: {d}") + + while len(complete_dirs) > limit: + oldest = complete_dirs.pop(0) + shutil.rmtree(oldest) + logger.info_rank0(f"Deleted old checkpoint: {oldest}") + + +def save_metadata(ckpt_dir: str, **kwargs) -> None: + """Save training metadata as JSON (rank 0 only).""" + with open(os.path.join(ckpt_dir, "metadata.json"), "w") as f: + json.dump(kwargs, f, indent=2) + + +def load_metadata(ckpt_dir: str) -> dict: + """Load training metadata from a checkpoint directory.""" + with open(os.path.join(ckpt_dir, "metadata.json")) as f: + return json.load(f) + + +def _get_accelerator_rng_state(): + """Get RNG state for the current accelerator, device-agnostic.""" + device_type = get_current_accelerator().type + if device_type == DeviceType.CUDA: + return torch.cuda.get_rng_state_all() + elif device_type == DeviceType.NPU: + return torch.npu.get_rng_state_all() + elif device_type == DeviceType.XPU: + return torch.xpu.get_rng_state_all() + return None + + +def _set_accelerator_rng_state(state) -> None: + """Set RNG state for the current accelerator, device-agnostic.""" + if state is None: + return + + device_type = get_current_accelerator().type + if device_type == DeviceType.CUDA: + torch.cuda.set_rng_state_all(state) + elif device_type == DeviceType.NPU: + torch.npu.set_rng_state_all(state) + elif device_type == DeviceType.XPU: + torch.xpu.set_rng_state_all(state) + + +def save_rng_state(ckpt_dir: str, rank: int) -> None: + """Save per-rank RNG states for reproducibility.""" + rng_state = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "torch": torch.random.get_rng_state(), + "accelerator": _get_accelerator_rng_state(), + } + rng_dir = os.path.join(ckpt_dir, "rng_state") + os.makedirs(rng_dir, exist_ok=True) + torch.save(rng_state, os.path.join(rng_dir, f"rank_{rank}.pt")) + + +def load_rng_state(ckpt_dir: str, rank: int) -> None: + """Restore per-rank RNG states from a checkpoint.""" + path = os.path.join(ckpt_dir, "rng_state", f"rank_{rank}.pt") + + if not os.path.exists(path): + logger.warning_rank0(f"RNG state file not found at {path}. Skipping RNG state restoration.") + return + + rng_state = torch.load(path, map_location="cpu", weights_only=False) + random.setstate(rng_state["python"]) + np.random.set_state(rng_state["numpy"]) + torch.random.set_rng_state(rng_state["torch"]) + _set_accelerator_rng_state(rng_state.get("accelerator")) + + +def mark_checkpoint_complete(ckpt_dir: str) -> None: + """Write a marker file indicating the checkpoint is fully saved.""" + open(os.path.join(ckpt_dir, CHECKPOINT_COMPLETE_MARKER), "w").close() + + +def resolve_resume_checkpoint_path(ckpt_path: str, output_dir: str) -> str | None: + """Resolve 'auto' to the latest valid checkpoint, or return the path as-is.""" + if ckpt_path == "auto": + resolved = find_latest_checkpoint(output_dir) + if resolved is None: + logger.warning_rank0( + "resume_from_checkpoint='auto' but no valid checkpoint found in " + f"'{output_dir}'. Training from scratch." + ) + else: + logger.info_rank0(f"Auto-detected latest checkpoint: {resolved}") + return resolved + return ckpt_path + + +def _save_standard_training_states( + ckpt_dir: str, + model: Any, + optimizer: torch.optim.Optimizer, + processor: Any, + save_ckpt_as_hf: bool, +) -> None: + """Save model and optimizer for DDP / single-GPU via save_pretrained.""" + rank = DistributedInterface().get_rank() + if rank == 0: + model_to_save = model.module if hasattr(model, "module") else model + model_dir = os.path.join(ckpt_dir, "model") + model_to_save.save_pretrained(model_dir, max_shard_size="4GB") + processor.save_pretrained(model_dir) + + os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True) + torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer", "state_dict.pt")) + + if save_ckpt_as_hf: + logger.info("Standard saving already uses HF format. No additional 'hf_model' directory created.") + + +def _load_standard_training_states( + ckpt_dir: str, + model: Any, + optimizer: torch.optim.Optimizer, + map_location: torch.device, +) -> None: + """Load model and optimizer for DDP / single-GPU.""" + model_dir = os.path.join(ckpt_dir, "model") + model_to_load = model.module if hasattr(model, "module") else model + + is_adapter_ckpt = os.path.exists(os.path.join(model_dir, "adapter_config.json")) + + if is_adapter_ckpt: + from peft import set_peft_model_state_dict + + adapter_file = os.path.join(model_dir, "adapter_model.safetensors") + if not os.path.exists(adapter_file): + adapter_file = os.path.join(model_dir, "adapter_model.bin") + adapter_state = torch.load(adapter_file, map_location="cpu", weights_only=True) + else: + adapter_state = load_file(adapter_file, device="cpu") + set_peft_model_state_dict(model_to_load, adapter_state) + else: + state_dict = {} + for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))): + state_dict.update(load_file(f, device="cpu")) + if not state_dict: + for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))): + state_dict.update(torch.load(f, map_location="cpu", weights_only=True)) + if state_dict: + model_to_load.load_state_dict(state_dict) + else: + logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.") + + optim_path = os.path.join(ckpt_dir, "optimizer", "state_dict.pt") + if os.path.exists(optim_path): + optimizer.load_state_dict(torch.load(optim_path, map_location=map_location, weights_only=True)) + + +class TrainingCheckpointCoordinator: + """Coordinates full checkpoint save/resume for a trainer instance.""" + + def __init__(self, trainer: Any) -> None: + self._t = trainer + + @property + def _dist_name(self) -> str | None: + return self._t.args.dist_config.name if self._t.args.dist_config is not None else None + + def save(self, epoch: int) -> None: + """Save a full training checkpoint at the current global step.""" + ckpt_dir = os.path.join(self._t.args.output_dir, f"checkpoint-{self._t.global_step}") + os.makedirs(ckpt_dir, exist_ok=True) + + rank = DistributedInterface().get_rank() + + if rank == 0: + save_metadata( + ckpt_dir, + global_step=self._t.global_step, + epoch=epoch, + num_training_steps=self._t.num_training_steps, + ) + + if self._dist_name in ("fsdp2", "deepspeed"): + from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin + + DistributedPlugin(self._dist_name).save_checkpoint( + self._t.model, + self._t.optimizer, + ckpt_dir, + save_ckpt_as_hf=self._t.args.save_ckpt_as_hf, + processor=self._t.renderer.processor, + ) + else: + _save_standard_training_states( + ckpt_dir, + self._t.model, + self._t.optimizer, + self._t.renderer.processor, + self._t.args.save_ckpt_as_hf, + ) + + if self._dist_name != "deepspeed" and rank == 0: + torch.save(self._t.lr_scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt")) + + dl_dir = os.path.join(ckpt_dir, "dataloader") + os.makedirs(dl_dir, exist_ok=True) + torch.save( + self._t.train_batch_generator.state_dict(), + os.path.join(dl_dir, f"rank_{rank}.pt"), + ) + + if self._dist_name != "deepspeed": + save_rng_state(ckpt_dir, rank) + + DistributedInterface().sync() + + if rank == 0: + mark_checkpoint_complete(ckpt_dir) + if self._t.args.save_total_limit is not None: + rotate_checkpoints(self._t.args.output_dir, self._t.args.save_total_limit) + + logger.info_rank0(f"Checkpoint saved to {ckpt_dir}") + + def resume(self, ckpt_path: str) -> None: + """Restore full training state from a checkpoint directory.""" + ckpt_dir = resolve_resume_checkpoint_path(ckpt_path, self._t.args.output_dir) + if ckpt_dir is None: + return + + if not os.path.isdir(ckpt_dir): + raise ValueError(f"Checkpoint directory does not exist: {ckpt_dir}") + + rank = DistributedInterface().get_rank() + + metadata = load_metadata(ckpt_dir) + self._t.global_step = metadata["global_step"] + self._t._resume_epoch = metadata["epoch"] + + if self._dist_name in ("fsdp2", "deepspeed"): + from ...plugins.trainer_plugins.distributed.hub import DistributedPlugin + + DistributedPlugin(self._dist_name).load_checkpoint( + self._t.model, + self._t.optimizer, + ckpt_dir, + processor=self._t.renderer.processor, + ) + else: + _load_standard_training_states( + ckpt_dir, + self._t.model, + self._t.optimizer, + self._t.device, + ) + + if self._dist_name != "deepspeed": + sched_path = os.path.join(ckpt_dir, "scheduler.pt") + if os.path.exists(sched_path): + self._t.lr_scheduler.load_state_dict(torch.load(sched_path, map_location="cpu", weights_only=True)) + + dl_path = os.path.join(ckpt_dir, "dataloader", f"rank_{rank}.pt") + + if os.path.exists(dl_path): + self._t.train_batch_generator.load_state_dict(torch.load(dl_path, map_location="cpu", weights_only=False)) + else: + logger.warning_rank0( + f"Dataloader state file not found at {dl_path}. Skipping Dataloader state restoration." + ) + + if self._dist_name != "deepspeed": + load_rng_state(ckpt_dir, rank) + + logger.info_rank0(f"Resumed from checkpoint: step={self._t.global_step}, epoch={self._t._resume_epoch}") diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py index d839045b3..a68b1f8ab 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py @@ -19,6 +19,7 @@ this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle initialization, backward, gradient accumulation, and model saving. """ +import os from typing import Any, Optional import torch @@ -127,3 +128,22 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None: accelerator.wait_for_everyone() logger.info_rank0(f"Model saved to {output_dir}") + + +def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False) + processor = kwargs.get("processor", None) + + # Always save DeepSpeed state for resume capability + accelerator: Accelerator = model._accelerator # type: ignore[union-attr] + accelerator.save_state(ckpt_dir) + + # Additionally save HF format if requested + if save_ckpt_as_hf: + hf_dir = os.path.join(ckpt_dir, "hf_model") + save_model(model, hf_dir, processor) + + +def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + accelerator: Accelerator = model._accelerator # type: ignore[union-attr] + accelerator.load_state(ckpt_dir) diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index 7ba7130dc..eb905263a 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -18,9 +18,16 @@ import os import torch import torch.distributed as dist +import torch.distributed.checkpoint as dcp import torch.nn as nn from peft.tuners.lora import LoraLayer -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + set_model_state_dict, + set_optimizer_state_dict, +) from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, @@ -66,6 +73,51 @@ def save_model(model: HFModel, output_dir: str, processor: Processor) -> None: logger.info(f"Model saved to {output_dir}") +def save_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + save_ckpt_as_hf = kwargs.get("save_ckpt_as_hf", False) + processor = kwargs.get("processor", None) + + # Always save DCP format for resume capability + options = StateDictOptions(full_state_dict=False, cpu_offload=True) + + model_state = get_model_state_dict(model, options=options) + dcp.save(state_dict=model_state, checkpoint_id=os.path.join(ckpt_dir, "model")) + + optim_state = get_optimizer_state_dict(model, optimizer, options=options) + dcp.save(state_dict=optim_state, checkpoint_id=os.path.join(ckpt_dir, "optimizer")) + + # Additionally save HF format if requested + if save_ckpt_as_hf: + if DistributedInterface().get_rank() == 0: + logger.info("Gathering state dict for saving additional HF format checkpoint...") + + hf_options = StateDictOptions(full_state_dict=True, cpu_offload=True) + hf_state_dict = get_model_state_dict(model, options=hf_options) + + if DistributedInterface().get_rank() == 0: + model_to_save = model.module if hasattr(model, "module") else model + hf_dir = os.path.join(ckpt_dir, "hf_model") + model_to_save.save_pretrained(hf_dir, state_dict=hf_state_dict, max_shard_size="4GB") + if processor is not None: + processor.save_pretrained(hf_dir, max_shard_size="4GB") + + logger.info(f"Additional HF format checkpoint saved to {hf_dir}") + + +def load_checkpoint(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + options = StateDictOptions(full_state_dict=False, cpu_offload=True) + + ckpt_model_dir = os.path.join(ckpt_dir, "model") + model_state = get_model_state_dict(model, options=options) + dcp.load(state_dict=model_state, checkpoint_id=ckpt_model_dir) + set_model_state_dict(model, model_state, options=options) + + ckpt_optim_dir = os.path.join(ckpt_dir, "optimizer") + optim_state = get_optimizer_state_dict(model, optimizer, options=options) + dcp.load(state_dict=optim_state, checkpoint_id=ckpt_optim_dir) + set_optimizer_state_dict(model, optimizer, optim_state, options=options) + + class FSDP2Engine: def __init__(self, dist_config: dict): self.dist_interface = DistributedInterface() diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py index 24b7052c2..47f64a3a6 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -16,6 +16,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +import torch + from ....config.arg_utils import PluginConfig from ....utils.plugin import BasePlugin @@ -43,6 +45,20 @@ def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> N return save_model(model, output_dir, processor) +@DistributedPlugin("fsdp2").register("save_checkpoint") +def save_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + from .fsdp2 import save_checkpoint + + return save_checkpoint(model, optimizer, ckpt_dir, **kwargs) + + +@DistributedPlugin("fsdp2").register("load_checkpoint") +def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: + from .fsdp2 import load_checkpoint + + return load_checkpoint(model, optimizer, ckpt_dir, **kwargs) + + @DistributedPlugin("deepspeed").register() def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: from .deepspeed import DeepSpeedEngine @@ -59,3 +75,17 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) from .deepspeed import save_model return save_model(model, output_dir, processor) + + +@DistributedPlugin("deepspeed").register("save_checkpoint") +def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: + from .deepspeed import save_checkpoint + + return save_checkpoint(model, optimizer, ckpt_dir) + + +@DistributedPlugin("deepspeed").register("load_checkpoint") +def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: + from .deepspeed import load_checkpoint + + return load_checkpoint(model, optimizer, ckpt_dir)