mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
refactor mm training
This commit is contained in:
@@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user