mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-08 23:12:18 +08:00
[v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -12,9 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from ..extras.env import VERSION, print_env
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
USAGE = (
|
||||
@@ -27,27 +28,97 @@ USAGE = (
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
|
||||
|
||||
|
||||
def launch():
|
||||
from .accelerator.helper import get_device_count
|
||||
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
|
||||
from .utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# NOTE:
|
||||
# `llamafactory-cli <command> ...` enters here first.
|
||||
# We may re-launch via `torchrun` for distributed training. In that case we must
|
||||
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
|
||||
# will misinterpret the first user argument (e.g. yaml config) as the command.
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
|
||||
if command == "sft": # train command will fallback to sft command
|
||||
from .trainers.sft_trainer import run_sft
|
||||
if command in _DIST_TRAIN_COMMANDS and (
|
||||
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||
):
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||
if int(nnodes) > 1:
|
||||
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
run_sft()
|
||||
# elastic launch support
|
||||
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||
rdzv_id = os.getenv("RDZV_ID")
|
||||
min_nnodes = os.getenv("MIN_NNODES")
|
||||
max_nnodes = os.getenv("MAX_NNODES")
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
torchrun_args = [
|
||||
"torchrun",
|
||||
"--nproc-per-node",
|
||||
nproc_per_node,
|
||||
]
|
||||
if rdzv_id is not None:
|
||||
# launch elastic job with fault tolerant support when possible
|
||||
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
||||
rdzv_nnodes = nnodes
|
||||
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||
if min_nnodes is not None and max_nnodes is not None:
|
||||
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||
|
||||
torchrun_args.extend(
|
||||
[
|
||||
"--nnodes",
|
||||
rdzv_nnodes,
|
||||
"--rdzv-id",
|
||||
rdzv_id,
|
||||
"--rdzv-backend",
|
||||
"c10d",
|
||||
"--rdzv-endpoint",
|
||||
f"{master_addr}:{master_port}",
|
||||
"--max-restarts",
|
||||
max_restarts,
|
||||
]
|
||||
)
|
||||
else:
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
torchrun_args.extend(
|
||||
[
|
||||
"--nnodes",
|
||||
nnodes,
|
||||
"--node_rank",
|
||||
node_rank,
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
master_port,
|
||||
]
|
||||
)
|
||||
|
||||
script_args = [__file__, command] + sys.argv[1:]
|
||||
process = subprocess.run(
|
||||
torchrun_args + script_args,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
|
||||
sys.exit(process.returncode)
|
||||
|
||||
elif command == "chat":
|
||||
from .samplers.cli_sampler import run_chat
|
||||
@@ -55,17 +126,54 @@ def launch():
|
||||
run_chat()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
raise NotImplementedError("Environment information is not implemented yet.")
|
||||
|
||||
elif command == "version":
|
||||
print(WELCOME)
|
||||
raise NotImplementedError("Version information is not implemented yet.")
|
||||
|
||||
elif command == "help":
|
||||
print(USAGE)
|
||||
|
||||
elif command in _DIST_TRAIN_COMMANDS:
|
||||
# Single GPU training without torchrun
|
||||
if command in ("train", "sft"):
|
||||
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
elif command == "dpo":
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
def main():
|
||||
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
|
||||
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
|
||||
|
||||
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
|
||||
if command in _DIST_TRAIN_COMMANDS:
|
||||
sys.argv.pop(1)
|
||||
else:
|
||||
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
|
||||
# treat it as sft by default.
|
||||
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
|
||||
command = "sft"
|
||||
if command in ("train", "sft"):
|
||||
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
elif command == "dpo":
|
||||
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
|
||||
# run_dpo()
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
# run_rm()
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user