mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
parent
5c71f4630d
commit
de9148930f
@ -6,6 +6,7 @@ peft>=0.7.0
|
|||||||
trl>=0.7.6
|
trl>=0.7.6
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
tiktoken
|
||||||
@ -17,4 +18,3 @@ pydantic
|
|||||||
fastapi
|
fastapi
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
einops
|
|
||||||
|
@ -118,13 +118,13 @@ class Runner:
|
|||||||
logging_steps=get("train.logging_steps"),
|
logging_steps=get("train.logging_steps"),
|
||||||
save_steps=get("train.save_steps"),
|
save_steps=get("train.save_steps"),
|
||||||
warmup_steps=get("train.warmup_steps"),
|
warmup_steps=get("train.warmup_steps"),
|
||||||
neftune_noise_alpha=get("train.neftune_alpha"),
|
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||||
train_on_prompt=get("train.train_on_prompt"),
|
train_on_prompt=get("train.train_on_prompt"),
|
||||||
upcast_layernorm=get("train.upcast_layernorm"),
|
upcast_layernorm=get("train.upcast_layernorm"),
|
||||||
lora_rank=get("train.lora_rank"),
|
lora_rank=get("train.lora_rank"),
|
||||||
lora_dropout=get("train.lora_dropout"),
|
lora_dropout=get("train.lora_dropout"),
|
||||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||||
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
additional_target=get("train.additional_target") or None,
|
||||||
create_new_adapter=get("train.create_new_adapter"),
|
create_new_adapter=get("train.create_new_adapter"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||||
)
|
)
|
||||||
@ -164,7 +164,6 @@ class Runner:
|
|||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
do_eval=True,
|
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
cache_dir=user_config.get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
@ -187,8 +186,9 @@ class Runner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if get("eval.predict"):
|
if get("eval.predict"):
|
||||||
args.pop("do_eval", None)
|
|
||||||
args["do_predict"] = True
|
args["do_predict"] = True
|
||||||
|
else:
|
||||||
|
args["do_eval"] = True
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user