Merge pull request #6542 from erictang000/et/ray-integration

Ray Train integration with LLaMA-Factory

Former-commit-id: d23a98825bcb569bc51e21a3c2236eccd2f6d2fd
This commit is contained in:
hoshi-hiyouga 2025-01-08 11:46:03 +08:00 committed by GitHub
commit 984b202f83
14 changed files with 239 additions and 45 deletions

View File

@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY= LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB= USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM= RECORD_VRAM=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA 微调 ### QLoRA 微调
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)

View File

@ -0,0 +1,48 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: tmp_dir
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
### ray
ray_run_name: llama3_8b_sft_lora
ray_num_workers: 4 # number of GPUs to use
resources_per_worker:
GPU: 1
placement_strategy: PACK

View File

@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -87,7 +87,7 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@ -229,7 +229,7 @@ def skip_check_imports() -> None:
r""" r"""
Avoids flash attention import error in custom model files. Avoids flash attention import error in custom model files.
""" """
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool: def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")

View File

@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [ __all__ = [
@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments", "FinetuningArguments",
"GeneratingArguments", "GeneratingArguments",
"ModelArguments", "ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args", "get_eval_args",
"get_infer_args", "get_infer_args",
"get_ray_args",
"get_train_args", "get_train_args",
"read_args",
] ]

View File

@ -15,13 +15,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import transformers import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments import yaml
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
@ -36,33 +39,42 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None: if args is not None:
return parser.parse_dict(args) return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return sys.argv[1:]
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
@ -110,7 +122,7 @@ def _verify_model_args(
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None, training_args: Optional["TrainingArguments"] = None,
) -> None: ) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
@ -146,22 +158,28 @@ def _check_extra_dependencies(
require_version("rouge_chinese", "To fix: pip install rouge-chinese") require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
@ -371,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
@ -404,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()

View File

@ -0,0 +1,48 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@ -35,7 +35,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory from ..extras.misc import get_peak_memory, use_ray
if is_safetensors_available(): if is_safetensors_available():
@ -194,7 +194,7 @@ class LogCallback(TrainerCallback):
self.do_train = False self.do_train = False
# Web UI # Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode and not use_ray():
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler) logging.add_handler(self.logger_handler)
@ -383,7 +383,7 @@ class ReporterCallback(TrainerCallback):
) )
if self.finetuning_args.use_swanlab: if self.finetuning_args.use_swanlab:
import swanlab import swanlab # type: ignore
swanlab.config.update( swanlab.config.update(
{ {

View File

@ -18,7 +18,8 @@
# limitations under the License. # limitations under the License.
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -31,7 +32,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -40,11 +41,16 @@ if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -75,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer", trainer: "Trainer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
kwargs = { kwargs = {
@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@ -272,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer( def _create_loraplus_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate default_lr = training_args.learning_rate
@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer( def _create_badam_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], [] decay_params, nodecay_params = [], []
@ -373,7 +379,7 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer( def _create_adam_mini_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini # type: ignore from adam_mini import Adam_mini # type: ignore
@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer( def create_custom_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
@ -415,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
num_training_steps: int, num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None, optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None: ) -> None:
@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"}, config={"Framework": "🦙LlamaFactory"},
) )
return swanlab_callback return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@ -22,7 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@ -31,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_swanlab_callback from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,10 +46,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def _training_function(config: Dict[str, Any]) -> None:
callbacks.append(LogCallback()) args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback()) callbacks.append(PissaConvertCallback())
@ -69,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
@ -394,12 +394,12 @@ class Runner:
continue continue
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")): if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]