mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[feat] support HyperParallel PT training and activation optimization (#10370)
This commit is contained in:
@@ -487,7 +487,7 @@ class FinetuningArguments(
|
|||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||||
"Only supported for the 'sft' stage with full fine-tuning."
|
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .workflow import run_sft
|
from .workflow import run_pt, run_sft
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["run_sft"]
|
__all__ = ["run_pt", "run_sft"]
|
||||||
|
|||||||
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""HyperParallel distributed trainer for LlamaFactory."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import types
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hyper_parallel.integration.llamafactory import (
|
||||||
|
HSDPModule,
|
||||||
|
HyperParallelArguments,
|
||||||
|
export_to_hf_format,
|
||||||
|
fsdp2_prepare_model,
|
||||||
|
hsdp_sync_stream,
|
||||||
|
load_hsdp_model,
|
||||||
|
load_hsdp_optimizer_and_scheduler,
|
||||||
|
save_hsdp_checkpoint,
|
||||||
|
wrap_optimizer_with_skip_dtensor_dispatch,
|
||||||
|
)
|
||||||
|
from hyper_parallel.integration.llamafactory import (
|
||||||
|
clip_grad_norm_ as hp_clip_grad_norm_,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..sft.trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||||
|
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
||||||
|
|
||||||
|
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
|
||||||
|
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hp_args: HyperParallelArguments,
|
||||||
|
finetuning_args=None,
|
||||||
|
processor=None,
|
||||||
|
ref_model: Optional[nn.Module] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self._hp_args = hp_args
|
||||||
|
|
||||||
|
# Let CustomSeq2SeqTrainer handle everything except ref_model —
|
||||||
|
# Custom would prepare it with accelerate's fsdp2_prepare_model,
|
||||||
|
# but we need HP's version instead.
|
||||||
|
super().__init__(
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
processor=processor,
|
||||||
|
ref_model=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not getattr(self.accelerator, "is_fsdp2", False):
|
||||||
|
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
||||||
|
|
||||||
|
# Prepare ref_model with HP's fsdp2_prepare_model
|
||||||
|
self.ref_model = ref_model
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
|
||||||
|
|
||||||
|
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
||||||
|
self._orig_fsdp2_prepare_model = None
|
||||||
|
self._accelerator_patches_active = False
|
||||||
|
|
||||||
|
def _activate_accelerator_patches(self) -> None:
|
||||||
|
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
||||||
|
if self._accelerator_patches_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||||
|
|
||||||
|
hp_args = self._hp_args
|
||||||
|
|
||||||
|
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
||||||
|
|
||||||
|
def _hp_fsdp2_prepare_model(accelerator, model):
|
||||||
|
return fsdp2_prepare_model(accelerator, model, hp_args)
|
||||||
|
|
||||||
|
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
||||||
|
|
||||||
|
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
|
||||||
|
if getattr(accelerator, "is_fsdp2", False):
|
||||||
|
accelerator.unscale_gradients()
|
||||||
|
parameter_list = list(parameters)
|
||||||
|
parameter_ids = {id(param) for param in parameter_list}
|
||||||
|
for model in accelerator._models: # pylint: disable=protected-access
|
||||||
|
if not isinstance(model, HSDPModule):
|
||||||
|
continue
|
||||||
|
model_param_ids = {id(param) for param in model.parameters()}
|
||||||
|
if parameter_ids and parameter_ids.issubset(model_param_ids):
|
||||||
|
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
|
||||||
|
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
|
||||||
|
|
||||||
|
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
|
||||||
|
self._accelerator_patches_active = True
|
||||||
|
|
||||||
|
def _restore_accelerator_patches(self) -> None:
|
||||||
|
"""Restore original Accelerate methods."""
|
||||||
|
if not self._accelerator_patches_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||||
|
|
||||||
|
if self._orig_fsdp2_prepare_model is not None:
|
||||||
|
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
|
||||||
|
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
|
||||||
|
self._accelerator_patches_active = False
|
||||||
|
|
||||||
|
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
|
||||||
|
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
|
||||||
|
del dataloader
|
||||||
|
if isinstance(model, HSDPModule):
|
||||||
|
return model
|
||||||
|
if training and getattr(self.accelerator, "is_fsdp2", False):
|
||||||
|
return model
|
||||||
|
return super()._wrap_model(model, training=training)
|
||||||
|
|
||||||
|
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
||||||
|
"""Skip redundant device moves for HSDP-wrapped models."""
|
||||||
|
if isinstance(model, HSDPModule):
|
||||||
|
return model
|
||||||
|
if device is None:
|
||||||
|
return model
|
||||||
|
return model.to(device)
|
||||||
|
|
||||||
|
def train(self, *args, **kwargs):
|
||||||
|
"""Activate HP patches during training and restore afterwards."""
|
||||||
|
self._activate_accelerator_patches()
|
||||||
|
try:
|
||||||
|
return super().train(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
self._restore_accelerator_patches()
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Any],
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Standard training step with HSDP gradient synchronization."""
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
||||||
|
if isinstance(model, HSDPModule):
|
||||||
|
model.set_is_last_backward(sync_gradients)
|
||||||
|
model.set_requires_gradient_sync(sync_gradients)
|
||||||
|
|
||||||
|
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
|
||||||
|
with compute_loss_context_manager():
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
|
||||||
|
if self.args.n_gpu > 1:
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
|
||||||
|
loss = loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
if isinstance(model, HSDPModule) and sync_gradients:
|
||||||
|
hsdp_sync_stream()
|
||||||
|
|
||||||
|
return loss.detach()
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
"""Create optimizer and wrap step with SkipDTensorDispatch."""
|
||||||
|
optimizer = super().create_optimizer()
|
||||||
|
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
|
||||||
|
"""Save model/optimizer shards per-rank and scheduler."""
|
||||||
|
save_hsdp_checkpoint(
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.lr_scheduler,
|
||||||
|
output_dir=output_dir,
|
||||||
|
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
|
||||||
|
"""Load model from HSDP sharded checkpoint."""
|
||||||
|
target = model if model is not None else self.model
|
||||||
|
loaded = load_hsdp_model(target, resume_from_checkpoint)
|
||||||
|
if not loaded:
|
||||||
|
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
|
||||||
|
self._pending_hsdp_checkpoint = resume_from_checkpoint
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
|
||||||
|
"""Load optimizer/scheduler from per-rank checkpoint files."""
|
||||||
|
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
|
||||||
|
if ckpt_dir is None:
|
||||||
|
return
|
||||||
|
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
|
||||||
|
|
||||||
|
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||||
|
"""Save model weights in HuggingFace-compatible format."""
|
||||||
|
save_dir = output_dir or self.args.output_dir
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)
|
||||||
@@ -12,8 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
|
|||||||
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..callbacks import SaveProcessorCallback
|
|
||||||
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||||
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||||
|
from .trainer import HyperParallelTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
|
||||||
|
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
|
||||||
|
|
||||||
|
When activation optimization is enabled, skip native gradient checkpointing
|
||||||
|
so HP can install its own via ``setup_activation_optimization``.
|
||||||
|
"""
|
||||||
|
if not is_hyper_parallel_available():
|
||||||
|
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||||
|
|
||||||
|
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
||||||
|
|
||||||
|
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||||
|
if hp_args.activation_mode != "none":
|
||||||
|
model_args.disable_gradient_checkpointing = True
|
||||||
|
return hp_args
|
||||||
|
|
||||||
|
|
||||||
|
def run_pt(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
|
):
|
||||||
|
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||||
|
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
trainer = HyperParallelTrainer(
|
||||||
|
hp_args=hp_args,
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**dataset_module,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
keys = ["loss"]
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||||
|
else:
|
||||||
|
keys += ["eval_loss"]
|
||||||
|
|
||||||
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
for key in dataset_module["eval_dataset"].keys():
|
||||||
|
try:
|
||||||
|
perplexity = math.exp(metrics[f"eval_{key}_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
|
||||||
|
metrics[f"eval_{key}_perplexity"] = perplexity
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
|
||||||
|
metrics["eval_perplexity"] = perplexity
|
||||||
|
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
@@ -43,13 +130,7 @@ def run_sft(
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
if not is_hyper_parallel_available():
|
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
|
||||||
|
|
||||||
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
|
||||||
HyperParallelArguments,
|
|
||||||
HyperParallelTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
@@ -94,25 +175,6 @@ def run_sft(
|
|||||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
|
||||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
|
||||||
|
|
||||||
callbacks = list(callbacks or [])
|
|
||||||
processor = tokenizer_module.get("processor")
|
|
||||||
if processor is not None:
|
|
||||||
callbacks.append(SaveProcessorCallback(processor))
|
|
||||||
|
|
||||||
compute_loss_func = None
|
|
||||||
if finetuning_args.use_dft_loss:
|
|
||||||
compute_loss_func = dft_loss_func
|
|
||||||
elif finetuning_args.use_eaft_loss:
|
|
||||||
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
|
||||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
|
||||||
)
|
|
||||||
elif finetuning_args.use_asft_loss:
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
|
||||||
|
|
||||||
trainer = HyperParallelTrainer(
|
trainer = HyperParallelTrainer(
|
||||||
hp_args=hp_args,
|
hp_args=hp_args,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -122,20 +184,11 @@ def run_sft(
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
gen_kwargs=gen_kwargs,
|
gen_kwargs=gen_kwargs,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
compute_loss_func=compute_loss_func,
|
|
||||||
**dataset_module,
|
**dataset_module,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**metric_module,
|
**metric_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
|
||||||
|
|
||||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
|
||||||
trainer.add_callback(BAdamCallback)
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
|||||||
@@ -88,9 +88,16 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
|
||||||
if not is_hyper_parallel_available():
|
if not is_hyper_parallel_available():
|
||||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
raise ImportError(
|
||||||
|
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||||
|
)
|
||||||
|
if finetuning_args.stage == "pt":
|
||||||
|
from .hyper_parallel import run_pt as run_pt_hp
|
||||||
|
|
||||||
|
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
else:
|
||||||
from .hyper_parallel import run_sft as run_sft_hp
|
from .hyper_parallel import run_sft as run_sft_hp
|
||||||
|
|
||||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
|||||||
Reference in New Issue
Block a user