mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[trainer] fix llama3.2 vision kto train (#6904)
Former-commit-id: 1563e89adc8988fc6e4250634a3f1e385979b0e5
This commit is contained in:
		
							parent
							
								
									2581cc844b
								
							
						
					
					
						commit
						0c0cdc26bc
					
				@ -285,6 +285,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
        batch["kl_input_ids"] = kl_batch["input_ids"]
 | 
			
		||||
        batch["kl_attention_mask"] = kl_batch["attention_mask"]
 | 
			
		||||
        batch["kl_labels"] = kl_batch["labels"]
 | 
			
		||||
        if "cross_attention_mask" in kl_batch:  # for mllama inputs.
 | 
			
		||||
            batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
 | 
			
		||||
        if "token_type_ids" in kl_batch:
 | 
			
		||||
            batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        if "image_grid_thw" in batch:
 | 
			
		||||
            model_inputs["image_grid_thw"] = batch["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
        if "aspect_ratio_ids" in batch:
 | 
			
		||||
            model_inputs["aspect_ratio_ids"] = batch["aspect_ratio_ids"]
 | 
			
		||||
 | 
			
		||||
        if "aspect_ratio_mask" in batch:
 | 
			
		||||
            model_inputs["aspect_ratio_mask"] = batch["aspect_ratio_mask"]
 | 
			
		||||
 | 
			
		||||
        if f"{prefix}cross_attention_mask" in batch:
 | 
			
		||||
            model_inputs["cross_attention_mask"] = batch[f"{prefix}cross_attention_mask"]
 | 
			
		||||
 | 
			
		||||
        logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
        logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
 | 
			
		||||
        return logits, logps, logps / valid_length
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user