mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	refactor ray integration, support save ckpt
Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
This commit is contained in:
		
							parent
							
								
									4f31ad997c
								
							
						
					
					
						commit
						944a2aec4d
					
				@ -12,6 +12,7 @@ FORCE_CHECK_IMPORTS=
 | 
			
		||||
LLAMAFACTORY_VERBOSITY=
 | 
			
		||||
USE_MODELSCOPE_HUB=
 | 
			
		||||
USE_OPENMIND_HUB=
 | 
			
		||||
USE_RAY=
 | 
			
		||||
RECORD_VRAM=
 | 
			
		||||
# torchrun
 | 
			
		||||
FORCE_TORCHRUN=
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
 | 
			
		||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Supervised Fine-Tuning with Ray on 4 GPUs
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### QLoRA Fine-Tuning
 | 
			
		||||
 | 
			
		||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
 | 
			
		||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 使用 Ray 在 4 张 GPU 上微调
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### QLoRA 微调
 | 
			
		||||
 | 
			
		||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,6 @@ finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset_dir: /home/ray/default/LLaMA-Factory/data/
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
@ -39,10 +38,3 @@ val_size: 0.1
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
eval_strategy: steps
 | 
			
		||||
eval_steps: 500
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### ray setup
 | 
			
		||||
resources_per_worker:
 | 
			
		||||
  GPU: 1
 | 
			
		||||
num_workers: 4
 | 
			
		||||
# placement_strategy: ...
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										48
									
								
								examples/train_lora/llama3_lora_sft_ray.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								examples/train_lora/llama3_lora_sft_ray.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct  # or use local absolute path
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
dataset_dir: REMOTE:llamafactory/demo_data  # or use local absolute path
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: tmp_dir
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
val_size: 0.1
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
eval_strategy: steps
 | 
			
		||||
eval_steps: 500
 | 
			
		||||
 | 
			
		||||
### ray
 | 
			
		||||
ray_run_name: llama3_8b_sft_lora
 | 
			
		||||
ray_num_workers: 4  # number of GPUs to use
 | 
			
		||||
resources_per_worker:
 | 
			
		||||
  GPU: 1
 | 
			
		||||
placement_strategy: PACK
 | 
			
		||||
@ -24,8 +24,7 @@ from .chat.chat_model import run_chat
 | 
			
		||||
from .eval.evaluator import run_eval
 | 
			
		||||
from .extras import logging
 | 
			
		||||
from .extras.env import VERSION, print_env
 | 
			
		||||
from .extras.misc import get_device_count
 | 
			
		||||
from .integrations.ray.ray_utils import should_use_ray
 | 
			
		||||
from .extras.misc import get_device_count, use_ray
 | 
			
		||||
from .train.tuner import export_model, run_exp
 | 
			
		||||
from .webui.interface import run_web_demo, run_web_ui
 | 
			
		||||
 | 
			
		||||
@ -88,8 +87,7 @@ def main():
 | 
			
		||||
        export_model()
 | 
			
		||||
    elif command == Command.TRAIN:
 | 
			
		||||
        force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
 | 
			
		||||
        use_ray = should_use_ray()
 | 
			
		||||
        if force_torchrun or (get_device_count() > 1 and not use_ray):
 | 
			
		||||
        if force_torchrun or (get_device_count() > 1 and not use_ray()):
 | 
			
		||||
            master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
            master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
 | 
			
		||||
            logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
 | 
			
		||||
@ -229,7 +229,7 @@ def skip_check_imports() -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Avoids flash attention import error in custom model files.
 | 
			
		||||
    """
 | 
			
		||||
    if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
 | 
			
		||||
    if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
 | 
			
		||||
        transformers.dynamic_module_utils.check_imports = get_relative_imports
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -275,8 +275,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_modelscope() -> bool:
 | 
			
		||||
    return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
 | 
			
		||||
    return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_openmind() -> bool:
 | 
			
		||||
    return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
 | 
			
		||||
    return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_ray() -> bool:
 | 
			
		||||
    return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
 | 
			
		||||
 | 
			
		||||
@ -62,6 +62,10 @@ def is_pillow_available():
 | 
			
		||||
    return _is_package_available("PIL")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_ray_available():
 | 
			
		||||
    return _is_package_available("ray")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_requests_available():
 | 
			
		||||
    return _is_package_available("requests")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
 | 
			
		||||
from .finetuning_args import FinetuningArguments
 | 
			
		||||
from .generating_args import GeneratingArguments
 | 
			
		||||
from .model_args import ModelArguments
 | 
			
		||||
from .parser import get_eval_args, get_infer_args, get_train_args
 | 
			
		||||
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
 | 
			
		||||
from .training_args import RayArguments, TrainingArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
@ -26,7 +27,11 @@ __all__ = [
 | 
			
		||||
    "FinetuningArguments",
 | 
			
		||||
    "GeneratingArguments",
 | 
			
		||||
    "ModelArguments",
 | 
			
		||||
    "RayArguments",
 | 
			
		||||
    "TrainingArguments",
 | 
			
		||||
    "get_eval_args",
 | 
			
		||||
    "get_infer_args",
 | 
			
		||||
    "get_ray_args",
 | 
			
		||||
    "get_train_args",
 | 
			
		||||
    "read_args",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -19,12 +19,12 @@ import json
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import transformers
 | 
			
		||||
import yaml
 | 
			
		||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
 | 
			
		||||
from transformers import HfArgumentParser
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.trainer_utils import get_last_checkpoint
 | 
			
		||||
from transformers.training_args import ParallelMode
 | 
			
		||||
@ -34,12 +34,12 @@ from transformers.utils.versions import require_version
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import CHECKPOINT_NAMES
 | 
			
		||||
from ..extras.misc import check_dependencies, get_current_device
 | 
			
		||||
from ..integrations.ray.ray_train_args import RayTrainArguments
 | 
			
		||||
from .data_args import DataArguments
 | 
			
		||||
from .evaluation_args import EvaluationArguments
 | 
			
		||||
from .finetuning_args import FinetuningArguments
 | 
			
		||||
from .generating_args import GeneratingArguments
 | 
			
		||||
from .model_args import ModelArguments
 | 
			
		||||
from .training_args import RayArguments, TrainingArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
@ -47,60 +47,41 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
check_dependencies()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TRAIN_ARGS = [
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataArguments,
 | 
			
		||||
    Seq2SeqTrainingArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments,
 | 
			
		||||
    RayTrainArguments,
 | 
			
		||||
]
 | 
			
		||||
_TRAIN_CLS = Tuple[
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataArguments,
 | 
			
		||||
    Seq2SeqTrainingArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments,
 | 
			
		||||
    RayTrainArguments,
 | 
			
		||||
]
 | 
			
		||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
 | 
			
		||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
 | 
			
		||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
 | 
			
		||||
    if args is not None:
 | 
			
		||||
        return args
 | 
			
		||||
 | 
			
		||||
    if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
 | 
			
		||||
        # read yaml file
 | 
			
		||||
        return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
 | 
			
		||||
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
 | 
			
		||||
        # read json file
 | 
			
		||||
        return json.loads(Path(sys.argv[1]).absolute().read_text())
 | 
			
		||||
    else:
 | 
			
		||||
        return {}
 | 
			
		||||
        return sys.argv[1:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_args(
 | 
			
		||||
    parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
 | 
			
		||||
    parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
 | 
			
		||||
) -> Tuple[Any]:
 | 
			
		||||
    args_dict = _read_args(args)
 | 
			
		||||
    args = read_args(args)
 | 
			
		||||
    if isinstance(args, dict):
 | 
			
		||||
        return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
 | 
			
		||||
 | 
			
		||||
    if args_dict:
 | 
			
		||||
        return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
 | 
			
		||||
    else:
 | 
			
		||||
        (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
 | 
			
		||||
            args=args_dict, return_remaining_strings=True
 | 
			
		||||
        )
 | 
			
		||||
    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
 | 
			
		||||
 | 
			
		||||
        if unknown_args:
 | 
			
		||||
            print(parser.format_help())
 | 
			
		||||
            print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
 | 
			
		||||
            raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
 | 
			
		||||
    if unknown_args:
 | 
			
		||||
        print(parser.format_help())
 | 
			
		||||
        print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
 | 
			
		||||
        raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
 | 
			
		||||
 | 
			
		||||
        return (*parsed_args,)
 | 
			
		||||
    return (*parsed_args,)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_transformers_logging() -> None:
 | 
			
		||||
@ -141,7 +122,7 @@ def _verify_model_args(
 | 
			
		||||
def _check_extra_dependencies(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    training_args: Optional["Seq2SeqTrainingArguments"] = None,
 | 
			
		||||
    training_args: Optional["TrainingArguments"] = None,
 | 
			
		||||
) -> None:
 | 
			
		||||
    if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
 | 
			
		||||
        logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
@ -177,31 +158,29 @@ def _check_extra_dependencies(
 | 
			
		||||
        require_version("rouge_chinese", "To fix: pip install rouge-chinese")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    parser = HfArgumentParser(_TRAIN_ARGS)
 | 
			
		||||
    return _parse_args(parser, args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
 | 
			
		||||
    parser = HfArgumentParser(_INFER_ARGS)
 | 
			
		||||
    return _parse_args(parser, args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
 | 
			
		||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
 | 
			
		||||
    parser = HfArgumentParser(_EVAL_ARGS)
 | 
			
		||||
    return _parse_args(parser, args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
 | 
			
		||||
    parser = HfArgumentParser(RayTrainArguments)
 | 
			
		||||
    ray_args = _parse_args(parser, args, allow_extra_keys=True)[0]
 | 
			
		||||
    if ray_args.use_ray:
 | 
			
		||||
        require_version("ray", "To fix: pip install ray")
 | 
			
		||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
 | 
			
		||||
    parser = HfArgumentParser(RayArguments)
 | 
			
		||||
    (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
 | 
			
		||||
    return ray_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
 | 
			
		||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
 | 
			
		||||
 | 
			
		||||
    # Setup logging
 | 
			
		||||
    if training_args.should_log:
 | 
			
		||||
@ -410,7 +389,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    return model_args, data_args, training_args, finetuning_args, generating_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
@ -443,7 +422,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
    return model_args, data_args, finetuning_args, generating_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
 | 
			
		||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
 | 
			
		||||
    model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
 | 
			
		||||
 | 
			
		||||
    _set_transformers_logging()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										48
									
								
								src/llamafactory/hparams/training_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								src/llamafactory/hparams/training_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
			
		||||
import json
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
from transformers.training_args import _convert_str_dict
 | 
			
		||||
 | 
			
		||||
from ..extras.misc import use_ray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RayArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the Ray training.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    ray_run_name: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
 | 
			
		||||
    )
 | 
			
		||||
    ray_num_workers: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
 | 
			
		||||
    )
 | 
			
		||||
    resources_per_worker: Union[dict, str] = field(
 | 
			
		||||
        default_factory=lambda: {"GPU": 1},
 | 
			
		||||
        metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
 | 
			
		||||
    )
 | 
			
		||||
    placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
 | 
			
		||||
        default="PACK",
 | 
			
		||||
        metadata={"help": "The placement strategy for Ray training. Default is PACK."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.use_ray = use_ray()
 | 
			
		||||
        if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
 | 
			
		||||
            self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the trainer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        Seq2SeqTrainingArguments.__post_init__(self)
 | 
			
		||||
        RayArguments.__post_init__(self)
 | 
			
		||||
@ -1,26 +0,0 @@
 | 
			
		||||
from typing import Any, Callable, Dict
 | 
			
		||||
 | 
			
		||||
from ray.train import ScalingConfig
 | 
			
		||||
from ray.train.torch import TorchTrainer
 | 
			
		||||
 | 
			
		||||
from .ray_train_args import RayTrainArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ray_trainer(
 | 
			
		||||
    training_function: Callable,
 | 
			
		||||
    train_loop_config: Dict[str, Any],
 | 
			
		||||
    ray_args: RayTrainArguments,
 | 
			
		||||
) -> TorchTrainer:
 | 
			
		||||
    if not ray_args.use_ray:
 | 
			
		||||
        raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
 | 
			
		||||
 | 
			
		||||
    trainer = TorchTrainer(
 | 
			
		||||
        training_function,
 | 
			
		||||
        train_loop_config=train_loop_config,
 | 
			
		||||
        scaling_config=ScalingConfig(
 | 
			
		||||
            num_workers=ray_args.num_workers,
 | 
			
		||||
            resources_per_worker=ray_args.resources_per_worker,
 | 
			
		||||
            use_gpu=True,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    return trainer
 | 
			
		||||
@ -1,30 +0,0 @@
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, Literal, Optional
 | 
			
		||||
 | 
			
		||||
from .ray_utils import should_use_ray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RayTrainArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the Ray training.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    resources_per_worker: Optional[Dict[str, Any]] = field(
 | 
			
		||||
        default_factory=lambda: {"GPU": 1},
 | 
			
		||||
        metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
 | 
			
		||||
    )
 | 
			
		||||
    num_workers: Optional[int] = field(
 | 
			
		||||
        default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
 | 
			
		||||
    )
 | 
			
		||||
    placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
 | 
			
		||||
        default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def use_ray(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Always returns the value from the environment variable check.
 | 
			
		||||
        This prevents manual setting of use_ray.
 | 
			
		||||
        """
 | 
			
		||||
        return should_use_ray()
 | 
			
		||||
@ -1,5 +0,0 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_ray():
 | 
			
		||||
    return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
 | 
			
		||||
@ -18,7 +18,8 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Mapping
 | 
			
		||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
@ -31,7 +32,7 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..extras.packages import is_galore_available
 | 
			
		||||
from ..extras.packages import is_galore_available, is_ray_available
 | 
			
		||||
from ..hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
 | 
			
		||||
 | 
			
		||||
@ -40,11 +41,16 @@ if is_galore_available():
 | 
			
		||||
    from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_ray_available():
 | 
			
		||||
    from ray.train import RunConfig, ScalingConfig
 | 
			
		||||
    from ray.train.torch import TorchTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
 | 
			
		||||
    from transformers import PreTrainedModel, TrainerCallback
 | 
			
		||||
    from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments
 | 
			
		||||
    from ..hparams import DataArguments, RayArguments, TrainingArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
@ -75,7 +81,7 @@ def create_modelcard_and_push(
 | 
			
		||||
    trainer: "Trainer",
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> None:
 | 
			
		||||
    kwargs = {
 | 
			
		||||
@ -188,7 +194,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
 | 
			
		||||
 | 
			
		||||
def _create_galore_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
 | 
			
		||||
@ -272,7 +278,7 @@ def _create_galore_optimizer(
 | 
			
		||||
 | 
			
		||||
def _create_loraplus_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    default_lr = training_args.learning_rate
 | 
			
		||||
@ -312,7 +318,7 @@ def _create_loraplus_optimizer(
 | 
			
		||||
 | 
			
		||||
def _create_badam_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    decay_params, nodecay_params = [], []
 | 
			
		||||
@ -373,7 +379,7 @@ def _create_badam_optimizer(
 | 
			
		||||
 | 
			
		||||
def _create_adam_mini_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    from adam_mini import Adam_mini  # type: ignore
 | 
			
		||||
 | 
			
		||||
@ -398,7 +404,7 @@ def _create_adam_mini_optimizer(
 | 
			
		||||
 | 
			
		||||
def create_custom_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> Optional["torch.optim.Optimizer"]:
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
@ -415,7 +421,7 @@ def create_custom_optimizer(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_custom_scheduler(
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    num_training_steps: int,
 | 
			
		||||
    optimizer: Optional["torch.optim.Optimizer"] = None,
 | 
			
		||||
) -> None:
 | 
			
		||||
@ -499,3 +505,28 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
 | 
			
		||||
        config={"Framework": "🦙LlamaFactory"},
 | 
			
		||||
    )
 | 
			
		||||
    return swanlab_callback
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ray_trainer(
 | 
			
		||||
    training_function: Callable,
 | 
			
		||||
    train_loop_config: Dict[str, Any],
 | 
			
		||||
    ray_args: "RayArguments",
 | 
			
		||||
) -> "TorchTrainer":
 | 
			
		||||
    if not ray_args.use_ray:
 | 
			
		||||
        raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
 | 
			
		||||
 | 
			
		||||
    trainer = TorchTrainer(
 | 
			
		||||
        training_function,
 | 
			
		||||
        train_loop_config=train_loop_config,
 | 
			
		||||
        scaling_config=ScalingConfig(
 | 
			
		||||
            num_workers=ray_args.ray_num_workers,
 | 
			
		||||
            resources_per_worker=ray_args.resources_per_worker,
 | 
			
		||||
            placement_strategy=ray_args.placement_strategy,
 | 
			
		||||
            use_gpu=True,
 | 
			
		||||
        ),
 | 
			
		||||
        run_config=RunConfig(
 | 
			
		||||
            name=ray_args.ray_run_name,
 | 
			
		||||
            storage_path=Path("./saves").absolute().as_posix(),
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    return trainer
 | 
			
		||||
 | 
			
		||||
@ -22,8 +22,8 @@ from transformers import PreTrainedModel
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ..hparams import get_infer_args, get_train_args
 | 
			
		||||
from ..hparams.parser import _parse_ray_args, _read_args
 | 
			
		||||
from ..extras.packages import is_ray_available
 | 
			
		||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
 | 
			
		||||
from .dpo import run_dpo
 | 
			
		||||
@ -32,7 +32,11 @@ from .ppo import run_ppo
 | 
			
		||||
from .pt import run_pt
 | 
			
		||||
from .rm import run_rm
 | 
			
		||||
from .sft import run_sft
 | 
			
		||||
from .trainer_utils import get_swanlab_callback
 | 
			
		||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_ray_available():
 | 
			
		||||
    from ray.train.huggingface.transformers import RayTrainReportCallback
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -43,10 +47,8 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def training_function(config: Dict[str, Any]) -> None:
 | 
			
		||||
    args = config.get("args", None)
 | 
			
		||||
    callbacks = config.get("callbacks", [])
 | 
			
		||||
 | 
			
		||||
    callbacks.append(LogCallback())
 | 
			
		||||
    args = config.get("args")
 | 
			
		||||
    callbacks: List[Any] = config.get("callbacks")
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.pissa_convert:
 | 
			
		||||
@ -73,31 +75,22 @@ def training_function(config: Dict[str, Any]) -> None:
 | 
			
		||||
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
 | 
			
		||||
    args_dict = _read_args(args)
 | 
			
		||||
    ray_args = _parse_ray_args(args_dict)
 | 
			
		||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
 | 
			
		||||
    callbacks = callbacks or []
 | 
			
		||||
    callbacks.append(LogCallback())
 | 
			
		||||
 | 
			
		||||
    args = read_args(args)
 | 
			
		||||
    ray_args = get_ray_args(args)
 | 
			
		||||
    if ray_args.use_ray:
 | 
			
		||||
        # Import lazily to avoid ray not installed error
 | 
			
		||||
        from ..integrations.ray.ray_train import get_ray_trainer
 | 
			
		||||
 | 
			
		||||
        # Initialize ray trainer
 | 
			
		||||
        callbacks.append(RayTrainReportCallback())
 | 
			
		||||
        trainer = get_ray_trainer(
 | 
			
		||||
            training_function=training_function,
 | 
			
		||||
            train_loop_config={
 | 
			
		||||
                "args": args_dict,
 | 
			
		||||
                "callbacks": callbacks,
 | 
			
		||||
            },
 | 
			
		||||
            train_loop_config={"args": args, "callbacks": callbacks},
 | 
			
		||||
            ray_args=ray_args,
 | 
			
		||||
        )
 | 
			
		||||
        trainer.fit()
 | 
			
		||||
    else:
 | 
			
		||||
        training_function(
 | 
			
		||||
            config={
 | 
			
		||||
                "args": args_dict,
 | 
			
		||||
                "callbacks": callbacks,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        training_function(config={"args": args, "callbacks": callbacks})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user