mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
add dataset stage check
Former-commit-id: f7fdc088d49564f7d436fd445e7e1987a9a00a0b
This commit is contained in:
parent
00c3cd5454
commit
c955d9267c
@ -186,6 +186,18 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
# transfer training stage to dataset stage
|
||||
dataset_stage = general_args.stage
|
||||
if general_args.stage == "ppo":
|
||||
dataset_stage = "sft"
|
||||
elif general_args.stage == "dpo":
|
||||
dataset_stage = "rm"
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
|
||||
raise ValueError("Dataset {} is not supported for the stage {}"
|
||||
.format(dataset_attr.dataset_name, general_args.stage))
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
# Log on each process the small summary:
|
||||
|
Loading…
x
Reference in New Issue
Block a user