refactor mm training

Former-commit-id: 3382317e32
This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 98b0c7530c
commit a83756b5e9
32 changed files with 505 additions and 472 deletions

View File

@@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]