mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[launcher] Add elastic and fault-tolerant training support (#8286)
Signed-off-by: Butui Hu <hot123tea123@gmail.com>
This commit is contained in:
parent
5308424705
commit
83688b0b4d
@ -165,6 +165,14 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
|||||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
|
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Multimodal Supervised Fine-Tuning
|
#### Multimodal Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -106,6 +106,14 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
|
|||||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 支持弹性和容错的多机指令监督微调
|
||||||
|
|
||||||
|
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID(由参与该作业的所有节点共享)。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -18,6 +18,8 @@ import sys
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from .hparams import get_train_args
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
"-" * 70
|
"-" * 70
|
||||||
@ -76,7 +78,11 @@ def main():
|
|||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
# launch distributed training
|
# launch distributed training
|
||||||
|
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||||
|
rdzv_id = os.getenv("RDZV_ID")
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
|
min_nnodes = os.getenv("MIN_NNODES")
|
||||||
|
max_nnodes = os.getenv("MAX_NNODES")
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
@ -91,25 +97,51 @@ def main():
|
|||||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
if rdzv_id is not None:
|
||||||
process = subprocess.run(
|
# launch elastic job with fault tolerant support when possible
|
||||||
(
|
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
rdzv_nnodes = nnodes
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||||
|
if min_nnodes is not None and max_nnodes is not None:
|
||||||
|
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||||
|
cmd = [
|
||||||
|
"torchrun",
|
||||||
|
"--nnodes",
|
||||||
|
rdzv_nnodes,
|
||||||
|
"--nproc-per-node",
|
||||||
|
nproc_per_node,
|
||||||
|
"--rdzv-id",
|
||||||
|
rdzv_id,
|
||||||
|
"--rdzv-backend",
|
||||||
|
"c10d",
|
||||||
|
"--rdzv-endpoint",
|
||||||
|
f"{master_addr}:{master_port}",
|
||||||
|
"--max-restarts",
|
||||||
|
max_restarts,
|
||||||
|
launcher.__file__,
|
||||||
|
*sys.argv[1:],
|
||||||
|
]
|
||||||
|
process = subprocess.run(cmd, env=env, check=True)
|
||||||
|
else:
|
||||||
|
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||||
|
process = subprocess.run(
|
||||||
|
(
|
||||||
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
|
)
|
||||||
|
.format(
|
||||||
|
nnodes=nnodes,
|
||||||
|
node_rank=node_rank,
|
||||||
|
nproc_per_node=nproc_per_node,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
file_name=launcher.__file__,
|
||||||
|
args=" ".join(sys.argv[1:]),
|
||||||
|
)
|
||||||
|
.split(),
|
||||||
|
env=env,
|
||||||
|
check=True,
|
||||||
)
|
)
|
||||||
.format(
|
|
||||||
nnodes=nnodes,
|
|
||||||
node_rank=node_rank,
|
|
||||||
nproc_per_node=nproc_per_node,
|
|
||||||
master_addr=master_addr,
|
|
||||||
master_port=master_port,
|
|
||||||
file_name=launcher.__file__,
|
|
||||||
args=" ".join(sys.argv[1:]),
|
|
||||||
)
|
|
||||||
.split(),
|
|
||||||
env=env,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
sys.exit(process.returncode)
|
sys.exit(process.returncode)
|
||||||
elif command in COMMAND_MAP:
|
elif command in COMMAND_MAP:
|
||||||
COMMAND_MAP[command]()
|
COMMAND_MAP[command]()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user