diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index b9e734e4..5042e53c 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -72,12 +72,8 @@ def main(): elif command == Command.EXPORT: export_model() elif command == Command.TRAIN: - disable_torchrun = os.environ.get("TORCHRUN_DISABLED", "0").lower() in ["true", "1"] - if disable_torchrun and get_device_count() > 1: - logger.warning("`torchrun` cannot be disabled when device count > 1.") - disable_torchrun = False - - if (not disable_torchrun) and (get_device_count() > 0): + force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] + if force_torchrun or get_device_count() > 1: master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index e8fdd129..c046152c 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -278,6 +278,9 @@ class Runner: args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" + if args.get("deepspeed", None) is not None: + env["FORCE_TORCHRUN"] = "1" + self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) yield from self.monitor()