mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Former-commit-id: 163ddb680b6f84a4424a887a3b8a5d668044e87c
This commit is contained in:
parent
a0bcac80c0
commit
1217240918
@ -9,6 +9,7 @@ finetuning_type: lora
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset_dir: /home/ray/default/lf/data/
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
@ -38,3 +39,10 @@ val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
|
||||
### ray setup
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
num_workers: 4
|
||||
# placement_strategy: ...
|
||||
|
@ -27,7 +27,7 @@ from .extras.env import VERSION, print_env
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
from .integrations.ray.ray_utils import should_use_ray
|
||||
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
@ -87,7 +87,8 @@ def main():
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
use_ray = should_use_ray()
|
||||
if force_torchrun or (get_device_count() > 1 and not use_ray):
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
|
@ -19,6 +19,10 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
return args
|
||||
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
# read yaml file
|
||||
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# read json file
|
||||
return json.loads(Path(sys.argv[1]).absolute().read_text())
|
||||
else:
|
||||
return {}
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
args_dict = _read_args(args)
|
||||
|
||||
if args_dict:
|
||||
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
@ -161,8 +177,16 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
|
||||
parser = HfArgumentParser(RayTrainArguments)
|
||||
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
|
||||
if ray_args.use_ray:
|
||||
require_version("ray", "To fix: pip install ray")
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
|
0
src/llamafactory/integrations/__init__.py
Normal file
0
src/llamafactory/integrations/__init__.py
Normal file
0
src/llamafactory/integrations/ray/__init__.py
Normal file
0
src/llamafactory/integrations/ray/__init__.py
Normal file
28
src/llamafactory/integrations/ray/ray_train.py
Normal file
28
src/llamafactory/integrations/ray/ray_train.py
Normal file
@ -0,0 +1,28 @@
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.train import ScalingConfig
|
||||
|
||||
from .ray_train_args import RayTrainArguments
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
ray_args: RayTrainArguments,
|
||||
) -> TorchTrainer:
|
||||
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
use_gpu=True,
|
||||
),
|
||||
)
|
||||
return trainer
|
22
src/llamafactory/integrations/ray/ray_train_args.py
Normal file
22
src/llamafactory/integrations/ray/ray_train_args.py
Normal file
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from .ray_utils import should_use_ray
|
||||
|
||||
@dataclass
|
||||
class RayTrainArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
|
||||
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
|
||||
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
"""
|
||||
Always returns the value from the environment variable check.
|
||||
This prevents manual setting of use_ray.
|
||||
"""
|
||||
return should_use_ray()
|
||||
|
9
src/llamafactory/integrations/ray/ray_utils.py
Normal file
9
src/llamafactory/integrations/ray/ray_utils.py
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def should_use_ray():
|
||||
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..hparams.parser import _parse_ray_args, _read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@ -40,8 +41,10 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def training_function(config: Dict[str, Any]) -> None:
|
||||
args = config.get("args", None)
|
||||
callbacks = config.get("callbacks", [])
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
callbacks.append(LogCallback())
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
|
||||
args_dict = _read_args(args)
|
||||
ray_args = _parse_ray_args(args_dict)
|
||||
|
||||
if ray_args.use_ray:
|
||||
# Import lazily to avoid ray not installed error
|
||||
from ..integrations.ray.ray_train import get_ray_trainer
|
||||
|
||||
# Initialize ray trainer
|
||||
trainer = get_ray_trainer(
|
||||
training_function=training_function,
|
||||
train_loop_config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
else:
|
||||
training_function(
|
||||
config={
|
||||
"args": args_dict,
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user