mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[feature] add support for dft loss (#8917)
This commit is contained in:
		
							parent
							
								
									936f4fd78e
								
							
						
					
					
						commit
						1ada15981a
					
				
							
								
								
									
										43
									
								
								examples/extras/dft/qwen2_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								examples/extras/dft/qwen2_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,43 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
use_dft_loss: true
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: qwen
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
dataloader_num_workers: 4
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/qwen2-1_5b/full/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
save_only_model: false
 | 
			
		||||
report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-5
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
# val_size: 0.1
 | 
			
		||||
# per_device_eval_batch_size: 1
 | 
			
		||||
# eval_strategy: steps
 | 
			
		||||
# eval_steps: 500
 | 
			
		||||
@ -428,6 +428,10 @@ class FinetuningArguments(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use the Muon optimizer."},
 | 
			
		||||
    )
 | 
			
		||||
    use_dft_loss: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to use the DFT loss."},
 | 
			
		||||
    )
 | 
			
		||||
    freeze_vision_tower: bool = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
 | 
			
		||||
 | 
			
		||||
@ -78,6 +78,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.use_dft_loss:
 | 
			
		||||
            from ..trainer_utils import dft_loss_func
 | 
			
		||||
 | 
			
		||||
            self.compute_loss_func = dft_loss_func
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -631,6 +631,51 @@ def get_batch_logps(
 | 
			
		||||
    return logps, valid_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
 | 
			
		||||
    logits = outputs.get("logits")
 | 
			
		||||
    if logits is None:
 | 
			
		||||
        return outputs.get("loss", torch.tensor(0.0))
 | 
			
		||||
 | 
			
		||||
    logits = logits.float()
 | 
			
		||||
    vocab_size = logits.size(-1)
 | 
			
		||||
    labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
 | 
			
		||||
    shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
    logits = logits.view(-1, vocab_size)
 | 
			
		||||
    shift_labels = shift_labels.view(-1)
 | 
			
		||||
    shift_labels = shift_labels.to(logits.device)
 | 
			
		||||
 | 
			
		||||
    loss = _dft_cross_entropy(logits, shift_labels, num_items_in_batch)
 | 
			
		||||
    return loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _dft_cross_entropy(
 | 
			
		||||
    source: torch.Tensor,
 | 
			
		||||
    target: torch.Tensor,
 | 
			
		||||
    num_items_in_batch: Optional[torch.Tensor] = None,
 | 
			
		||||
    ignore_index: int = -100,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
 | 
			
		||||
    valid_mask = target != ignore_index
 | 
			
		||||
    if not valid_mask.any():
 | 
			
		||||
        return torch.tensor(0.0, device=source.device, dtype=source.dtype)
 | 
			
		||||
 | 
			
		||||
    valid_losses = per_token_loss[valid_mask]
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        target_probs = torch.exp(-valid_losses)
 | 
			
		||||
 | 
			
		||||
    weighted_losses = valid_losses * target_probs
 | 
			
		||||
 | 
			
		||||
    if num_items_in_batch is not None:
 | 
			
		||||
        total_loss = weighted_losses.sum()
 | 
			
		||||
        if torch.is_tensor(num_items_in_batch):
 | 
			
		||||
            num_items_in_batch = num_items_in_batch.to(total_loss.device)
 | 
			
		||||
        loss = total_loss / num_items_in_batch
 | 
			
		||||
    else:
 | 
			
		||||
        loss = weighted_losses.mean()
 | 
			
		||||
    return loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def nested_detach(
 | 
			
		||||
    tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
 | 
			
		||||
    clone: bool = False,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user