From f45e81e186ec22282524bdd8322d36f928488cf6 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 7 Jun 2024 19:16:06 +0800 Subject: [PATCH] fix #4137 Former-commit-id: cdc0d6f5a2e5040e145c82c4801f37bd76529047 --- src/llamafactory/cli.py | 8 ++------ src/llamafactory/webui/runner.py | 3 +++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index b9e734e4..5042e53c 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -72,12 +72,8 @@ def main(): elif command == Command.EXPORT: export_model() elif command == Command.TRAIN: - disable_torchrun = os.environ.get("TORCHRUN_DISABLED", "0").lower() in ["true", "1"] - if disable_torchrun and get_device_count() > 1: - logger.warning("`torchrun` cannot be disabled when device count > 1.") - disable_torchrun = False - - if (not disable_torchrun) and (get_device_count() > 0): + force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] + if force_torchrun or get_device_count() > 1: master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index e8fdd129..c046152c 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -278,6 +278,9 @@ class Runner: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" + if args.get("deepspeed", None) is not None: + env["FORCE_TORCHRUN"] = "1" + self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) yield from self.monitor()