2 Commits

Author SHA1 Message Date
Junyou Su
675ce8cc7f [algo] add ASFT (#10174) 2026-02-12 13:12:14 +08:00
jiaqiw09
ab073f4c13 [v1] add LoRA/Freeze support and merge workflow (#10157) 2026-02-12 13:02:09 +08:00
15 changed files with 805 additions and 14 deletions

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Llama-2-7b
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.1
### dataset
dataset: med
template: llama2
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama2-7b/full/asft2
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 2.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Qwen2.5-7B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.05
### dataset
dataset: math
template: qwen
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2-7b/full/asft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,38 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# Freeze Configuration
peft_config:
name: freeze
freeze_trainable_layers: 2 # Train the last 2 layers
freeze_trainable_modules: all # In these layers, train specific modules
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_freeze
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,7 @@
model: Qwen/Qwen3-4B
peft_config:
name: lora
adapter_name_or_path: ./outputs/test_lora
export_dir: ./merge_lora_model
export_size: 5
infer_dtype: auto

View File

@@ -0,0 +1,39 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
use_asft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the ASFT loss."},
)
asft_alpha: float = field(
default=0.1,
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
)
use_eaft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the EAFT loss."},

View File

@@ -17,6 +17,7 @@
import json
import os
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
ref_model: Optional["torch.nn.Module"] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
self.ref_model = ref_model
if ref_model is not None:
from trl.models.utils import prepare_deepspeed, prepare_fsdp
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_dft_loss:
from ..trainer_utils import dft_loss_func
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from ..trainer_utils import asft_loss_func
self.compute_loss_func = partial(
asft_loss_func,
asft_alpha=finetuning_args.asft_alpha,
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
if self.finetuning_args.use_asft_loss:
with torch.no_grad():
ref_outputs = self.ref_model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
ref_logits = ref_outputs.logits
outputs = model(**inputs)
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
else:
return super().compute_loss(model, inputs, *args, **kwargs)
@override
def prediction_step(

View File

@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
@@ -52,6 +52,10 @@ def run_sft(
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
@@ -124,6 +128,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,

View File

@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
import torch.nn.functional as F
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
@@ -681,6 +682,88 @@ def _dft_cross_entropy(
return loss
def asft_loss_func(
outputs,
labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
# shift for causal LM
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
vocab_size = shift_logits.size(-1)
# flatten
shift_logits = shift_logits.view(-1, vocab_size)
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1).to(shift_logits.device)
return _asft_cross_entropy(
policy_logits=shift_logits,
policy_labels=shift_labels,
ref_logits=shift_ref_logits,
asft_alpha=asft_alpha,
ignore_index=ignore_index,
)
def _asft_cross_entropy(
policy_logits: torch.Tensor,
policy_labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
dft_loss = _dft_cross_entropy(
policy_logits,
policy_labels,
ignore_index=ignore_index,
)
kl_loss = _kl_divergence(
policy_logits,
ref_logits,
policy_labels,
ignore_index=ignore_index,
)
return dft_loss + asft_alpha * kl_loss
def _kl_divergence(
policy_logits: torch.Tensor,
ref_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
# log p(y|x)
log_p = F.log_softmax(policy_logits, dim=-1)
# q(y|x)
q = F.softmax(ref_logits, dim=-1)
# token-wise KL
kl = F.kl_div(
log_p,
q,
reduction="none",
).sum(dim=-1) # [N]
# mask padding tokens
mask = (labels != ignore_index).float()
return (kl * mask).sum() / mask.sum()
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",

View File

@@ -204,6 +204,16 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
state_dict = None
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
if DistributedInterface().get_rank() != 0:
return
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -125,6 +125,11 @@ def launch():
run_chat()
elif command == "merge":
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
merge_and_export_model()
elif command == "env":
raise NotImplementedError("Environment information is not implemented yet.")

View File

@@ -12,14 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
import re
from typing import Literal, TypedDict, Union
from peft import LoraConfig, PeftModel, get_peft_model
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from ...config import InputArgument, get_args
from ...core.model_engine import ModelEngine
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import HFModel
logger = logging.get_logger(__name__)
class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"]
"""Plugin name."""
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
"""Lora rank."""
lora_alpha: int
"""Lora alpha."""
target_modules: list[str]
lora_dropout: float
"""Lora dropout."""
target_modules: Union[list[str], str]
"""Target modules."""
use_rslora: bool
"""Use RS-LoRA."""
use_dora: bool
"""Use DoRA."""
modules_to_save: list[str]
"""Modules to save."""
adapter_name_or_path: Union[list[str], str]
"""Path to the adapter(s)."""
export_dir: str
"""Path to the export directory."""
export_size: int
"""Shard size for the export model."""
export_hub_model_id: str
"""Hub model ID for the export model."""
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
"""Inference data type for the export model."""
export_legacy_format: bool
"""Use legacy format for the export model."""
class FreezeConfigDict(TypedDict, total=False):
@@ -36,22 +64,280 @@ class FreezeConfigDict(TypedDict, total=False):
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: list[str] | None
freeze_trainable_modules: Union[list[str], str]
"""Freeze trainable modules."""
freeze_extra_modules: list[str]
"""Freeze extra modules."""
cast_trainable_params_to_fp32: bool
"""Cast trainable params to fp32."""
class PeftPlugin(BasePlugin):
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config)
return super().__call__(model, config, is_train)
def _find_all_linear_modules(model: HFModel) -> list[str]:
r"""Find all available modules to apply LoRA."""
forbidden_modules = {"lm_head", "output_layer", "output"}
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
return list(module_names)
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
for adapter_path in adapter_name_or_path:
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
logger.info_rank0(f"Merged adapter from {adapter_path}")
return model
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
r"""Loads adapter(s) into the model.
Determine adapter usage based on mode:
- Training: Load the single adapter for continued training.
- Inference: Merge all adapters to clean up the model.
- Unmergeable: Keep the single adapter active without merging.
"""
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
# TODO
# Adapters fix for deepspeed and quant
# Adapters fix for vision
if is_train and len(adapter_name_or_path) > 1:
raise ValueError(
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
"Training will continue on the specified adapter. "
"Please merge multiple adapters before starting a new LoRA adapter."
)
if is_train:
adapter_to_merge = []
adapter_to_resume = adapter_name_or_path[0]
else:
adapter_to_merge = adapter_name_or_path
adapter_to_resume = None
if adapter_to_merge:
model = merge_adapters(model, adapter_to_merge)
if adapter_to_resume is not None:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
if is_train:
logger.info_rank0(
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
"LoRA hyperparameters will be loaded from the adapter itself; "
"the current LoRA configuration will be ignored. "
"Merge the adapter into the base model before training if you want to start a new adapter."
)
return model
@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path:
return load_adapter(model, adapter_name_or_path, is_train)
logger.info_rank0("Fine-tuning method: LoRA")
target_modules = config.get("target_modules", "all")
# Handle target modules
if target_modules == "all":
target_modules = _find_all_linear_modules(model)
elif isinstance(target_modules, str):
target_modules = [target_modules]
logger.info_rank0(f"LoRA target modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=not is_train,
r=config.get("r", 8),
lora_alpha=config.get("lora_alpha", 16),
lora_dropout=config.get("lora_dropout", 0.05),
use_rslora=config.get("use_rslora", False),
use_dora=config.get("use_dora", False),
target_modules=target_modules,
modules_to_save=config.get("modules_to_save", None),
)
model = get_peft_model(model, peft_config)
if is_train:
model.print_trainable_parameters()
return model
@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
logger.info_rank0("Fine-tuning method: Freeze")
if not is_train:
return model
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
freeze_extra_modules = config.get("freeze_extra_modules", [])
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
if isinstance(freeze_trainable_modules, str):
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
if isinstance(freeze_extra_modules, str):
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
# Get number of layers
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if freeze_trainable_layers > 0:
# last n layers
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
else:
# first n layers
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
# Identify hidden and non-hidden modules
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name:
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
# Build list of trainable layer patterns
trainable_layers = []
for module_name in freeze_trainable_modules:
if module_name == "all":
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.")
elif module_name in hidden_modules:
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.{module_name}")
else:
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
# Add extra modules
if freeze_extra_modules:
for module_name in freeze_extra_modules:
if module_name in non_hidden_modules:
trainable_layers.append(module_name)
else:
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
# TODO
# Multi-modal special handling
# Set requires_grad
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
param.requires_grad_(True)
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
else:
param.requires_grad_(False)
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
# Count trainable params for verification
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logger.info_rank0(
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
)
return model
def merge_and_export_model(args: InputArgument = None):
model_args, _, _, _ = get_args(args)
export_config = model_args.peft_config
if export_config is None:
raise ValueError("Please specify peft_config to merge and export model.")
export_dir = export_config.get("export_dir")
if export_dir is None:
raise ValueError("Please specify export_dir.")
export_size = export_config.get("export_size", 5)
export_hub_model_id = export_config.get("export_hub_model_id")
infer_dtype = export_config.get("infer_dtype", "auto")
export_legacy_format = export_config.get("export_legacy_format", False)
adapters = None
if export_config.get("name") == "lora":
adapters = export_config.get("adapter_name_or_path")
else:
raise ValueError("Currently merge and export model function is only supported for lora.")
if adapters is None:
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
logger.info_rank0("Loading model for export...")
model_engine = ModelEngine(model_args, is_train=False)
model = model_engine.model
tokenizer = model_engine.processor
if infer_dtype == "auto":
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
model = model.to(torch.bfloat16)
logger.info_rank0("Converted model to bfloat16.")
else:
target_dtype = getattr(torch, infer_dtype)
model = model.to(target_dtype)
logger.info_rank0(f"Converted model to {infer_dtype}.")
logger.info_rank0(f"Exporting model to {export_dir}...")
model.save_pretrained(
export_dir,
max_shard_size=f"{export_size}GB",
safe_serialization=not export_legacy_format,
)
if tokenizer is not None:
try:
if hasattr(tokenizer, "padding_side"):
tokenizer.padding_side = "left"
tokenizer.save_pretrained(export_dir)
except Exception as e:
logger.warning(f"Failed to save tokenizer: {e}")
if export_hub_model_id:
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
model.push_to_hub(export_hub_model_id)
if tokenizer is not None:
tokenizer.push_to_hub(export_hub_model_id)
logger.info_rank0("Model exported successfully.")

View File

@@ -24,6 +24,7 @@ from torch.distributed.fsdp import (
fully_shard,
)
from transformers import PreTrainedModel
from peft.tuners.lora import LoraLayer
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
@@ -93,6 +94,10 @@ class FSDP2Engine:
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
if self.fsdp_mesh is None:
@@ -110,6 +115,26 @@ class FSDP2Engine:
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
lora_modules.append(module)
for module in lora_modules:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info(f"Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules():
should_wrap = False
@@ -154,7 +179,6 @@ class FSDP2Engine:
)
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:

View File

@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args)
model_engine = ModelEngine(model_args, is_train=True)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,

View File

@@ -0,0 +1,156 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from llamafactory.v1.plugins.model_plugins import peft as peft_module
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
TINY_MODEL = "llamafactory/tiny-random-qwen3"
@pytest.fixture(scope="module")
def model_path():
return TINY_MODEL
@pytest.fixture(scope="function")
def model(model_path):
return AutoModelForCausalLM.from_pretrained(model_path)
@pytest.fixture(scope="function")
def tokenizer(model_path):
return AutoTokenizer.from_pretrained(model_path)
@pytest.fixture(scope="function")
def adapter_path(tmp_path):
# Create a dummy adapter
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
peft_model = get_peft_model(base_model, lora_config)
save_path = tmp_path / "test_adapter"
peft_model.save_pretrained(save_path)
return str(save_path)
def test_find_all_linear_modules(model):
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
modules = peft_module._find_all_linear_modules(model)
expected_subset = {"q_proj", "v_proj"}
assert expected_subset.issubset(set(modules))
def test_get_lora_model(model):
"""Verify a PeftModel is returned and LoRA config takes effect."""
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
model = peft_module.get_lora_model(model, config, is_train=True)
assert isinstance(model, PeftModel)
assert model.peft_config["default"].r == 8
assert "q_proj" in model.peft_config["default"].target_modules
def test_get_freeze_model_layers(model):
"""Verify layer-wise freezing: only the last layer stays trainable."""
# Freeze all but last layer
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
# Ensure we start with something known
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
assert num_layers > 0
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name:
assert param.requires_grad, f"{name} should be trainable"
elif "layers.0" in name and num_layers > 1:
assert not param.requires_grad, f"{name} should be frozen"
def test_get_freeze_model_modules(model):
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
# Freeze specific modules (e.g. only self_attn)
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
assert param.requires_grad, f"{name} should be trainable"
else:
assert not param.requires_grad, f"{name} should be frozen"
def test_load_adapter_single_for_inference(model, adapter_path):
"""Verify single adapter is merged+unloaded in inference mode."""
# Test loading single adapter for inference (merge and unload)
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
assert not isinstance(model_result, PeftModel)
def test_load_adapter_resume_train(model, adapter_path):
"""Verify training mode returns a trainable PeftModel."""
# Test loading for training
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
assert isinstance(model_result, PeftModel)
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
"""Verify multiple adapters are rejected in training mode."""
with pytest.raises(ValueError, match="only a single LoRA adapter"):
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
def test_load_adapter_infer_multiple_merges(model, adapter_path):
"""Verify multiple adapters are merged in inference mode."""
# Test merging multiple adapters
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
assert not isinstance(model_result, PeftModel)
def test_merge_and_export_model(tmp_path, adapter_path):
"""Verify merge_and_export_model produces export artifacts."""
export_dir = tmp_path / "export"
args_dict = {
"model": TINY_MODEL,
"peft_config": {
"name": "lora",
"adapter_name_or_path": adapter_path,
"export_dir": str(export_dir),
"export_size": 1,
"infer_dtype": "float16",
},
}
merge_and_export_model(args_dict)
assert export_dir.exists()
assert (export_dir / "config.json").exists()
assert (export_dir / "model.safetensors").exists()
assert (export_dir / "tokenizer_config.json").exists()