mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 07:45:59 +08:00
[algo] add ASFT (#10174)
This commit is contained in:
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: models/Llama-2-7b
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
use_asft_loss: true
|
||||||
|
asft_alpha: 0.1
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: med
|
||||||
|
template: llama2
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 10000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/llama2-7b/full/asft2
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 2.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: models/Qwen2.5-7B
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
use_asft_loss: true
|
||||||
|
asft_alpha: 0.05
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: math
|
||||||
|
template: qwen
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 10000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2-7b/full/asft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 5.0e-5
|
||||||
|
num_train_epochs: 1.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use the DFT loss."},
|
metadata={"help": "Whether to use the DFT loss."},
|
||||||
)
|
)
|
||||||
|
use_asft_loss: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use the ASFT loss."},
|
||||||
|
)
|
||||||
|
asft_alpha: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
|
||||||
|
)
|
||||||
use_eaft_loss: bool = field(
|
use_eaft_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use the EAFT loss."},
|
metadata={"help": "Whether to use the EAFT loss."},
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
model_args: Optional["ModelArguments"] = None,
|
model_args: Optional["ModelArguments"] = None,
|
||||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
ref_model: Optional["torch.nn.Module"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
|
|
||||||
|
if ref_model is not None:
|
||||||
|
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||||
|
|
||||||
|
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||||
|
if not (
|
||||||
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||||
|
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||||
|
if self.accelerator.is_fsdp2:
|
||||||
|
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||||
|
|
||||||
|
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
self.ref_model.eval()
|
||||||
|
|
||||||
if finetuning_args.use_dft_loss:
|
if finetuning_args.use_dft_loss:
|
||||||
from ..trainer_utils import dft_loss_func
|
from ..trainer_utils import dft_loss_func
|
||||||
|
|
||||||
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||||
)
|
)
|
||||||
|
elif finetuning_args.use_asft_loss:
|
||||||
|
from ..trainer_utils import asft_loss_func
|
||||||
|
|
||||||
|
self.compute_loss_func = partial(
|
||||||
|
asft_loss_func,
|
||||||
|
asft_alpha=finetuning_args.asft_alpha,
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
if self.finetuning_args.use_asft_loss:
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_outputs = self.ref_model(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
|
)
|
||||||
|
ref_logits = ref_outputs.logits
|
||||||
|
outputs = model(**inputs)
|
||||||
|
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
|
||||||
|
else:
|
||||||
|
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
|
|||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||||
from .trainer import CustomSeq2SeqTrainer
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
@@ -52,6 +52,10 @@ def run_sft(
|
|||||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
|
ref_model = None
|
||||||
|
if finetuning_args.use_asft_loss:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
|
||||||
@@ -124,6 +128,7 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
gen_kwargs=gen_kwargs,
|
gen_kwargs=gen_kwargs,
|
||||||
|
ref_model=ref_model,
|
||||||
**dataset_module,
|
**dataset_module,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**metric_module,
|
**metric_module,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
@@ -681,6 +682,88 @@ def _dft_cross_entropy(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def asft_loss_func(
|
||||||
|
outputs,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
asft_alpha: float = 0.1,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
if logits is None:
|
||||||
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
|
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
# shift for causal LM
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
|
||||||
|
|
||||||
|
vocab_size = shift_logits.size(-1)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
shift_logits = shift_logits.view(-1, vocab_size)
|
||||||
|
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
|
||||||
|
return _asft_cross_entropy(
|
||||||
|
policy_logits=shift_logits,
|
||||||
|
policy_labels=shift_labels,
|
||||||
|
ref_logits=shift_ref_logits,
|
||||||
|
asft_alpha=asft_alpha,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _asft_cross_entropy(
|
||||||
|
policy_logits: torch.Tensor,
|
||||||
|
policy_labels: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
asft_alpha: float = 0.1,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
dft_loss = _dft_cross_entropy(
|
||||||
|
policy_logits,
|
||||||
|
policy_labels,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
kl_loss = _kl_divergence(
|
||||||
|
policy_logits,
|
||||||
|
ref_logits,
|
||||||
|
policy_labels,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dft_loss + asft_alpha * kl_loss
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_divergence(
|
||||||
|
policy_logits: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# log p(y|x)
|
||||||
|
log_p = F.log_softmax(policy_logits, dim=-1)
|
||||||
|
|
||||||
|
# q(y|x)
|
||||||
|
q = F.softmax(ref_logits, dim=-1)
|
||||||
|
|
||||||
|
# token-wise KL
|
||||||
|
kl = F.kl_div(
|
||||||
|
log_p,
|
||||||
|
q,
|
||||||
|
reduction="none",
|
||||||
|
).sum(dim=-1) # [N]
|
||||||
|
|
||||||
|
# mask padding tokens
|
||||||
|
mask = (labels != ignore_index).float()
|
||||||
|
|
||||||
|
return (kl * mask).sum() / mask.sum()
|
||||||
|
|
||||||
|
|
||||||
def eaft_loss_func(
|
def eaft_loss_func(
|
||||||
outputs: "torch.Tensor",
|
outputs: "torch.Tensor",
|
||||||
labels: "torch.Tensor",
|
labels: "torch.Tensor",
|
||||||
|
|||||||
Reference in New Issue
Block a user