mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
run style check
Former-commit-id: 1e8e7be0a535e55888f58bbe2c38bc1c382e9012
This commit is contained in:
parent
1217240918
commit
bba52e258e
@ -9,7 +9,7 @@ finetuning_type: lora
|
|||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset_dir: /home/ray/default/lf/data/
|
dataset_dir: /home/ray/default/LLaMA-Factory/data/
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
|
@ -25,9 +25,10 @@ from .eval.evaluator import run_eval
|
|||||||
from .extras import logging
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.misc import get_device_count
|
from .extras.misc import get_device_count
|
||||||
|
from .integrations.ray.ray_utils import should_use_ray
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
from .integrations.ray.ray_utils import should_use_ray
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
"-" * 70
|
"-" * 70
|
||||||
|
@ -15,16 +15,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import json
|
|
||||||
import yaml
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import yaml
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
@ -35,21 +34,35 @@ from transformers.utils.versions import require_version
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES
|
from ..extras.constants import CHECKPOINT_NAMES
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
|
from ..integrations.ray.ray_train_args import RayTrainArguments
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|
||||||
from ..integrations.ray.ray_train_args import RayTrainArguments
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
_TRAIN_ARGS = [
|
||||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments, RayTrainArguments]
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
RayTrainArguments,
|
||||||
|
]
|
||||||
|
_TRAIN_CLS = Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
RayTrainArguments,
|
||||||
|
]
|
||||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
@ -70,14 +83,17 @@ def _read_args(args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False) -> Tuple[Any]:
|
def _parse_args(
|
||||||
|
parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None, allow_extra_keys: bool = False
|
||||||
|
) -> Tuple[Any]:
|
||||||
args_dict = _read_args(args)
|
args_dict = _read_args(args)
|
||||||
|
|
||||||
if args_dict:
|
if args_dict:
|
||||||
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
return parser.parse_dict(args_dict, allow_extra_keys=allow_extra_keys)
|
||||||
else:
|
else:
|
||||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args_dict, return_remaining_strings=True)
|
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(
|
||||||
|
args=args_dict, return_remaining_strings=True
|
||||||
|
)
|
||||||
|
|
||||||
if unknown_args:
|
if unknown_args:
|
||||||
print(parser.format_help())
|
print(parser.format_help())
|
||||||
@ -85,7 +101,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
return (*parsed_args,)
|
return (*parsed_args,)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging() -> None:
|
def _set_transformers_logging() -> None:
|
||||||
@ -187,7 +202,7 @@ def _parse_ray_args(args: Optional[Dict[str, Any]] = None) -> RayTrainArguments:
|
|||||||
|
|
||||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args, _ = _parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
@ -1,21 +1,19 @@
|
|||||||
|
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
from ray.train.torch import TorchTrainer
|
|
||||||
from ray.train import ScalingConfig
|
from ray.train import ScalingConfig
|
||||||
|
from ray.train.torch import TorchTrainer
|
||||||
|
|
||||||
from .ray_train_args import RayTrainArguments
|
from .ray_train_args import RayTrainArguments
|
||||||
|
|
||||||
|
|
||||||
def get_ray_trainer(
|
def get_ray_trainer(
|
||||||
training_function: Callable,
|
training_function: Callable,
|
||||||
train_loop_config: Dict[str, Any],
|
train_loop_config: Dict[str, Any],
|
||||||
ray_args: RayTrainArguments,
|
ray_args: RayTrainArguments,
|
||||||
) -> TorchTrainer:
|
) -> TorchTrainer:
|
||||||
|
|
||||||
if not ray_args.use_ray:
|
if not ray_args.use_ray:
|
||||||
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
raise ValueError("Ray is not enabled. Please set USE_RAY=1 in your environment.")
|
||||||
|
|
||||||
trainer = TorchTrainer(
|
trainer = TorchTrainer(
|
||||||
training_function,
|
training_function,
|
||||||
train_loop_config=train_loop_config,
|
train_loop_config=train_loop_config,
|
||||||
@ -25,4 +23,4 @@ def get_ray_trainer(
|
|||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
@ -3,14 +3,23 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
|
|
||||||
from .ray_utils import should_use_ray
|
from .ray_utils import should_use_ray
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RayTrainArguments:
|
class RayTrainArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the Ray training.
|
Arguments pertaining to the Ray training.
|
||||||
"""
|
"""
|
||||||
resources_per_worker: Optional[Dict[str, Any]] = field(default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."})
|
|
||||||
num_workers: Optional[int] = field(default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."})
|
resources_per_worker: Optional[Dict[str, Any]] = field(
|
||||||
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."})
|
default_factory=lambda: {"GPU": 1},
|
||||||
|
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||||
|
)
|
||||||
|
num_workers: Optional[int] = field(
|
||||||
|
default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}
|
||||||
|
)
|
||||||
|
placement_strategy: Optional[Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"]] = field(
|
||||||
|
default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_ray(self) -> bool:
|
def use_ray(self) -> bool:
|
||||||
@ -19,4 +28,3 @@ class RayTrainArguments:
|
|||||||
This prevents manual setting of use_ray.
|
This prevents manual setting of use_ray.
|
||||||
"""
|
"""
|
||||||
return should_use_ray()
|
return should_use_ray()
|
||||||
|
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def should_use_ray():
|
def should_use_ray():
|
||||||
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,14 +37,15 @@ from .trainer_utils import get_swanlab_callback
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def training_function(config: Dict[str, Any]) -> None:
|
def training_function(config: Dict[str, Any]) -> None:
|
||||||
args = config.get("args", None)
|
args = config.get("args", None)
|
||||||
callbacks = config.get("callbacks", [])
|
callbacks = config.get("callbacks", [])
|
||||||
|
|
||||||
callbacks.append(LogCallback())
|
callbacks.append(LogCallback())
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||||
|
|
||||||
@ -71,15 +72,15 @@ def training_function(config: Dict[str, Any]) -> None:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||||
|
|
||||||
args_dict = _read_args(args)
|
args_dict = _read_args(args)
|
||||||
ray_args = _parse_ray_args(args_dict)
|
ray_args = _parse_ray_args(args_dict)
|
||||||
|
|
||||||
if ray_args.use_ray:
|
if ray_args.use_ray:
|
||||||
# Import lazily to avoid ray not installed error
|
# Import lazily to avoid ray not installed error
|
||||||
from ..integrations.ray.ray_train import get_ray_trainer
|
from ..integrations.ray.ray_train import get_ray_trainer
|
||||||
|
|
||||||
# Initialize ray trainer
|
# Initialize ray trainer
|
||||||
trainer = get_ray_trainer(
|
trainer = get_ray_trainer(
|
||||||
training_function=training_function,
|
training_function=training_function,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user