fix abort in webui DDP mode

Former-commit-id: 2187518762844472a96b72fbad4da15d8bc97bbd
This commit is contained in:
hiyouga 2024-06-04 00:10:24 +08:00
parent 326f180397
commit b12d4beb8a
3 changed files with 32 additions and 54 deletions

View File

@ -71,28 +71,23 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
if get_device_count() > 1: if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run( subprocess.run(
[ (
"torchrun", "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--nnodes", "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
nnodes, ).format(
"--node_rank", nnodes=os.environ.get("NNODES", "1"),
node_rank, node_rank=os.environ.get("RANK", "0"),
"--nproc_per_node", nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
nproc_per_node, master_addr=master_addr,
"--master_addr", master_port=master_port,
master_addr, file_name=launcher.__file__,
"--master_port", args=" ".join(sys.argv[1:]),
master_port, ),
launcher.__file__, shell=True,
*sys.argv[1:],
]
) )
else: else:
run_exp() run_exp()

View File

@ -1,20 +1,17 @@
import os import os
import signal
import random
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
from .. import launcher
if is_gradio_available(): if is_gradio_available():
import gradio as gr import gradio as gr
@ -41,12 +38,7 @@ class Runner:
def set_abort(self) -> None: def set_abort(self) -> None:
self.aborted = True self.aborted = True
if self.trainer is not None: if self.trainer is not None:
for children in psutil.Process(self.trainer.pid).children(): # abort the child process abort_leaf_process(self.trainer.pid)
grand_children = children.children()
if len(grand_children) > 0:
for grand_child in grand_children:
os.kill(grand_child.pid, signal.SIGABRT)
os.kill(children.pid, signal.SIGABRT)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@ -285,30 +277,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ) env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1" env["LLAMABOARD_ENABLED"] = "1"
if get_device_count() > 1: self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
self.trainer = Popen([
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
save_cmd(args)
], env=env, shell=True)
else:
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor() yield from self.monitor()
def preview_train(self, data): def preview_train(self, data):

View File

@ -1,8 +1,10 @@
import json import json
import os import os
import signal
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
@ -17,6 +19,18 @@ if is_gradio_available():
import gradio as gr import gradio as gr
def abort_leaf_process(pid: int) -> None:
r"""
Aborts the leaf processes.
"""
children = psutil.Process(pid).children()
if children:
for child in children:
abort_leaf_process(child.pid)
else:
os.kill(pid, signal.SIGABRT)
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r""" r"""
Judges if the quantization is available in this finetuning type. Judges if the quantization is available in this finetuning type.