mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 21:22:50 +08:00
Merge pull request #741 from hiyouga/feature-addDatasetCheck
Feature add dataset check Former-commit-id: 701a9d60cb030a18238b5426477752d41657f85d
This commit is contained in:
commit
5e1a8b1c74
@ -285,7 +285,7 @@ register_template(
|
|||||||
"Always answer as helpfully as possible, while being safe. "
|
"Always answer as helpfully as possible, while being safe. "
|
||||||
"Your answers should not include any harmful, unethical, "
|
"Your answers should not include any harmful, unethical, "
|
||||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||||
"If a question does not make any sense, or is not factually coherent, "
|
"If a question does not make any sense, or is not factually coherent, "
|
||||||
"explain why instead of answering something not correct. "
|
"explain why instead of answering something not correct. "
|
||||||
"If you don't know the answer to a question, please don't share false information."
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
|
@ -186,6 +186,18 @@ def get_train_args(
|
|||||||
else:
|
else:
|
||||||
model_args.compute_dtype = torch.float16
|
model_args.compute_dtype = torch.float16
|
||||||
|
|
||||||
|
# transfer training stage to dataset stage
|
||||||
|
dataset_stage = general_args.stage
|
||||||
|
if general_args.stage == "ppo":
|
||||||
|
dataset_stage = "sft"
|
||||||
|
elif general_args.stage == "dpo":
|
||||||
|
dataset_stage = "rm"
|
||||||
|
|
||||||
|
for dataset_attr in data_args.dataset_list:
|
||||||
|
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
|
||||||
|
raise ValueError("Dataset {} is not supported for the stage {}"
|
||||||
|
.format(dataset_attr.dataset_name, general_args.stage))
|
||||||
|
|
||||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user