add dataset stage check

Former-commit-id: 5c719a7ce988339d034a653456da9742dc2cec7c
This commit is contained in:
codemayq 2023-08-30 16:23:08 +08:00
parent f7658db1b6
commit f641075789

View File

@ -186,6 +186,18 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: