mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix abort in webui DDP mode
Former-commit-id: 2187518762844472a96b72fbad4da15d8bc97bbd
This commit is contained in:
parent
326f180397
commit
b12d4beb8a
@ -71,28 +71,23 @@ def main():
|
|||||||
export_model()
|
export_model()
|
||||||
elif command == Command.TRAIN:
|
elif command == Command.TRAIN:
|
||||||
if get_device_count() > 1:
|
if get_device_count() > 1:
|
||||||
nnodes = os.environ.get("NNODES", "1")
|
|
||||||
node_rank = os.environ.get("RANK", "0")
|
|
||||||
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
|
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
[
|
(
|
||||||
"torchrun",
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--nnodes",
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
nnodes,
|
).format(
|
||||||
"--node_rank",
|
nnodes=os.environ.get("NNODES", "1"),
|
||||||
node_rank,
|
node_rank=os.environ.get("RANK", "0"),
|
||||||
"--nproc_per_node",
|
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||||
nproc_per_node,
|
master_addr=master_addr,
|
||||||
"--master_addr",
|
master_port=master_port,
|
||||||
master_addr,
|
file_name=launcher.__file__,
|
||||||
"--master_port",
|
args=" ".join(sys.argv[1:]),
|
||||||
master_port,
|
),
|
||||||
launcher.__file__,
|
shell=True,
|
||||||
*sys.argv[1:],
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run_exp()
|
run_exp()
|
||||||
|
@ -1,20 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import random
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from subprocess import Popen, TimeoutExpired
|
from subprocess import Popen, TimeoutExpired
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||||
|
|
||||||
import psutil
|
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, get_device_count
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
|
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||||
from .. import launcher
|
|
||||||
|
|
||||||
if is_gradio_available():
|
if is_gradio_available():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -41,12 +38,7 @@ class Runner:
|
|||||||
def set_abort(self) -> None:
|
def set_abort(self) -> None:
|
||||||
self.aborted = True
|
self.aborted = True
|
||||||
if self.trainer is not None:
|
if self.trainer is not None:
|
||||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
abort_leaf_process(self.trainer.pid)
|
||||||
grand_children = children.children()
|
|
||||||
if len(grand_children) > 0:
|
|
||||||
for grand_child in grand_children:
|
|
||||||
os.kill(grand_child.pid, signal.SIGABRT)
|
|
||||||
os.kill(children.pid, signal.SIGABRT)
|
|
||||||
|
|
||||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||||
@ -285,30 +277,7 @@ class Runner:
|
|||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
env = deepcopy(os.environ)
|
env = deepcopy(os.environ)
|
||||||
env["LLAMABOARD_ENABLED"] = "1"
|
env["LLAMABOARD_ENABLED"] = "1"
|
||||||
if get_device_count() > 1:
|
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||||
nnodes = os.environ.get("NNODES", "1")
|
|
||||||
node_rank = os.environ.get("RANK", "0")
|
|
||||||
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
|
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
|
||||||
|
|
||||||
self.trainer = Popen([
|
|
||||||
"torchrun",
|
|
||||||
"--nnodes",
|
|
||||||
nnodes,
|
|
||||||
"--node_rank",
|
|
||||||
node_rank,
|
|
||||||
"--nproc_per_node",
|
|
||||||
nproc_per_node,
|
|
||||||
"--master_addr",
|
|
||||||
master_addr,
|
|
||||||
"--master_port",
|
|
||||||
master_port,
|
|
||||||
launcher.__file__,
|
|
||||||
save_cmd(args)
|
|
||||||
], env=env, shell=True)
|
|
||||||
else:
|
|
||||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
|
|
||||||
def preview_train(self, data):
|
def preview_train(self, data):
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import psutil
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
@ -17,6 +19,18 @@ if is_gradio_available():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def abort_leaf_process(pid: int) -> None:
|
||||||
|
r"""
|
||||||
|
Aborts the leaf processes.
|
||||||
|
"""
|
||||||
|
children = psutil.Process(pid).children()
|
||||||
|
if children:
|
||||||
|
for child in children:
|
||||||
|
abort_leaf_process(child.pid)
|
||||||
|
else:
|
||||||
|
os.kill(pid, signal.SIGABRT)
|
||||||
|
|
||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||||
r"""
|
r"""
|
||||||
Judges if the quantization is available in this finetuning type.
|
Judges if the quantization is available in this finetuning type.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user