run style check

Former-commit-id: 1e8e7be0a535e55888f58bbe2c38bc1c382e9012
This commit is contained in:
Eric Tang 2025-01-06 23:55:56 +00:00 committed by hiyouga
parent 1217240918
commit bba52e258e
7 changed files with 54 additions and 35 deletions

View File

@ -9,7 +9,7 @@ finetuning_type: lora
lora_target: all lora_target: all
### dataset ### dataset
dataset_dir: /home/ray/default/lf/data/ dataset_dir: /home/ray/default/LLaMA-Factory/data/
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 2048 cutoff_len: 2048

View File

@ -25,9 +25,10 @@ from .eval.evaluator import run_eval
from .extras import logging from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .integrations.ray.ray_utils import should_use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
from .integrations.ray.ray_utils import should_use_ray
USAGE = ( USAGE = (
"-" * 70 "-" * 70

View File

@ -15,16 +15,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple
import json
import yaml
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch import torch
import transformers import transformers
import yaml
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
@ -35,21 +34,35 @@ from transformers.utils.versions import require_version
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from ..integrations.ray.ray_train_args import RayTrainArguments
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from ..integrations.ray.ray_train_args import RayTrainArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] _TRAIN_ARGS = [
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments] ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_TRAIN_CLS = Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
RayTrainArguments,
]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return {} return {}
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]: def _parse_args(
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args_dict = _read_args(args) args_dict = _read_args(args)
if args_dict: if args_dict:
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys) return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
else: else:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True) (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
args=args_dict, return_remaining_strings=True
)
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
_set_transformers_logging() _set_transformers_logging()

View File

@ -1,21 +1,19 @@
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from .ray_train_args import RayTrainArguments from .ray_train_args import RayTrainArguments
def get_ray_trainer( def get_ray_trainer(
training_function: Callable, training_function: Callable,
train_loop_config: Dict[str, Any], train_loop_config: Dict[str, Any],
ray_args: RayTrainArguments, ray_args: RayTrainArguments,
) -> TorchTrainer: ) -> TorchTrainer:
if not ray_args.use_ray: if not ray_args.use_ray:
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.") raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
trainer = TorchTrainer( trainer = TorchTrainer(
training_function, training_function,
train_loop_config=train_loop_config, train_loop_config=train_loop_config,
@ -25,4 +23,4 @@ def get_ray_trainer(
use_gpu=True, use_gpu=True,
), ),
) )
return trainer return trainer

View File

@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional
from .ray_utils import should_use_ray from .ray_utils import should_use_ray
@dataclass @dataclass
class RayTrainArguments: class RayTrainArguments:
r""" r"""
Arguments pertaining to the Ray training. Arguments pertaining to the Ray training.
""" """
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}) resources_per_worker: Optional[Dict[str, Any]] = field(
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}) default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
num_workers: Optional[int] = field(
default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
)
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
)
@property @property
def use_ray(self) -> bool: def use_ray(self) -> bool:
@ -19,4 +28,3 @@ class RayTrainArguments:
This prevents manual setting of use_ray. This prevents manual setting of use_ray.
""" """
return should_use_ray() return should_use_ray()

View File

@ -1,9 +1,5 @@
import os import os
def should_use_ray(): def should_use_ray():
return os.getenv("USE_RAY", "0").lower() in ["true", "1"] return os.getenv("USE_RAY", "0").lower() in ["true", "1"]

View File

@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def training_function(config: Dict[str, Any]) -> None: def training_function(config: Dict[str, Any]) -> None:
args = config.get("args", None) args = config.get("args", None)
callbacks = config.get("callbacks", []) callbacks = config.get("callbacks", [])
callbacks.append(LogCallback()) callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None:
else: else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
args_dict = _read_args(args) args_dict = _read_args(args)
ray_args = _parse_ray_args(args_dict) ray_args = _parse_ray_args(args_dict)
if ray_args.use_ray: if ray_args.use_ray:
# Import lazily to avoid ray not installed error # Import lazily to avoid ray not installed error
from ..integrations.ray.ray_train import get_ray_trainer from ..integrations.ray.ray_train import get_ray_trainer
# Initialize ray trainer # Initialize ray trainer
trainer = get_ray_trainer( trainer = get_ray_trainer(
training_function=training_function, training_function=training_function,