refactor ray integration, support save ckpt

Former-commit-id: d8cac6f54663e6cffeddf2c65e3da454e7b86a75
This commit is contained in:
hiyouga 2025-01-07 08:54:41 +00:00
parent bba52e258e
commit b4174021d6
18 changed files with 215 additions and 161 deletions

View File

@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)

View File

@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
```
### QLoRA 微调
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)

View File

@ -9,7 +9,6 @@ finetuning_type: lora
lora_target: all
### dataset
dataset_dir: /home/ray/default/LLaMA-Factory/data/
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
@ -39,10 +38,3 @@ val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
### ray setup
resources_per_worker:
GPU: 1
num_workers: 4
# placement_strategy: ...

View File

@ -0,0 +1,48 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: tmp_dir
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
### ray
ray_run_name: llama3_8b_sft_lora
ray_num_workers: 4 # number of GPUs to use
resources_per_worker:
GPU: 1
placement_strategy: PACK

View File

@ -24,8 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count
from .integrations.ray.ray_utils import should_use_ray
from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@ -88,8 +87,7 @@ def main():
export_model()
elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
use_ray = should_use_ray()
if force_torchrun or (get_device_count() > 1 and not use_ray):
if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")

View File

@ -229,7 +229,7 @@ def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -62,6 +62,10 @@ def is_pillow_available():
return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available():
return _is_package_available("requests")

View File

@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [
@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args",
"get_infer_args",
"get_ray_args",
"get_train_args",
"read_args",
]

View File

@ -19,12 +19,12 @@ import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import transformers
import yaml
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
@ -34,12 +34,12 @@ from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..integrations.ray.ray_train_args import RayTrainArguments
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
@ -47,53 +47,34 @@ logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_TRAIN_CLS = Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None:
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
# read yaml file
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# read json file
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return {}
return sys.argv[1:]
def _parse_args(
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args_dict = _read_args(args)
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if args_dict:
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
args=args_dict, return_remaining_strings=True
)
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
@ -141,7 +122,7 @@ def _verify_model_args(
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
training_args: Optional["TrainingArguments"] = None,
) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
@ -177,31 +158,29 @@ def _check_extra_dependencies(
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
parser = HfArgumentParser(RayTrainArguments)
ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
if ray_args.use_ray:
require_version("ray", "To fix: pip install ray")
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()

View File

@ -0,0 +1,48 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@ -1,26 +0,0 @@
from typing import Any, Callable, Dict
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from .ray_train_args import RayTrainArguments
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments,
) -> TorchTrainer:
if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.num_workers,
resources_per_worker=ray_args.resources_per_worker,
use_gpu=True,
),
)
return trainer

View File

@ -1,30 +0,0 @@
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from .ray_utils import should_use_ray
@dataclass
class RayTrainArguments:
r"""
Arguments pertaining to the Ray training.
"""
resources_per_worker: Optional[Dict[str, Any]] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
num_workers: Optional[int] = field(
default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
)
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
)
@property
def use_ray(self) -> bool:
"""
Always returns the value from the environment variable check.
This prevents manual setting of use_ray.
"""
return should_use_ray()

View File

@ -1,5 +0,0 @@
import os
def should_use_ray():
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -18,7 +18,8 @@
# limitations under the License.
from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
@ -31,7 +32,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available
from ..extras.packages import is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -40,11 +41,16 @@ if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
@ -75,7 +81,7 @@ def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
@ -272,7 +278,7 @@ def _create_galore_optimizer(
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
@ -373,7 +379,7 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini # type: ignore
@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
@ -415,7 +421,7 @@ def create_custom_optimizer(
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer

View File

@ -22,8 +22,8 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..hparams.parser import _parse_ray_args, _read_args
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
@ -32,7 +32,11 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_swanlab_callback
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None)
callbacks = config.get("callbacks", [])
callbacks.append(LogCallback())
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
if finetuning_args.pissa_convert:
@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
callbacks = callbacks or []
callbacks.append(LogCallback())
args = read_args(args)
ray_args = get_ray_args(args)
if ray_args.use_ray:
# Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=training_function,
train_loop_config={
"args": args_dict,
"callbacks": callbacks,
},
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
training_function(
config={
"args": args_dict,
"callbacks": callbacks,
}
)
training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: