diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 189c7533..4da4ec18 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -35,7 +35,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import get_peak_memory +from ..extras.misc import get_peak_memory, use_ray if is_safetensors_available(): @@ -194,7 +194,7 @@ class LogCallback(TrainerCallback): self.do_train = False # Web UI self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] - if self.webui_mode: + if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) @@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback): ) if self.finetuning_args.use_swanlab: - import swanlab + import swanlab # type: ignore swanlab.config.update( { diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 24620c87..bbbef1cf 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -46,11 +46,12 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def training_function(config: Dict[str, Any]) -> None: +def _training_function(config: Dict[str, Any]) -> None: args = config.get("args") callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + callbacks.append(LogCallback()) if finetuning_args.pissa_convert: callbacks.append(PissaConvertCallback()) @@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: - callbacks = callbacks or [] - callbacks.append(LogCallback()) - args = read_args(args) ray_args = get_ray_args(args) + callbacks = callbacks or [] if ray_args.use_ray: callbacks.append(RayTrainReportCallback()) trainer = get_ray_trainer( - training_function=training_function, + training_function=_training_function, train_loop_config={"args": args, "callbacks": callbacks}, ray_args=ray_args, ) trainer.fit() else: - training_function(config={"args": args, "callbacks": callbacks}) + _training_function(config={"args": args, "callbacks": callbacks}) def export_model(args: Optional[Dict[str, Any]] = None) -> None: diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index f5aecaeb..dc91ad50 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_gpu_or_npu_available, torch_gc +from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES @@ -394,12 +394,12 @@ class Runner: continue if self.do_train: - if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang] else: - if os.path.exists(os.path.join(output_path, "all_results.json")): + if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang]