mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 07:45:59 +08:00
Compare commits
8 Commits
184304b5b4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aab9b400bb | ||
|
|
50599c719b | ||
|
|
a0f3ad0cee | ||
|
|
f80e15dbb4 | ||
|
|
991267fd3b | ||
|
|
5c52afa30d | ||
|
|
675ce8cc7f | ||
|
|
ab073f4c13 |
1
.github/workflows/tests_cuda.yml
vendored
1
.github/workflows/tests_cuda.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
||||
uv venv
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
uv pip install -r requirements/bitsandbytes.txt
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
### model
|
||||
model_name_or_path: models/Llama-2-7b
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
use_asft_loss: true
|
||||
asft_alpha: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: med
|
||||
template: llama2
|
||||
cutoff_len: 2048
|
||||
max_samples: 10000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama2-7b/full/asft2
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 2.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
### model
|
||||
model_name_or_path: models/Qwen2.5-7B
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
use_asft_loss: true
|
||||
asft_alpha: 0.05
|
||||
|
||||
### dataset
|
||||
dataset: math
|
||||
template: qwen
|
||||
cutoff_len: 2048
|
||||
max_samples: 10000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2-7b/full/asft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 5.0e-5
|
||||
num_train_epochs: 1.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# Freeze Configuration
|
||||
peft_config:
|
||||
name: freeze
|
||||
freeze_trainable_layers: 2 # Train the last 2 layers
|
||||
freeze_trainable_modules: all # In these layers, train specific modules
|
||||
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_freeze
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 2.0e-5
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
25
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
25
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
dist_config:
|
||||
name: deepspeed
|
||||
config_file: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/Qwen3-0.6B-deepspeed
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
7
examples/v1/train_lora/export_lora.yaml
Normal file
7
examples/v1/train_lora/export_lora.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
peft_config:
|
||||
name: lora
|
||||
adapter_name_or_path: ./outputs/test_lora
|
||||
export_dir: ./merge_lora_model
|
||||
export_size: 5
|
||||
infer_dtype: auto
|
||||
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
43
examples/v1/train_qlora/quantization.yaml
Normal file
43
examples/v1/train_qlora/quantization.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
# Quantization Config
|
||||
quant_config:
|
||||
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
|
||||
quantization_bit: 4 # choice: 8/4(bnb)
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_quantization
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -65,6 +65,7 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
|
||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the DFT loss."},
|
||||
)
|
||||
use_asft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the ASFT loss."},
|
||||
)
|
||||
asft_alpha: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
|
||||
)
|
||||
use_eaft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the EAFT loss."},
|
||||
|
||||
@@ -142,6 +142,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen3_next":
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||
|
||||
@@ -82,9 +82,34 @@ def _check_model_support(model_args: "ModelArguments"):
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||
raise ValueError(
|
||||
f"Model {config.model_type} is not supported by mcore_adapter."
|
||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||
)
|
||||
|
||||
|
||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
|
||||
return
|
||||
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||
params_to_freeze.extend(["vision_model.pos_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
|
||||
if params_to_freeze:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in params_to_freeze):
|
||||
p.requires_grad_(False)
|
||||
|
||||
def run_pt(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -161,22 +186,8 @@ def run_sft(
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
|
||||
if params_to_freeze:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in params_to_freeze):
|
||||
p.requires_grad_(False)
|
||||
# optional freezing for qwen_vl series
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
@@ -229,6 +240,8 @@ def run_dpo(
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
if finetuning_args.use_ref_model:
|
||||
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
ref_model = AutoModel.from_config(ref_config)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
ref_model: Optional["torch.nn.Module"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
self.ref_model = ref_model
|
||||
|
||||
if ref_model is not None:
|
||||
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||
|
||||
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||
if self.accelerator.is_fsdp2:
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||
else:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_dft_loss:
|
||||
from ..trainer_utils import dft_loss_func
|
||||
|
||||
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from ..trainer_utils import asft_loss_func
|
||||
|
||||
self.compute_loss_func = partial(
|
||||
asft_loss_func,
|
||||
asft_alpha=finetuning_args.asft_alpha,
|
||||
)
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
if self.finetuning_args.use_asft_loss:
|
||||
with torch.no_grad():
|
||||
ref_outputs = self.ref_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
ref_logits = ref_outputs.logits
|
||||
outputs = model(**inputs)
|
||||
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
|
||||
else:
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
@@ -52,6 +52,10 @@ def run_sft(
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
ref_model = None
|
||||
if finetuning_args.use_asft_loss:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
@@ -124,6 +128,7 @@ def run_sft(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
|
||||
@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
@@ -681,6 +682,88 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def asft_loss_func(
|
||||
outputs,
|
||||
labels: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
asft_alpha: float = 0.1,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
|
||||
logits = logits.float()
|
||||
|
||||
# shift for causal LM
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
|
||||
|
||||
vocab_size = shift_logits.size(-1)
|
||||
|
||||
# flatten
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
|
||||
return _asft_cross_entropy(
|
||||
policy_logits=shift_logits,
|
||||
policy_labels=shift_labels,
|
||||
ref_logits=shift_ref_logits,
|
||||
asft_alpha=asft_alpha,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
|
||||
def _asft_cross_entropy(
|
||||
policy_logits: torch.Tensor,
|
||||
policy_labels: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
asft_alpha: float = 0.1,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
dft_loss = _dft_cross_entropy(
|
||||
policy_logits,
|
||||
policy_labels,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
kl_loss = _kl_divergence(
|
||||
policy_logits,
|
||||
ref_logits,
|
||||
policy_labels,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
return dft_loss + asft_alpha * kl_loss
|
||||
|
||||
|
||||
def _kl_divergence(
|
||||
policy_logits: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
# log p(y|x)
|
||||
log_p = F.log_softmax(policy_logits, dim=-1)
|
||||
|
||||
# q(y|x)
|
||||
q = F.softmax(ref_logits, dim=-1)
|
||||
|
||||
# token-wise KL
|
||||
kl = F.kl_div(
|
||||
log_p,
|
||||
q,
|
||||
reduction="none",
|
||||
).sum(dim=-1) # [N]
|
||||
|
||||
# mask padding tokens
|
||||
mask = (labels != ignore_index).float()
|
||||
|
||||
return (kl * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
model = model.to(output_dtype)
|
||||
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
|
||||
save_kwargs = {
|
||||
"save_directory": model_args.export_dir,
|
||||
"max_shard_size": f"{model_args.export_size}GB",
|
||||
}
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||
|
||||
model.save_pretrained(**save_kwargs)
|
||||
|
||||
if model_args.export_hub_model_id is not None:
|
||||
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
|
||||
push_kwargs = {
|
||||
"max_shard_size": f"{model_args.export_size}GB",
|
||||
}
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||
|
||||
model.push_to_hub(
|
||||
model_args.export_hub_model_id,
|
||||
token=model_args.hf_hub_token,
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
**push_kwargs,
|
||||
)
|
||||
|
||||
if finetuning_args.stage == "rm":
|
||||
|
||||
@@ -76,19 +76,28 @@ class BaseTrainer:
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
if self.args.dist_config is not None:
|
||||
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
||||
else:
|
||||
shard_need_optimizer = False
|
||||
self._accelerate_engine = None
|
||||
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||
|
||||
if shard_need_optimizer:
|
||||
if dist_name == "deepspeed":
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
self._deepspeed_engine = DistributedPlugin("deepspeed")(
|
||||
self.model,
|
||||
self.args.dist_config,
|
||||
num_micro_batch=self.train_batch_generator.num_micro_batch,
|
||||
micro_batch_size=self.args.micro_batch_size,
|
||||
)
|
||||
self._init_optimizer()
|
||||
self._shard_model()
|
||||
self._init_lr_scheduler()
|
||||
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
else:
|
||||
# fsdp2 / DDP / no dist
|
||||
self._shard_model()
|
||||
self._init_optimizer()
|
||||
|
||||
self._init_lr_scheduler()
|
||||
self._init_lr_scheduler()
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
@@ -171,25 +180,35 @@ class BaseTrainer:
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
for micro_batch in micro_batches:
|
||||
num_micro = len(micro_batches)
|
||||
for i, micro_batch in enumerate(micro_batches):
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
|
||||
loss.backward()
|
||||
if self._deepspeed_engine is not None:
|
||||
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
|
||||
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
|
||||
self._deepspeed_engine.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
step_loss += loss.item()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||
if self._deepspeed_engine is not None:
|
||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
@@ -203,7 +222,14 @@ class BaseTrainer:
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
DistributedPlugin(self.args.dist_config.name).save_model(
|
||||
self.model, self.args.output_dir, self.renderer.processor
|
||||
)
|
||||
else:
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
@@ -90,6 +90,26 @@ class ModelEngine:
|
||||
Transformers can choose the proper model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {"device_map": init_device}
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
|
||||
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
|
||||
init_kwargs=init_kwargs,
|
||||
config=self.model_config,
|
||||
tokenizer=self.processor,
|
||||
model_args=self.args,
|
||||
is_trainable=self.is_train,
|
||||
)
|
||||
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
@@ -107,14 +127,8 @@ class ModelEngine:
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
assert self.args.quant_config is None, "Quantization is not supported with meta device."
|
||||
with init_empty_weights():
|
||||
model = AutoClass.from_config(self.model_config)
|
||||
else:
|
||||
@@ -122,8 +136,8 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=init_device,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
|
||||
@@ -125,6 +125,11 @@ def launch():
|
||||
|
||||
run_chat()
|
||||
|
||||
elif command == "merge":
|
||||
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||
|
||||
merge_and_export_model()
|
||||
|
||||
elif command == "env":
|
||||
raise NotImplementedError("Environment information is not implemented yet.")
|
||||
|
||||
|
||||
@@ -12,14 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
import re
|
||||
from typing import Literal, TypedDict, Union
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
|
||||
from ...config import InputArgument, get_args
|
||||
from ...core.model_engine import ModelEngine
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import HFModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LoraConfigDict(TypedDict, total=False):
|
||||
name: Literal["lora"]
|
||||
"""Plugin name."""
|
||||
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
|
||||
"""Lora rank."""
|
||||
lora_alpha: int
|
||||
"""Lora alpha."""
|
||||
target_modules: list[str]
|
||||
lora_dropout: float
|
||||
"""Lora dropout."""
|
||||
target_modules: Union[list[str], str]
|
||||
"""Target modules."""
|
||||
use_rslora: bool
|
||||
"""Use RS-LoRA."""
|
||||
use_dora: bool
|
||||
"""Use DoRA."""
|
||||
modules_to_save: list[str]
|
||||
"""Modules to save."""
|
||||
adapter_name_or_path: Union[list[str], str]
|
||||
"""Path to the adapter(s)."""
|
||||
export_dir: str
|
||||
"""Path to the export directory."""
|
||||
export_size: int
|
||||
"""Shard size for the export model."""
|
||||
export_hub_model_id: str
|
||||
"""Hub model ID for the export model."""
|
||||
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
|
||||
"""Inference data type for the export model."""
|
||||
export_legacy_format: bool
|
||||
"""Use legacy format for the export model."""
|
||||
|
||||
|
||||
class FreezeConfigDict(TypedDict, total=False):
|
||||
@@ -36,22 +64,280 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: list[str] | None
|
||||
freeze_trainable_modules: Union[list[str], str]
|
||||
"""Freeze trainable modules."""
|
||||
freeze_extra_modules: list[str]
|
||||
"""Freeze extra modules."""
|
||||
cast_trainable_params_to_fp32: bool
|
||||
"""Cast trainable params to fp32."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||
return super().__call__(model, config)
|
||||
return super().__call__(model, config, is_train)
|
||||
|
||||
|
||||
def _find_all_linear_modules(model: HFModel) -> list[str]:
|
||||
r"""Find all available modules to apply LoRA."""
|
||||
forbidden_modules = {"lm_head", "output_layer", "output"}
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
continue
|
||||
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
|
||||
if not isinstance(adapter_name_or_path, list):
|
||||
adapter_name_or_path = [adapter_name_or_path]
|
||||
|
||||
for adapter_path in adapter_name_or_path:
|
||||
model = PeftModel.from_pretrained(model, adapter_path)
|
||||
model = model.merge_and_unload()
|
||||
logger.info_rank0(f"Merged adapter from {adapter_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
|
||||
r"""Loads adapter(s) into the model.
|
||||
|
||||
Determine adapter usage based on mode:
|
||||
- Training: Load the single adapter for continued training.
|
||||
- Inference: Merge all adapters to clean up the model.
|
||||
- Unmergeable: Keep the single adapter active without merging.
|
||||
"""
|
||||
if not isinstance(adapter_name_or_path, list):
|
||||
adapter_name_or_path = [adapter_name_or_path]
|
||||
|
||||
# TODO
|
||||
# Adapters fix for deepspeed and quant
|
||||
# Adapters fix for vision
|
||||
|
||||
if is_train and len(adapter_name_or_path) > 1:
|
||||
raise ValueError(
|
||||
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
|
||||
"Training will continue on the specified adapter. "
|
||||
"Please merge multiple adapters before starting a new LoRA adapter."
|
||||
)
|
||||
|
||||
if is_train:
|
||||
adapter_to_merge = []
|
||||
adapter_to_resume = adapter_name_or_path[0]
|
||||
else:
|
||||
adapter_to_merge = adapter_name_or_path
|
||||
adapter_to_resume = None
|
||||
|
||||
if adapter_to_merge:
|
||||
model = merge_adapters(model, adapter_to_merge)
|
||||
|
||||
if adapter_to_resume is not None:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
|
||||
if is_train:
|
||||
logger.info_rank0(
|
||||
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
|
||||
"LoRA hyperparameters will be loaded from the adapter itself; "
|
||||
"the current LoRA configuration will be ignored. "
|
||||
"Merge the adapter into the base model before training if you want to start a new adapter."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
return load_adapter(model, adapter_name_or_path, is_train)
|
||||
|
||||
logger.info_rank0("Fine-tuning method: LoRA")
|
||||
|
||||
target_modules = config.get("target_modules", "all")
|
||||
|
||||
# Handle target modules
|
||||
if target_modules == "all":
|
||||
target_modules = _find_all_linear_modules(model)
|
||||
elif isinstance(target_modules, str):
|
||||
target_modules = [target_modules]
|
||||
|
||||
logger.info_rank0(f"LoRA target modules: {target_modules}")
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=not is_train,
|
||||
r=config.get("r", 8),
|
||||
lora_alpha=config.get("lora_alpha", 16),
|
||||
lora_dropout=config.get("lora_dropout", 0.05),
|
||||
use_rslora=config.get("use_rslora", False),
|
||||
use_dora=config.get("use_dora", False),
|
||||
target_modules=target_modules,
|
||||
modules_to_save=config.get("modules_to_save", None),
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if is_train:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
|
||||
logger.info_rank0("Fine-tuning method: Freeze")
|
||||
|
||||
if not is_train:
|
||||
return model
|
||||
|
||||
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
|
||||
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
|
||||
freeze_extra_modules = config.get("freeze_extra_modules", [])
|
||||
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
|
||||
|
||||
if isinstance(freeze_trainable_modules, str):
|
||||
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
|
||||
|
||||
if isinstance(freeze_extra_modules, str):
|
||||
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
|
||||
|
||||
# Get number of layers
|
||||
num_layers = (
|
||||
getattr(model.config, "num_hidden_layers", None)
|
||||
or getattr(model.config, "num_layers", None)
|
||||
or getattr(model.config, "n_layer", None)
|
||||
)
|
||||
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if freeze_trainable_layers > 0:
|
||||
# last n layers
|
||||
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
|
||||
else:
|
||||
# first n layers
|
||||
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
|
||||
|
||||
# Identify hidden and non-hidden modules
|
||||
hidden_modules = set()
|
||||
non_hidden_modules = set()
|
||||
for name, _ in model.named_parameters():
|
||||
if ".0." in name:
|
||||
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name:
|
||||
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
if re.search(r"\.\d+\.", name) is None:
|
||||
non_hidden_modules.add(name.split(".")[-2])
|
||||
|
||||
# Build list of trainable layer patterns
|
||||
trainable_layers = []
|
||||
for module_name in freeze_trainable_modules:
|
||||
if module_name == "all":
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(f".{idx:d}.")
|
||||
elif module_name in hidden_modules:
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(f".{idx:d}.{module_name}")
|
||||
else:
|
||||
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
|
||||
|
||||
# Add extra modules
|
||||
if freeze_extra_modules:
|
||||
for module_name in freeze_extra_modules:
|
||||
if module_name in non_hidden_modules:
|
||||
trainable_layers.append(module_name)
|
||||
else:
|
||||
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
|
||||
|
||||
# TODO
|
||||
# Multi-modal special handling
|
||||
|
||||
# Set requires_grad
|
||||
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
):
|
||||
param.requires_grad_(True)
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
|
||||
|
||||
# Count trainable params for verification
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
all_params = sum(p.numel() for p in model.parameters())
|
||||
logger.info_rank0(
|
||||
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def merge_and_export_model(args: InputArgument = None):
|
||||
model_args, _, _, _ = get_args(args)
|
||||
|
||||
export_config = model_args.peft_config
|
||||
if export_config is None:
|
||||
raise ValueError("Please specify peft_config to merge and export model.")
|
||||
|
||||
export_dir = export_config.get("export_dir")
|
||||
if export_dir is None:
|
||||
raise ValueError("Please specify export_dir.")
|
||||
|
||||
export_size = export_config.get("export_size", 5)
|
||||
export_hub_model_id = export_config.get("export_hub_model_id")
|
||||
infer_dtype = export_config.get("infer_dtype", "auto")
|
||||
export_legacy_format = export_config.get("export_legacy_format", False)
|
||||
|
||||
adapters = None
|
||||
if export_config.get("name") == "lora":
|
||||
adapters = export_config.get("adapter_name_or_path")
|
||||
else:
|
||||
raise ValueError("Currently merge and export model function is only supported for lora.")
|
||||
|
||||
if adapters is None:
|
||||
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
|
||||
|
||||
logger.info_rank0("Loading model for export...")
|
||||
model_engine = ModelEngine(model_args, is_train=False)
|
||||
model = model_engine.model
|
||||
tokenizer = model_engine.processor
|
||||
|
||||
if infer_dtype == "auto":
|
||||
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
|
||||
model = model.to(torch.bfloat16)
|
||||
logger.info_rank0("Converted model to bfloat16.")
|
||||
else:
|
||||
target_dtype = getattr(torch, infer_dtype)
|
||||
model = model.to(target_dtype)
|
||||
logger.info_rank0(f"Converted model to {infer_dtype}.")
|
||||
|
||||
logger.info_rank0(f"Exporting model to {export_dir}...")
|
||||
model.save_pretrained(
|
||||
export_dir,
|
||||
max_shard_size=f"{export_size}GB",
|
||||
safe_serialization=not export_legacy_format,
|
||||
)
|
||||
if tokenizer is not None:
|
||||
try:
|
||||
if hasattr(tokenizer, "padding_side"):
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.save_pretrained(export_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save tokenizer: {e}")
|
||||
|
||||
if export_hub_model_id:
|
||||
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
|
||||
model.push_to_hub(export_hub_model_id)
|
||||
if tokenizer is not None:
|
||||
tokenizer.push_to_hub(export_hub_model_id)
|
||||
|
||||
logger.info_rank0("Model exported successfully.")
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from ...accelerator.helper import get_current_device
|
||||
from ...config.model_args import ModelArguments
|
||||
from ...utils import logging
|
||||
from ...utils.packages import check_version
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantizationPlugin(BasePlugin):
|
||||
r"""Plugin for model quantization."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
init_kwargs: dict[str, Any] = None,
|
||||
config: "PretrainedConfig" = None,
|
||||
tokenizer: "PreTrainedTokenizer" = None,
|
||||
model_args: "ModelArguments" = None,
|
||||
is_trainable: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return super().__call__(
|
||||
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
|
||||
)
|
||||
|
||||
|
||||
@QuantizationPlugin("auto").register()
|
||||
def quantization_auto(
|
||||
init_kwargs: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Automatic quantization selection, only support bnb currently.
|
||||
|
||||
Args:
|
||||
init_kwargs (dict[str, Any]): The kwargs for model initialization.
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The updated kwargs for model initialization.
|
||||
"""
|
||||
model_args: ModelArguments = kwargs.get("model_args", None)
|
||||
quant_config = model_args.quant_config
|
||||
|
||||
quantization_bit = quant_config.get("quantization_bit", None)
|
||||
if quantization_bit is not None:
|
||||
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
|
||||
if quantization_bit in [8, 4]:
|
||||
return quantization_with_bnb(init_kwargs, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
|
||||
logger.warning_rank0("No quantization method applied.")
|
||||
return init_kwargs
|
||||
|
||||
|
||||
@QuantizationPlugin("bnb").register()
|
||||
def quantization_with_bnb(
|
||||
init_kwargs: dict[str, Any],
|
||||
model_args: "ModelArguments" = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
r"""Quantization with BNB."""
|
||||
logger.info_rank0("Using Bitsandbytes quantization.")
|
||||
quantization_bit = model_args.quant_config.get("quantization_bit", None)
|
||||
if quantization_bit is None:
|
||||
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
|
||||
quantization_bit = 4
|
||||
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
|
||||
if quantization_bit == 8:
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif quantization_bit == 4:
|
||||
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
|
||||
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
|
||||
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
|
||||
bnb_4bit_quant_storage=model_args.quant_config.get(
|
||||
"compute_dtype", torch.float16
|
||||
), # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
# TODO: improve deepspeed zero3 and fsdp detection.
|
||||
if kwargs.get("is_trainable", False):
|
||||
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
else:
|
||||
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
|
||||
if model_args.quant_config.get("quantization_bit") != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
|
||||
return init_kwargs
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""DeepSpeed integration via accelerate's built-in capabilities.
|
||||
|
||||
Instead of manually calling deepspeed.initialize() and syncing config,
|
||||
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
|
||||
initialization, backward, gradient accumulation, and model saving.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DeepSpeedEngine:
|
||||
"""DeepSpeed integration using accelerate's built-in capabilities.
|
||||
|
||||
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
|
||||
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
|
||||
- Config syncing (auto values, batch size, lr, etc.)
|
||||
- deepspeed.initialize() call
|
||||
- Optimizer / LR scheduler wrapping
|
||||
- Backward + gradient accumulation boundary
|
||||
- ZeRO-3 parameter gathering for saving
|
||||
"""
|
||||
|
||||
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
|
||||
config_file = dist_config.get("config_file")
|
||||
if not config_file:
|
||||
raise ValueError("DeepSpeed config_file is required in dist_config")
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
gradient_accumulation_steps=num_micro_batch,
|
||||
)
|
||||
|
||||
# Resolve "auto" for train_micro_batch_size_per_gpu so that
|
||||
# accelerate.prepare() does not require a DataLoader to infer it.
|
||||
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
|
||||
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
|
||||
|
||||
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
|
||||
|
||||
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
|
||||
"""No-op shard — actual model wrapping happens in prepare().
|
||||
|
||||
Returns self so the caller gets the engine instance via the hub interface.
|
||||
"""
|
||||
return self
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
model: HFModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: Optional[Any] = None,
|
||||
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
|
||||
"""Prepare model, optimizer, and lr_scheduler using accelerate.
|
||||
|
||||
Internally calls deepspeed.initialize() and wraps the returned objects.
|
||||
"""
|
||||
if lr_scheduler is not None:
|
||||
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
|
||||
else:
|
||||
model, optimizer = self.accelerator.prepare(model, optimizer)
|
||||
|
||||
model._accelerator = self.accelerator # type: ignore[assignment]
|
||||
|
||||
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
def backward(self, loss: torch.Tensor) -> None:
|
||||
"""Backward pass using accelerate.
|
||||
|
||||
Delegates to DeepSpeedEngineWrapper.backward() which respects
|
||||
sync_gradients to control gradient accumulation boundaries.
|
||||
When sync_gradients=True: engine.backward(loss) + engine.step()
|
||||
When sync_gradients=False: engine.backward(loss) only
|
||||
"""
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
def get_grad_norm(self) -> float:
|
||||
"""Get the global gradient norm from the DeepSpeed engine."""
|
||||
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
|
||||
if engine_wrapper is not None:
|
||||
return engine_wrapper.engine.get_global_grad_norm() or 0.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
"""Save model using accelerate's built-in ZeRO-aware utilities.
|
||||
|
||||
Expects model._accelerator to be set during prepare().
|
||||
Handles ZeRO-3 parameter gathering automatically via
|
||||
accelerator.get_state_dict().
|
||||
"""
|
||||
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
|
||||
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
logger.info_rank0(f"Model saved to {output_dir}")
|
||||
|
||||
@@ -17,23 +17,24 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....accelerator.interface import DistributedInterface
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
||||
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
if no_split_modules:
|
||||
if isinstance(no_split_modules, (list, tuple)):
|
||||
@@ -49,6 +50,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
||||
return None
|
||||
|
||||
|
||||
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
logger.info("Gathering state dict for saving...")
|
||||
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(model, options=options)
|
||||
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||
logger.info(f"Model saved to {output_dir}")
|
||||
|
||||
|
||||
class FSDP2Engine:
|
||||
def __init__(self, dist_config: dict):
|
||||
self.dist_interface = DistributedInterface()
|
||||
@@ -94,7 +109,10 @@ class FSDP2Engine:
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||
def is_lora_module_wrap(self, model) -> bool:
|
||||
return any(isinstance(module, LoraLayer) for module in model.modules())
|
||||
|
||||
def prepare_model(self, model: HFModel) -> HFModel:
|
||||
if self.fsdp_mesh is None:
|
||||
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
|
||||
return model
|
||||
@@ -111,6 +129,25 @@ class FSDP2Engine:
|
||||
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
||||
transformer_layer_cls_to_wrap = {layer_cls}
|
||||
|
||||
if self.is_lora_module_wrap(model):
|
||||
lora_modules = []
|
||||
for module in model.modules():
|
||||
if len(list(module.children())) != 0:
|
||||
continue
|
||||
if any(param.requires_grad for param in module.parameters(recurse=False)):
|
||||
lora_modules.append(module)
|
||||
|
||||
for module in lora_modules:
|
||||
fully_shard(
|
||||
module,
|
||||
mesh=self.fsdp_mesh,
|
||||
reshard_after_forward=self.reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||
)
|
||||
|
||||
logger.info("Applying FSDP wrap for LoRA layer separately.")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
should_wrap = False
|
||||
|
||||
@@ -156,7 +193,7 @@ class FSDP2Engine:
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
||||
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
|
||||
if self.rank == 0:
|
||||
logger.info("Materializing sharded model params...")
|
||||
|
||||
@@ -176,7 +213,7 @@ class FSDP2Engine:
|
||||
|
||||
return model
|
||||
|
||||
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
@@ -184,7 +221,7 @@ class FSDP2Engine:
|
||||
model = self.prepare_model(model)
|
||||
return model
|
||||
|
||||
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
import torch.distributed.checkpoint as dcp
|
||||
|
||||
try:
|
||||
@@ -203,7 +240,7 @@ class FSDP2Engine:
|
||||
logger.error(f"Failed to load from DCP: {e}")
|
||||
raise e
|
||||
|
||||
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
|
||||
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
||||
import glob
|
||||
import json
|
||||
|
||||
|
||||
@@ -12,9 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ....config.arg_utils import PluginConfig
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import HFModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
class DistributedPlugin(BasePlugin):
|
||||
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register()
|
||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
from .fsdp2 import FSDP2Engine
|
||||
|
||||
return FSDP2Engine(dist_config).shard_model(model)
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register("save_model")
|
||||
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
from .fsdp2 import save_model
|
||||
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||
return model
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
|
||||
return DeepSpeedEngine(
|
||||
dist_config,
|
||||
num_micro_batch=kwargs.get("num_micro_batch"),
|
||||
micro_batch_size=kwargs.get("micro_batch_size"),
|
||||
).shard_model(model)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("save_model")
|
||||
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
from .deepspeed import save_model
|
||||
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
train_dataset = DataEngine(data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args)
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
|
||||
@@ -21,6 +21,13 @@ from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from . import logging
|
||||
from .env import is_env_enabled
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
|
||||
@lru_cache
|
||||
def is_transformers_version_greater_than(content: str):
|
||||
return _get_package_version("transformers") >= version.parse(content)
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if "gptqmodel" in requirement or "autoawq" in requirement:
|
||||
pip_command = f"pip install {requirement} --no-build-isolation"
|
||||
else:
|
||||
pip_command = f"pip install {requirement}"
|
||||
|
||||
if mandatory:
|
||||
hint = f"To fix: run `{pip_command}`."
|
||||
else:
|
||||
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||
|
||||
require_version(requirement, hint)
|
||||
|
||||
@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
def fix_valuehead_cpu_loading():
|
||||
"""Fix valuehead model loading."""
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def bypass_mistral_regex_check():
|
||||
"""Disable Mistral regex network check.
|
||||
|
||||
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||
"""
|
||||
try:
|
||||
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||
except ImportError:
|
||||
# Very old transformers, nothing to patch
|
||||
yield
|
||||
return
|
||||
|
||||
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||
# Method does not exist in this version
|
||||
yield
|
||||
return
|
||||
|
||||
# Backup original method
|
||||
original = TokenizersBackend._patch_mistral_regex
|
||||
|
||||
# Replace with no-op
|
||||
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||
|
||||
yield
|
||||
|
||||
# Restore original method
|
||||
TokenizersBackend._patch_mistral_regex = original
|
||||
|
||||
@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def bypass_mistral_regex_check():
|
||||
"""Disable Mistral regex network check.
|
||||
|
||||
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||
"""
|
||||
try:
|
||||
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||
except ImportError:
|
||||
# Very old transformers, nothing to patch
|
||||
yield
|
||||
return
|
||||
|
||||
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||
# Method does not exist in this version
|
||||
yield
|
||||
return
|
||||
|
||||
# Backup original method
|
||||
original = TokenizersBackend._patch_mistral_regex
|
||||
|
||||
# Replace with no-op
|
||||
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||
|
||||
yield
|
||||
|
||||
# Restore original method
|
||||
TokenizersBackend._patch_mistral_regex = original
|
||||
|
||||
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins import peft as peft_module
|
||||
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||
|
||||
|
||||
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_path():
|
||||
return TINY_MODEL
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def model(model_path):
|
||||
return AutoModelForCausalLM.from_pretrained(model_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tokenizer(model_path):
|
||||
return AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def adapter_path(tmp_path):
|
||||
# Create a dummy adapter
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
|
||||
peft_model = get_peft_model(base_model, lora_config)
|
||||
save_path = tmp_path / "test_adapter"
|
||||
peft_model.save_pretrained(save_path)
|
||||
return str(save_path)
|
||||
|
||||
|
||||
def test_find_all_linear_modules(model):
|
||||
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
|
||||
modules = peft_module._find_all_linear_modules(model)
|
||||
expected_subset = {"q_proj", "v_proj"}
|
||||
assert expected_subset.issubset(set(modules))
|
||||
|
||||
|
||||
def test_get_lora_model(model):
|
||||
"""Verify a PeftModel is returned and LoRA config takes effect."""
|
||||
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
|
||||
model = peft_module.get_lora_model(model, config, is_train=True)
|
||||
assert isinstance(model, PeftModel)
|
||||
assert model.peft_config["default"].r == 8
|
||||
assert "q_proj" in model.peft_config["default"].target_modules
|
||||
|
||||
|
||||
def test_get_freeze_model_layers(model):
|
||||
"""Verify layer-wise freezing: only the last layer stays trainable."""
|
||||
# Freeze all but last layer
|
||||
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
|
||||
|
||||
# Ensure we start with something known
|
||||
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||
|
||||
num_layers = model.config.num_hidden_layers
|
||||
assert num_layers > 0
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if f"layers.{num_layers - 1}" in name:
|
||||
assert param.requires_grad, f"{name} should be trainable"
|
||||
elif "layers.0" in name and num_layers > 1:
|
||||
assert not param.requires_grad, f"{name} should be frozen"
|
||||
|
||||
|
||||
def test_get_freeze_model_modules(model):
|
||||
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
|
||||
# Freeze specific modules (e.g. only self_attn)
|
||||
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
|
||||
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||
|
||||
num_layers = model.config.num_hidden_layers
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
|
||||
assert param.requires_grad, f"{name} should be trainable"
|
||||
else:
|
||||
assert not param.requires_grad, f"{name} should be frozen"
|
||||
|
||||
|
||||
def test_load_adapter_single_for_inference(model, adapter_path):
|
||||
"""Verify single adapter is merged+unloaded in inference mode."""
|
||||
# Test loading single adapter for inference (merge and unload)
|
||||
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
|
||||
assert not isinstance(model_result, PeftModel)
|
||||
|
||||
|
||||
def test_load_adapter_resume_train(model, adapter_path):
|
||||
"""Verify training mode returns a trainable PeftModel."""
|
||||
# Test loading for training
|
||||
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
|
||||
assert isinstance(model_result, PeftModel)
|
||||
|
||||
|
||||
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
|
||||
"""Verify multiple adapters are rejected in training mode."""
|
||||
with pytest.raises(ValueError, match="only a single LoRA adapter"):
|
||||
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
|
||||
|
||||
|
||||
def test_load_adapter_infer_multiple_merges(model, adapter_path):
|
||||
"""Verify multiple adapters are merged in inference mode."""
|
||||
# Test merging multiple adapters
|
||||
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
|
||||
assert not isinstance(model_result, PeftModel)
|
||||
|
||||
|
||||
def test_merge_and_export_model(tmp_path, adapter_path):
|
||||
"""Verify merge_and_export_model produces export artifacts."""
|
||||
export_dir = tmp_path / "export"
|
||||
|
||||
args_dict = {
|
||||
"model": TINY_MODEL,
|
||||
"peft_config": {
|
||||
"name": "lora",
|
||||
"adapter_name_or_path": adapter_path,
|
||||
"export_dir": str(export_dir),
|
||||
"export_size": 1,
|
||||
"infer_dtype": "float16",
|
||||
},
|
||||
}
|
||||
|
||||
merge_and_export_model(args_dict)
|
||||
|
||||
assert export_dir.exists()
|
||||
assert (export_dir / "config.json").exists()
|
||||
assert (export_dir / "model.safetensors").exists()
|
||||
assert (export_dir / "tokenizer_config.json").exists()
|
||||
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
bitsandbytes = pytest.importorskip("bitsandbytes")
|
||||
|
||||
|
||||
def check_quantization_status(model):
|
||||
quantized_info = {"bnb": []}
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# check BitsAndBytes quantization
|
||||
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
|
||||
module, bitsandbytes.nn.modules.Linear4bit
|
||||
):
|
||||
quantized_info["bnb"].append(name)
|
||||
|
||||
return quantized_info
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda"])
|
||||
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
|
||||
def test_quantization_plugin(name, quantization_bit):
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
quant_config={
|
||||
"name": name,
|
||||
"quantization_bit": quantization_bit,
|
||||
},
|
||||
)
|
||||
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
quantized_info = check_quantization_status(model_engine.model)
|
||||
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
|
||||
assert any(v for v in quantized_info.values()), "model is not quantized properly."
|
||||
Reference in New Issue
Block a user