mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix llamaboard with ray
Former-commit-id: c46675d5e56d175c27d705ef0068fb47dc89a872
This commit is contained in:
parent
b4174021d6
commit
0c1ad5f3fb
@ -35,7 +35,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory
|
||||
from ..extras.misc import get_peak_memory, use_ray
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
|
||||
self.do_train = False
|
||||
# Web UI
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
if self.webui_mode and not use_ray():
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.add_handler(self.logger_handler)
|
||||
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab
|
||||
import swanlab # type: ignore
|
||||
|
||||
swanlab.config.update(
|
||||
{
|
||||
|
@ -46,11 +46,12 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def training_function(config: Dict[str, Any]) -> None:
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
if finetuning_args.pissa_convert:
|
||||
callbacks.append(PissaConvertCallback())
|
||||
|
||||
@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(LogCallback())
|
||||
|
||||
args = read_args(args)
|
||||
ray_args = get_ray_args(args)
|
||||
callbacks = callbacks or []
|
||||
if ray_args.use_ray:
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=training_function,
|
||||
training_function=_training_function,
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
training_function(config={"args": args, "callbacks": callbacks})
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
@ -394,12 +394,12 @@ class Runner:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")):
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
|
||||
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
Loading…
x
Reference in New Issue
Block a user