refactor ray integration, support save ckpt

Former-commit-id: d8cac6f54663e6cffeddf2c65e3da454e7b86a75
This commit is contained in:
hiyouga 2025-01-07 08:54:41 +00:00
parent bba52e258e
commit b4174021d6
18 changed files with 215 additions and 161 deletions

View File

@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY= LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB= USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM= RECORD_VRAM=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA 微调 ### QLoRA 微调
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) #### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)

View File

@ -9,7 +9,6 @@ finetuning_type: lora
lora_target: all lora_target: all
### dataset ### dataset
dataset_dir: /home/ray/default/LLaMA-Factory/data/
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048
@ -39,10 +38,3 @@ val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
eval_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500
### ray setup
resources_per_worker:
GPU: 1
num_workers: 4
# placement_strategy: ...

View File

@ -0,0 +1,48 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: tmp_dir
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
### ray
ray_run_name: llama3_8b_sft_lora
ray_num_workers: 4 # number of GPUs to use
resources_per_worker:
GPU: 1
placement_strategy: PACK

View File

@ -24,8 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count, use_ray
from .integrations.ray.ray_utils import should_use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -88,8 +87,7 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
use_ray = should_use_ray() if force_torchrun or (get_device_count() > 1 and not use_ray()):
if force_torchrun or (get_device_count() > 1 and not use_ray):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@ -229,7 +229,7 @@ def skip_check_imports() -> None:
r""" r"""
Avoids flash attention import error in custom model files. Avoids flash attention import error in custom model files.
""" """
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool: def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available(): def is_requests_available():
return _is_package_available("requests") return _is_package_available("requests")

View File

@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [ __all__ = [
@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments", "FinetuningArguments",
"GeneratingArguments", "GeneratingArguments",
"ModelArguments", "ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args", "get_eval_args",
"get_infer_args", "get_infer_args",
"get_ray_args",
"get_train_args", "get_train_args",
"read_args",
] ]

View File

@ -19,12 +19,12 @@ import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import transformers import transformers
import yaml import yaml
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
@ -34,12 +34,12 @@ from transformers.utils.versions import require_version
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from ..integrations.ray.ray_train_args import RayTrainArguments
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -47,60 +47,41 @@ logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ _TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
ModelArguments, _TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_TRAIN_CLS = Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None: if args is not None:
return args return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
# read yaml file
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# read json file
return json.loads(Path(sys.argv[1]).absolute().read_text()) return json.loads(Path(sys.argv[1]).absolute().read_text())
else: else:
return {} return sys.argv[1:]
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]: ) -> Tuple[Any]:
args_dict = _read_args(args) args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if args_dict: (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
args=args_dict, return_remaining_strings=True
)
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
@ -141,7 +122,7 @@ def _verify_model_args(
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None, training_args: Optional["TrainingArguments"] = None,
) -> None: ) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
@ -177,31 +158,29 @@ def _check_extra_dependencies(
require_version("rouge_chinese", "To fix: pip install rouge-chinese") require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args) return _parse_args(parser, args)
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments: def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayTrainArguments) parser = HfArgumentParser(RayArguments)
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0] (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
if ray_args.use_ray:
require_version("ray", "To fix: pip install ray")
return ray_args return ray_args
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()

View File

@ -0,0 +1,48 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@ -1,26 +0,0 @@
from typing import Any, Callable, Dict
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from .ray_train_args import RayTrainArguments
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments,
) -> TorchTrainer:
if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.num_workers,
resources_per_worker=ray_args.resources_per_worker,
use_gpu=True,
),
)
return trainer

View File

@ -1,30 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from .ray_utils import should_use_ray
@dataclass
class RayTrainArguments:
r"""
Arguments pertaining to the Ray training.
"""
resources_per_worker: Optional[Dict[str, Any]] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
num_workers: Optional[int] = field(
default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
)
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
)
@property
def use_ray(self) -> bool:
"""
Always returns the value from the environment variable check.
This prevents manual setting of use_ray.
"""
return should_use_ray()

View File

@ -1,5 +0,0 @@
import os
def should_use_ray():
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -18,7 +18,8 @@
# limitations under the License. # limitations under the License.
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -31,7 +32,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -40,11 +41,16 @@ if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -75,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer", trainer: "Trainer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
kwargs = { kwargs = {
@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@ -272,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer( def _create_loraplus_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate default_lr = training_args.learning_rate
@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer( def _create_badam_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], [] decay_params, nodecay_params = [], []
@ -373,7 +379,7 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer( def _create_adam_mini_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini # type: ignore from adam_mini import Adam_mini # type: ignore
@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer( def create_custom_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
@ -415,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
num_training_steps: int, num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None, optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None: ) -> None:
@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"}, config={"Framework": "🦙LlamaFactory"},
) )
return swanlab_callback return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@ -22,8 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..extras.packages import is_ray_available
from ..hparams.parser import _parse_ray_args, _read_args from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@ -32,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_swanlab_callback from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None: def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None) args = config.get("args")
callbacks = config.get("callbacks", []) callbacks: List[Any] = config.get("callbacks")
callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert: if finetuning_args.pissa_convert:
@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args_dict = _read_args(args) callbacks = callbacks or []
ray_args = _parse_ray_args(args_dict) callbacks.append(LogCallback())
args = read_args(args)
ray_args = get_ray_args(args)
if ray_args.use_ray: if ray_args.use_ray:
# Import lazily to avoid ray not installed error callbacks.append(RayTrainReportCallback())
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
trainer = get_ray_trainer( trainer = get_ray_trainer(
training_function=training_function, training_function=training_function,
train_loop_config={ train_loop_config={"args": args, "callbacks": callbacks},
"args": args_dict,
"callbacks": callbacks,
},
ray_args=ray_args, ray_args=ray_args,
) )
trainer.fit() trainer.fit()
else: else:
training_function( training_function(config={"args": args, "callbacks": callbacks})
config={
"args": args_dict,
"callbacks": callbacks,
}
)
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None: