mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix llamaboard with ray
Former-commit-id: c46675d5e56d175c27d705ef0068fb47dc89a872
This commit is contained in:
parent
b4174021d6
commit
0c1ad5f3fb
@ -35,7 +35,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import get_peak_memory
|
from ..extras.misc import get_peak_memory, use_ray
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
|
|||||||
self.do_train = False
|
self.do_train = False
|
||||||
# Web UI
|
# Web UI
|
||||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||||
if self.webui_mode:
|
if self.webui_mode and not use_ray():
|
||||||
signal.signal(signal.SIGABRT, self._set_abort)
|
signal.signal(signal.SIGABRT, self._set_abort)
|
||||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||||
logging.add_handler(self.logger_handler)
|
logging.add_handler(self.logger_handler)
|
||||||
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.finetuning_args.use_swanlab:
|
if self.finetuning_args.use_swanlab:
|
||||||
import swanlab
|
import swanlab # type: ignore
|
||||||
|
|
||||||
swanlab.config.update(
|
swanlab.config.update(
|
||||||
{
|
{
|
||||||
|
@ -46,11 +46,12 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def training_function(config: Dict[str, Any]) -> None:
|
def _training_function(config: Dict[str, Any]) -> None:
|
||||||
args = config.get("args")
|
args = config.get("args")
|
||||||
callbacks: List[Any] = config.get("callbacks")
|
callbacks: List[Any] = config.get("callbacks")
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||||
|
|
||||||
|
callbacks.append(LogCallback())
|
||||||
if finetuning_args.pissa_convert:
|
if finetuning_args.pissa_convert:
|
||||||
callbacks.append(PissaConvertCallback())
|
callbacks.append(PissaConvertCallback())
|
||||||
|
|
||||||
@ -76,21 +77,19 @@ def training_function(config: Dict[str, Any]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||||
callbacks = callbacks or []
|
|
||||||
callbacks.append(LogCallback())
|
|
||||||
|
|
||||||
args = read_args(args)
|
args = read_args(args)
|
||||||
ray_args = get_ray_args(args)
|
ray_args = get_ray_args(args)
|
||||||
|
callbacks = callbacks or []
|
||||||
if ray_args.use_ray:
|
if ray_args.use_ray:
|
||||||
callbacks.append(RayTrainReportCallback())
|
callbacks.append(RayTrainReportCallback())
|
||||||
trainer = get_ray_trainer(
|
trainer = get_ray_trainer(
|
||||||
training_function=training_function,
|
training_function=_training_function,
|
||||||
train_loop_config={"args": args, "callbacks": callbacks},
|
train_loop_config={"args": args, "callbacks": callbacks},
|
||||||
ray_args=ray_args,
|
ray_args=ray_args,
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
else:
|
else:
|
||||||
training_function(config={"args": args, "callbacks": callbacks})
|
_training_function(config={"args": args, "callbacks": callbacks})
|
||||||
|
|
||||||
|
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
|||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
|
||||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||||
from .locales import ALERTS, LOCALES
|
from .locales import ALERTS, LOCALES
|
||||||
@ -394,12 +394,12 @@ class Runner:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if self.do_train:
|
if self.do_train:
|
||||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
|
||||||
finish_info = ALERTS["info_finished"][lang]
|
finish_info = ALERTS["info_finished"][lang]
|
||||||
else:
|
else:
|
||||||
finish_info = ALERTS["err_failed"][lang]
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
else:
|
else:
|
||||||
if os.path.exists(os.path.join(output_path, "all_results.json")):
|
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
|
||||||
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
||||||
else:
|
else:
|
||||||
finish_info = ALERTS["err_failed"][lang]
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user