drafting ray integration

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>

Former-commit-id: 163ddb680b6f84a4424a887a3b8a5d668044e87c
This commit is contained in:
Kourosh Hakhamaneshi 2024-12-30 16:48:52 -08:00 committed by hiyouga
parent a0bcac80c0
commit 1217240918
9 changed files with 143 additions and 21 deletions

View File

@ -9,6 +9,7 @@ finetuning_type: lora
lora_target: all lora_target: all
### dataset ### dataset
dataset_dir: /home/ray/default/lf/data/
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048
@ -38,3 +39,10 @@ val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
eval_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500
### ray setup
resources_per_worker:
GPU: 1
num_workers: 4
# placement_strategy: ...

View File

@ -27,7 +27,7 @@ from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
from .integrations.ray.ray_utils import should_use_ray
USAGE = ( USAGE = (
"-" * 70 "-" * 70
@ -87,7 +87,8 @@ def main():
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: use_ray = should_use_ray()
if force_torchrun or (get_device_count() > 1 and not use_ray):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@ -19,6 +19,10 @@ import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import json
import yaml
from pathlib import Path
import torch import torch
import transformers import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
@ -37,39 +41,51 @@ from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from ..integrations.ray.ray_train_args import RayTrainArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) # read yaml file
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# read json file
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return {}
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
if unknown_args: args_dict = _read_args(args)
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") if args_dict:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
return (*parsed_args,) if unknown_args:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
@ -161,9 +177,17 @@ def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) parser = HfArgumentParser(RayTrainArguments)
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
if ray_args.use_ray:
require_version("ray", "To fix: pip install ray")
return ray_args
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
_set_transformers_logging() _set_transformers_logging()

View File

@ -0,0 +1,28 @@
from typing import Any, Callable, Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from .ray_train_args import RayTrainArguments
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments,
) -> TorchTrainer:
if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.num_workers,
resources_per_worker=ray_args.resources_per_worker,
use_gpu=True,
),
)
return trainer

View File

@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from .ray_utils import should_use_ray
@dataclass
class RayTrainArguments:
r"""
Arguments pertaining to the Ray training.
"""
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
@property
def use_ray(self) -> bool:
"""
Always returns the value from the environment variable check.
This prevents manual setting of use_ray.
"""
return should_use_ray()

View File

@ -0,0 +1,9 @@
import os
def should_use_ray():
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -23,6 +23,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..hparams.parser import _parse_ray_args, _read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@ -36,12 +37,14 @@ from .trainer_utils import get_swanlab_callback
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None:
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: args = config.get("args", None)
callbacks = config.get("callbacks", [])
callbacks.append(LogCallback()) callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
@ -68,6 +71,33 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
else: else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict)
if ray_args.use_ray:
# Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
trainer = get_ray_trainer(
training_function=training_function,
train_loop_config={
"args": args_dict,
"callbacks": callbacks,
},
ray_args=ray_args,
)
trainer.fit()
else:
training_function(
config={
"args": args_dict,
"callbacks": callbacks,
}
)
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)