mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
fix webui
Former-commit-id: a0fe43aac968d9f6ca4724b8d718b45c03063b91
This commit is contained in:
parent
e1c9dcea93
commit
8659084ab0
@ -65,6 +65,8 @@ def get_dataset(
|
|||||||
max_samples_temp = min(len(dataset), max_samples)
|
max_samples_temp = min(len(dataset), max_samples)
|
||||||
dataset = dataset.select(range(max_samples_temp))
|
dataset = dataset.select(range(max_samples_temp))
|
||||||
|
|
||||||
|
# TODO: adapt to the sharegpt format
|
||||||
|
|
||||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
@ -138,13 +138,14 @@ class Runner:
|
|||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||||
resume_lora_training=(
|
resume_lora_training=resume_lora_training,
|
||||||
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
|
|
||||||
),
|
|
||||||
output_dir=output_dir
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
args[compute_type] = True
|
args[compute_type] = True
|
||||||
|
|
||||||
|
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
|
||||||
if args["quantization_bit"] is not None:
|
if args["quantization_bit"] is not None:
|
||||||
args["upcast_layernorm"] = True
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user