Former-commit-id: 3ae735ffe8acbe1df05324daa9e18ec37f33d594
This commit is contained in:
hiyouga 2024-01-08 21:42:25 +08:00
parent 5c71f4630d
commit de9148930f
2 changed files with 5 additions and 5 deletions

View File

@ -6,6 +6,7 @@ peft>=0.7.0
trl>=0.7.6 trl>=0.7.6
gradio>=3.38.0,<4.0.0 gradio>=3.38.0,<4.0.0
scipy scipy
einops
sentencepiece sentencepiece
protobuf protobuf
tiktoken tiktoken
@ -17,4 +18,3 @@ pydantic
fastapi fastapi
sse-starlette sse-starlette
matplotlib matplotlib
einops

View File

@ -118,13 +118,13 @@ class Runner:
logging_steps=get("train.logging_steps"), logging_steps=get("train.logging_steps"),
save_steps=get("train.save_steps"), save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"), warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha"), neftune_noise_alpha=get("train.neftune_alpha") or None,
train_on_prompt=get("train.train_on_prompt"), train_on_prompt=get("train.train_on_prompt"),
upcast_layernorm=get("train.upcast_layernorm"), upcast_layernorm=get("train.upcast_layernorm"),
lora_rank=get("train.lora_rank"), lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"), lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")), lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") if get("train.additional_target") else None, additional_target=get("train.additional_target") or None,
create_new_adapter=get("train.create_new_adapter"), create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
) )
@ -164,7 +164,6 @@ class Runner:
args = dict( args = dict(
stage="sft", stage="sft",
do_eval=True,
model_name_or_path=get("top.model_path"), model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path, adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None), cache_dir=user_config.get("cache_dir", None),
@ -187,8 +186,9 @@ class Runner:
) )
if get("eval.predict"): if get("eval.predict"):
args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
else:
args["do_eval"] = True
return args return args