mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix webui
Former-commit-id: b240b6792fdb734dd77ed54861fdde059feb1855
This commit is contained in:
parent
64bf750a74
commit
9ef9cb316b
@ -65,6 +65,8 @@ def get_dataset(
|
|||||||
max_samples_temp = min(len(dataset), max_samples)
|
max_samples_temp = min(len(dataset), max_samples)
|
||||||
dataset = dataset.select(range(max_samples_temp))
|
dataset = dataset.select(range(max_samples_temp))
|
||||||
|
|
||||||
|
# TODO: adapt to the sharegpt format
|
||||||
|
|
||||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
@ -138,13 +138,14 @@ class Runner:
|
|||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||||
resume_lora_training=(
|
resume_lora_training=resume_lora_training,
|
||||||
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
|
|
||||||
),
|
|
||||||
output_dir=output_dir
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
args[compute_type] = True
|
args[compute_type] = True
|
||||||
|
|
||||||
|
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
|
||||||
if args["quantization_bit"] is not None:
|
if args["quantization_bit"] is not None:
|
||||||
args["upcast_layernorm"] = True
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user