diff --git a/.env.local b/.env.local index 203aebaf..1d8f2a00 100644 --- a/.env.local +++ b/.env.local @@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS= LLAMAFACTORY_VERBOSITY= USE_MODELSCOPE_HUB= USE_OPENMIND_HUB= +USE_RAY= RECORD_VRAM= # torchrun FORCE_TORCHRUN= diff --git a/examples/README.md b/examples/README.md index cc2afc46..89f7d174 100644 --- a/examples/README.md +++ b/examples/README.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### Supervised Fine-Tuning with Ray on 4 GPUs + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA Fine-Tuning #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) diff --git a/examples/README_zh.md b/examples/README_zh.md index b41d7ab8..2c108e56 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` +#### 使用 Ray 在 4 张 GPU 上微调 + +```bash +USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml +``` + ### QLoRA 微调 #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml new file mode 100644 index 00000000..4aac08bc --- /dev/null +++ b/examples/train_lora/llama3_lora_sft_ray.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: tmp_dir +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 + +### ray +ray_run_name: llama3_8b_sft_lora +ray_num_workers: 4 # number of GPUs to use +resources_per_worker: + GPU: 1 +placement_strategy: PACK diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index ef5bd1dc..72085e2d 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -24,7 +24,7 @@ from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env -from .extras.misc import get_device_count +from .extras.misc import get_device_count, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -87,7 +87,7 @@ def main(): export_model() elif command == Command.TRAIN: force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] - if force_torchrun or get_device_count() > 1: + if force_torchrun or (get_device_count() > 1 and not use_ray()): master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 735c5d63..11797f9f 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -229,7 +229,7 @@ def skip_check_imports() -> None: r""" Avoids flash attention import error in custom model files. """ - if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: transformers.dynamic_module_utils.check_imports = get_relative_imports @@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: def use_modelscope() -> bool: - return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] def use_openmind() -> bool: - return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + + +def use_ray() -> bool: + return os.getenv("USE_RAY", "0").lower() in ["true", "1"] diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 44b9bb8a..6b2bc3f3 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -62,6 +62,10 @@ def is_pillow_available(): return _is_package_available("PIL") +def is_ray_available(): + return _is_package_available("ray") + + def is_requests_available(): return _is_package_available("requests") diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index cfe448c1..254a845e 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments -from .parser import get_eval_args, get_infer_args, get_train_args +from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args +from .training_args import RayArguments, TrainingArguments __all__ = [ @@ -26,7 +27,11 @@ __all__ = [ "FinetuningArguments", "GeneratingArguments", "ModelArguments", + "RayArguments", + "TrainingArguments", "get_eval_args", "get_infer_args", + "get_ray_args", "get_train_args", + "read_args", ] diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 4a254367..62edbf78 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,13 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys -from typing import Any, Dict, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch import transformers -from transformers import HfArgumentParser, Seq2SeqTrainingArguments +import yaml +from transformers import HfArgumentParser from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode @@ -36,33 +39,42 @@ from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments +from .training_args import RayArguments, TrainingArguments logger = logging.get_logger(__name__) - check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] -_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: +def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]: if args is not None: - return parser.parse_dict(args) + return args if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): - return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return json.loads(Path(sys.argv[1]).absolute().read_text()) + else: + return sys.argv[1:] - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - return parser.parse_json_file(os.path.abspath(sys.argv[1])) - (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) +def _parse_args( + parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False +) -> Tuple[Any]: + args = read_args(args) + if isinstance(args, dict): + return parser.parse_dict(args, allow_extra_keys=allow_extra_keys) + + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True) if unknown_args: print(parser.format_help()) @@ -110,7 +122,7 @@ def _verify_model_args( def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - training_args: Optional["Seq2SeqTrainingArguments"] = None, + training_args: Optional["TrainingArguments"] = None, ) -> None: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") @@ -146,22 +158,28 @@ def _check_extra_dependencies( require_version("rouge_chinese", "To fix: pip install rouge-chinese") -def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) return _parse_args(parser, args) -def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) return _parse_args(parser, args) -def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) return _parse_args(parser, args) -def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: +def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments: + parser = HfArgumentParser(RayArguments) + (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) + return ray_args + + +def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) # Setup logging @@ -371,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: +def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) _set_transformers_logging() @@ -404,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: +def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) _set_transformers_logging() diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py new file mode 100644 index 00000000..9df24ace --- /dev/null +++ b/src/llamafactory/hparams/training_args.py @@ -0,0 +1,48 @@ +import json +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from transformers import Seq2SeqTrainingArguments +from transformers.training_args import _convert_str_dict + +from ..extras.misc import use_ray + + +@dataclass +class RayArguments: + r""" + Arguments pertaining to the Ray training. + """ + + ray_run_name: Optional[str] = field( + default=None, + metadata={"help": "The training results will be saved at `saves/ray_run_name`."}, + ) + ray_num_workers: int = field( + default=1, + metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, + ) + resources_per_worker: Union[dict, str] = field( + default_factory=lambda: {"GPU": 1}, + metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, + ) + placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field( + default="PACK", + metadata={"help": "The placement strategy for Ray training. Default is PACK."}, + ) + + def __post_init__(self): + self.use_ray = use_ray() + if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): + self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) + + +@dataclass +class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): + r""" + Arguments pertaining to the trainer. + """ + + def __post_init__(self): + Seq2SeqTrainingArguments.__post_init__(self) + RayArguments.__post_init__(self) diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 189c7533..4da4ec18 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -35,7 +35,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import get_peak_memory +from ..extras.misc import get_peak_memory, use_ray if is_safetensors_available(): @@ -194,7 +194,7 @@ class LogCallback(TrainerCallback): self.do_train = False # Web UI self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] - if self.webui_mode: + if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) @@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback): ) if self.finetuning_args.use_swanlab: - import swanlab + import swanlab # type: ignore swanlab.config.update( { diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 4cd2337e..6aca53cf 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -18,7 +18,8 @@ # limitations under the License. from collections.abc import Mapping -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer @@ -31,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available +from ..extras.packages import is_galore_available, is_ray_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,11 +41,16 @@ if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore +if is_ray_available(): + from ray.train import RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + + if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback + from transformers import PreTrainedModel, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from ..hparams import DataArguments + from ..hparams import DataArguments, RayArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -75,7 +81,7 @@ def create_modelcard_and_push( trainer: "Trainer", model_args: "ModelArguments", data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> None: kwargs = { @@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _create_galore_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": @@ -272,7 +278,7 @@ def _create_galore_optimizer( def _create_loraplus_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": default_lr = training_args.learning_rate @@ -312,7 +318,7 @@ def _create_loraplus_optimizer( def _create_badam_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": decay_params, nodecay_params = [], [] @@ -373,7 +379,7 @@ def _create_badam_optimizer( def _create_adam_mini_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", ) -> "torch.optim.Optimizer": from adam_mini import Adam_mini # type: ignore @@ -398,7 +404,7 @@ def _create_adam_mini_optimizer( def create_custom_optimizer( model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: if finetuning_args.use_galore: @@ -415,7 +421,7 @@ def create_custom_optimizer( def create_custom_scheduler( - training_args: "Seq2SeqTrainingArguments", + training_args: "TrainingArguments", num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: @@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall config={"Framework": "🦙LlamaFactory"}, ) return swanlab_callback + + +def get_ray_trainer( + training_function: Callable, + train_loop_config: Dict[str, Any], + ray_args: "RayArguments", +) -> "TorchTrainer": + if not ray_args.use_ray: + raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") + + trainer = TorchTrainer( + training_function, + train_loop_config=train_loop_config, + scaling_config=ScalingConfig( + num_workers=ray_args.ray_num_workers, + resources_per_worker=ray_args.resources_per_worker, + placement_strategy=ray_args.placement_strategy, + use_gpu=True, + ), + run_config=RunConfig( + name=ray_args.ray_run_name, + storage_path=Path("./saves").absolute().as_posix(), + ), + ) + return trainer diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 6c79320e..bbbef1cf 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -22,7 +22,8 @@ from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..hparams import get_infer_args, get_train_args +from ..extras.packages import is_ray_available +from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo @@ -31,7 +32,11 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .trainer_utils import get_swanlab_callback +from .trainer_utils import get_ray_trainer, get_swanlab_callback + + +if is_ray_available(): + from ray.train.huggingface.transformers import RayTrainReportCallback if TYPE_CHECKING: @@ -41,10 +46,12 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: - callbacks.append(LogCallback()) +def _training_function(config: Dict[str, Any]) -> None: + args = config.get("args") + callbacks: List[Any] = config.get("callbacks") model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + callbacks.append(LogCallback()) if finetuning_args.pissa_convert: callbacks.append(PissaConvertCallback()) @@ -69,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb raise ValueError(f"Unknown task: {finetuning_args.stage}.") +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: + args = read_args(args) + ray_args = get_ray_args(args) + callbacks = callbacks or [] + if ray_args.use_ray: + callbacks.append(RayTrainReportCallback()) + trainer = get_ray_trainer( + training_function=_training_function, + train_loop_config={"args": args, "callbacks": callbacks}, + ray_args=ray_args, + ) + trainer.fit() + else: + _training_function(config={"args": args, "callbacks": callbacks}) + + def export_model(args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args) diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index f5aecaeb..dc91ad50 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_gpu_or_npu_available, torch_gc +from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES @@ -394,12 +394,12 @@ class Runner: continue if self.do_train: - if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang] else: - if os.path.exists(os.path.join(output_path, "all_results.json")): + if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang]