mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	run style check
Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
This commit is contained in:
		
							parent
							
								
									8683582300
								
							
						
					
					
						commit
						4f31ad997c
					
				@ -9,7 +9,7 @@ finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset_dir: /home/ray/default/lf/data/
 | 
			
		||||
dataset_dir: /home/ray/default/LLaMA-Factory/data/
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
 | 
			
		||||
@ -25,9 +25,10 @@ from .eval.evaluator import run_eval
 | 
			
		||||
from .extras import logging
 | 
			
		||||
from .extras.env import VERSION, print_env
 | 
			
		||||
from .extras.misc import get_device_count
 | 
			
		||||
from .integrations.ray.ray_utils import should_use_ray
 | 
			
		||||
from .train.tuner import export_model, run_exp
 | 
			
		||||
from .webui.interface import run_web_demo, run_web_ui
 | 
			
		||||
from .integrations.ray.ray_utils import should_use_ray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
 | 
			
		||||
@ -15,16 +15,15 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import yaml
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import transformers
 | 
			
		||||
import yaml
 | 
			
		||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.trainer_utils import get_last_checkpoint
 | 
			
		||||
@ -35,21 +34,35 @@ from transformers.utils.versions import require_version
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import CHECKPOINT_NAMES
 | 
			
		||||
from ..extras.misc import check_dependencies, get_current_device
 | 
			
		||||
from ..integrations.ray.ray_train_args import RayTrainArguments
 | 
			
		||||
from .data_args import DataArguments
 | 
			
		||||
from .evaluation_args import EvaluationArguments
 | 
			
		||||
from .finetuning_args import FinetuningArguments
 | 
			
		||||
from .generating_args import GeneratingArguments
 | 
			
		||||
from .model_args import ModelArguments
 | 
			
		||||
 | 
			
		||||
from ..integrations.ray.ray_train_args import RayTrainArguments
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
check_dependencies()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
 | 
			
		||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
 | 
			
		||||
_TRAIN_ARGS = [
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataArguments,
 | 
			
		||||
    Seq2SeqTrainingArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments,
 | 
			
		||||
    RayTrainArguments,
 | 
			
		||||
]
 | 
			
		||||
_TRAIN_CLS = Tuple[
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataArguments,
 | 
			
		||||
    Seq2SeqTrainingArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments,
 | 
			
		||||
    RayTrainArguments,
 | 
			
		||||
]
 | 
			
		||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
 | 
			
		||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
 | 
			
		||||
@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
 | 
			
		||||
 | 
			
		||||
def _parse_args(
 | 
			
		||||
    parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
 | 
			
		||||
) -> Tuple[Any]:
 | 
			
		||||
    args_dict = _read_args(args)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if args_dict:
 | 
			
		||||
        return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
 | 
			
		||||
    else:
 | 
			
		||||
        (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
 | 
			
		||||
        (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
 | 
			
		||||
            args=args_dict, return_remaining_strings=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if unknown_args:
 | 
			
		||||
            print(parser.format_help())
 | 
			
		||||
@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
 | 
			
		||||
            raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
 | 
			
		||||
 | 
			
		||||
        return (*parsed_args,)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_transformers_logging() -> None:
 | 
			
		||||
@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
 | 
			
		||||
 | 
			
		||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    # Setup logging
 | 
			
		||||
    if training_args.should_log:
 | 
			
		||||
        _set_transformers_logging()
 | 
			
		||||
 | 
			
		||||
@ -1,21 +1,19 @@
 | 
			
		||||
 | 
			
		||||
from typing import Any, Callable, Dict
 | 
			
		||||
 | 
			
		||||
from ray.train.torch import TorchTrainer
 | 
			
		||||
from ray.train import ScalingConfig
 | 
			
		||||
from ray.train.torch import TorchTrainer
 | 
			
		||||
 | 
			
		||||
from .ray_train_args import RayTrainArguments
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def get_ray_trainer(
 | 
			
		||||
    training_function: Callable,
 | 
			
		||||
    train_loop_config: Dict[str, Any],
 | 
			
		||||
    ray_args: RayTrainArguments,
 | 
			
		||||
) -> TorchTrainer:
 | 
			
		||||
    
 | 
			
		||||
    if not ray_args.use_ray:
 | 
			
		||||
        raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    trainer = TorchTrainer(
 | 
			
		||||
        training_function,
 | 
			
		||||
        train_loop_config=train_loop_config,
 | 
			
		||||
@ -25,4 +23,4 @@ def get_ray_trainer(
 | 
			
		||||
            use_gpu=True,
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    return trainer
 | 
			
		||||
    return trainer
 | 
			
		||||
 | 
			
		||||
@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional
 | 
			
		||||
 | 
			
		||||
from .ray_utils import should_use_ray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RayTrainArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the Ray training.
 | 
			
		||||
    """
 | 
			
		||||
    resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
 | 
			
		||||
    num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
 | 
			
		||||
    placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
 | 
			
		||||
 | 
			
		||||
    resources_per_worker: Optional[Dict[str, Any]] = field(
 | 
			
		||||
        default_factory=lambda: {"GPU": 1},
 | 
			
		||||
        metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
 | 
			
		||||
    )
 | 
			
		||||
    num_workers: Optional[int] = field(
 | 
			
		||||
        default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
 | 
			
		||||
    )
 | 
			
		||||
    placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
 | 
			
		||||
        default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def use_ray(self) -> bool:
 | 
			
		||||
@ -19,4 +28,3 @@ class RayTrainArguments:
 | 
			
		||||
        This prevents manual setting of use_ray.
 | 
			
		||||
        """
 | 
			
		||||
        return should_use_ray()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,5 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_ray():
 | 
			
		||||
    return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import TrainerCallback
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def training_function(config: Dict[str, Any]) -> None:
 | 
			
		||||
    args = config.get("args", None)
 | 
			
		||||
    callbacks = config.get("callbacks", [])
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    callbacks.append(LogCallback())
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
 | 
			
		||||
 | 
			
		||||
@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None:
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown task: {finetuning_args.stage}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
 | 
			
		||||
    
 | 
			
		||||
    args_dict = _read_args(args)
 | 
			
		||||
    ray_args = _parse_ray_args(args_dict)
 | 
			
		||||
    
 | 
			
		||||
    if ray_args.use_ray: 
 | 
			
		||||
 | 
			
		||||
    if ray_args.use_ray:
 | 
			
		||||
        # Import lazily to avoid ray not installed error
 | 
			
		||||
        from ..integrations.ray.ray_train import get_ray_trainer
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # Initialize ray trainer
 | 
			
		||||
        trainer = get_ray_trainer(
 | 
			
		||||
            training_function=training_function,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user