mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[feature] add support for EAFT loss (#9720)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
40
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
40
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_eaft_loss: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: qwen2.5-0_5b/full/sft_eaft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
|
||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the DFT loss."},
|
||||
)
|
||||
use_eaft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the EAFT loss."},
|
||||
)
|
||||
eaft_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
||||
|
||||
@@ -87,6 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
from ..trainer_utils import eaft_loss_func
|
||||
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
|
||||
@@ -679,6 +679,61 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
|
||||
logits = logits.float()
|
||||
vocab_size = logits.size(-1)
|
||||
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
logits = logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
|
||||
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
|
||||
return loss
|
||||
|
||||
|
||||
def _eaft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
alpha: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
|
||||
|
||||
valid_losses = per_token_loss[valid_mask]
|
||||
|
||||
with torch.no_grad():
|
||||
source_detached = source[valid_mask].detach()
|
||||
|
||||
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||
log_probs_topk = topk_val - logsumexp_topk
|
||||
probs_topk = torch.exp(log_probs_topk)
|
||||
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||
|
||||
entropy_term = entropy_approx / 3.0
|
||||
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||
|
||||
weighted_losses = valid_losses * adaptive_weight
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
total_loss = weighted_losses.sum()
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(total_loss.device)
|
||||
loss = total_loss / num_items_in_batch
|
||||
else:
|
||||
loss = weighted_losses.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user