mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[misc] upgrade cli (#7714)
This commit is contained in:
parent
1fd4d14fbb
commit
3ef36d0057
@ -3,16 +3,14 @@ datasets>=2.16.0,<=3.5.0
|
||||
accelerate>=0.34.0,<=1.6.0
|
||||
peft>=0.14.0,<=0.15.1
|
||||
trl>=0.8.6,<=0.9.6
|
||||
tokenizers>=0.19.0,<=0.21.0
|
||||
gradio>=4.38.0,<=5.21.0
|
||||
pandas>=2.0.0
|
||||
tokenizers>=0.19.0,<=0.21.1
|
||||
gradio>=4.38.0,<=5.25.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
tiktoken
|
||||
protobuf
|
||||
uvicorn
|
||||
pydantic
|
||||
fastapi
|
||||
sse-starlette
|
||||
matplotlib>=3.7.0
|
||||
@ -21,6 +19,7 @@ packaging
|
||||
pyyaml
|
||||
numpy<2.0.0
|
||||
pydantic<=2.10.6
|
||||
pandas>=2.0.0
|
||||
av
|
||||
librosa
|
||||
tyro<0.9.0
|
||||
|
@ -15,6 +15,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
@ -96,6 +97,13 @@ def main():
|
||||
if int(nnodes) > 1:
|
||||
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
@ -110,7 +118,9 @@ def main():
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
)
|
||||
.split()
|
||||
.split(),
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
|
@ -727,23 +727,23 @@ register_model_group(
|
||||
},
|
||||
"GLM-4-9B-Chat-0414": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-9B-Chat-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-Chat-0414" ,
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-Chat-0414",
|
||||
},
|
||||
"GLM-4-32B-0414": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414" ,
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
|
||||
},
|
||||
"GLM-4-32B-Chat-0414": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-Chat-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Chat-0414" ,
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Chat-0414",
|
||||
},
|
||||
"GLM-4-Z1-9B-Chat-0414": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-Z1-9B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-Z1-9B-0414" ,
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-Z1-9B-0414",
|
||||
},
|
||||
"GLM-4-Z1-32B-Chat-0414": {
|
||||
DownloadSource.DEFAULT: "THUDM/GLM-4-Z1-32B-0414",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-Z1-32B-0414" ,
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-Z1-32B-0414",
|
||||
},
|
||||
},
|
||||
template="glm4",
|
||||
|
@ -390,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
if finetuning_args.stage != "sft":
|
||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||
@ -408,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
# Post-process model arguments
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
model_args.device_map = {"": torch.device("cpu")}
|
||||
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
|
||||
@ -421,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
|
@ -96,6 +96,7 @@ def patch_config(
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE"))
|
||||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
|
@ -368,6 +368,7 @@ class Runner:
|
||||
if args.get("deepspeed", None) is not None:
|
||||
env["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
|
||||
yield from self.monitor()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user