diff --git a/examples/README.md b/examples/README.md index 1898f3a2..1e4cd2df 100644 --- a/examples/README.md +++ b/examples/README.md @@ -165,6 +165,14 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml ``` +### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes + +To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html). + +```bash +FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +``` + #### Multimodal Supervised Fine-Tuning ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index 8e6c6b64..1cb49a7d 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -106,6 +106,14 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` +### 支持弹性和容错的多机指令监督微调 + +要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。 + +```bash +FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +``` + #### 使用 DeepSpeed ZeRO-3 平均分配显存 ```bash diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 6eb61702..78d04d5d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -18,6 +18,8 @@ import sys from copy import deepcopy from functools import partial +from .hparams import get_train_args + USAGE = ( "-" * 70 @@ -76,7 +78,11 @@ def main(): command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): # launch distributed training + max_restarts = os.getenv("MAX_RESTARTS", "0") + rdzv_id = os.getenv("RDZV_ID") nnodes = os.getenv("NNODES", "1") + min_nnodes = os.getenv("MIN_NNODES") + max_nnodes = os.getenv("MAX_NNODES") node_rank = os.getenv("NODE_RANK", "0") nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") @@ -91,25 +97,51 @@ def main(): env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # NOTE: DO NOT USE shell=True to avoid security risk - process = subprocess.run( - ( - "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " - "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + if rdzv_id is not None: + # launch elastic job with fault tolerant support when possible + # see also https://docs.pytorch.org/docs/stable/elastic/train_script.html + rdzv_nnodes = nnodes + # elastic number of nodes if MIN_NNODES and MAX_NNODES are set + if min_nnodes is not None and max_nnodes is not None: + rdzv_nnodes = f"{min_nnodes}:{max_nnodes}" + cmd = [ + "torchrun", + "--nnodes", + rdzv_nnodes, + "--nproc-per-node", + nproc_per_node, + "--rdzv-id", + rdzv_id, + "--rdzv-backend", + "c10d", + "--rdzv-endpoint", + f"{master_addr}:{master_port}", + "--max-restarts", + max_restarts, + launcher.__file__, + *sys.argv[1:], + ] + process = subprocess.run(cmd, env=env, check=True) + else: + # NOTE: DO NOT USE shell=True to avoid security risk + process = subprocess.run( + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + ) + .format( + nnodes=nnodes, + node_rank=node_rank, + nproc_per_node=nproc_per_node, + master_addr=master_addr, + master_port=master_port, + file_name=launcher.__file__, + args=" ".join(sys.argv[1:]), + ) + .split(), + env=env, + check=True, ) - .format( - nnodes=nnodes, - node_rank=node_rank, - nproc_per_node=nproc_per_node, - master_addr=master_addr, - master_port=master_port, - file_name=launcher.__file__, - args=" ".join(sys.argv[1:]), - ) - .split(), - env=env, - check=True, - ) sys.exit(process.returncode) elif command in COMMAND_MAP: COMMAND_MAP[command]()