fix llamaboard with ray

Former-commit-id: c46675d5e56d175c27d705ef0068fb47dc89a872
This commit is contained in:
hiyouga 2025-01-07 09:59:24 +00:00
parent b4174021d6
commit 0c1ad5f3fb
3 changed files with 11 additions and 12 deletions

View File

@ -35,7 +35,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory from ..extras.misc import get_peak_memory, use_ray
if is_safetensors_available(): if is_safetensors_available():
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
self.do_train = False self.do_train = False
# Web UI # Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode and not use_ray():
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler) logging.add_handler(self.logger_handler)
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
) )
if self.finetuning_args.use_swanlab: if self.finetuning_args.use_swanlab:
import swanlab import swanlab # type: ignore
swanlab.config.update( swanlab.config.update(
{ {

View File

@ -46,11 +46,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None: def _training_function(config: Dict[str, Any]) -> None:
args = config.get("args") args = config.get("args")
callbacks: List[Any] = config.get("callbacks") callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback()) callbacks.append(PissaConvertCallback())
@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None:
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
callbacks = callbacks or []
callbacks.append(LogCallback())
args = read_args(args) args = read_args(args)
ray_args = get_ray_args(args) ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray: if ray_args.use_ray:
callbacks.append(RayTrainReportCallback()) callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer( trainer = get_ray_trainer(
training_function=training_function, training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks}, train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args, ray_args=ray_args,
) )
trainer.fit() trainer.fit()
else: else:
training_function(config={"args": args, "callbacks": callbacks}) _training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
@ -394,12 +394,12 @@ class Runner:
continue continue
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")): if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]