diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index a74445a6..c14ae6ec 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -71,28 +71,23 @@ def main(): export_model() elif command == Command.TRAIN: if get_device_count() > 1: - nnodes = os.environ.get("NNODES", "1") - node_rank = os.environ.get("RANK", "0") - nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count())) master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) subprocess.run( - [ - "torchrun", - "--nnodes", - nnodes, - "--node_rank", - node_rank, - "--nproc_per_node", - nproc_per_node, - "--master_addr", - master_addr, - "--master_port", - master_port, - launcher.__file__, - *sys.argv[1:], - ] + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + ).format( + nnodes=os.environ.get("NNODES", "1"), + node_rank=os.environ.get("RANK", "0"), + nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), + master_addr=master_addr, + master_port=master_port, + file_name=launcher.__file__, + args=" ".join(sys.argv[1:]), + ), + shell=True, ) else: run_exp() diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 36f593ae..6e1facef 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -1,20 +1,17 @@ import os -import signal -import random from copy import deepcopy from subprocess import Popen, TimeoutExpired from typing import TYPE_CHECKING, Any, Dict, Generator, Optional -import psutil from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count +from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config from .locales import ALERTS -from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd -from .. import launcher +from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd + if is_gradio_available(): import gradio as gr @@ -41,12 +38,7 @@ class Runner: def set_abort(self) -> None: self.aborted = True if self.trainer is not None: - for children in psutil.Process(self.trainer.pid).children(): # abort the child process - grand_children = children.children() - if len(grand_children) > 0: - for grand_child in grand_children: - os.kill(grand_child.pid, signal.SIGABRT) - os.kill(children.pid, signal.SIGABRT) + abort_leaf_process(self.trainer.pid) def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] @@ -285,30 +277,7 @@ class Runner: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" - if get_device_count() > 1: - nnodes = os.environ.get("NNODES", "1") - node_rank = os.environ.get("RANK", "0") - nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count())) - master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") - master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) - - self.trainer = Popen([ - "torchrun", - "--nnodes", - nnodes, - "--node_rank", - node_rank, - "--nproc_per_node", - nproc_per_node, - "--master_addr", - master_addr, - "--master_port", - master_port, - launcher.__file__, - save_cmd(args) - ], env=env, shell=True) - else: - self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) + self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) yield from self.monitor() def preview_train(self, data): diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 0446cb47..fc258806 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -1,8 +1,10 @@ import json import os +import signal from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +import psutil from transformers.trainer_utils import get_last_checkpoint from yaml import safe_dump, safe_load @@ -17,6 +19,18 @@ if is_gradio_available(): import gradio as gr +def abort_leaf_process(pid: int) -> None: + r""" + Aborts the leaf processes. + """ + children = psutil.Process(pid).children() + if children: + for child in children: + abort_leaf_process(child.pid) + else: + os.kill(pid, signal.SIGABRT) + + def can_quantize(finetuning_type: str) -> "gr.Dropdown": r""" Judges if the quantization is available in this finetuning type.