mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-02 01:36:02 +08:00
[deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -16,22 +16,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
template: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -39,7 +39,7 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
media_dir: Optional[str] = field(
|
||||
media_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||
)
|
||||
@@ -67,7 +67,7 @@ class DataArguments:
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -79,15 +79,15 @@ class DataArguments:
|
||||
default=1000,
|
||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
preprocessing_num_workers: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
max_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
eval_num_beams: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
@@ -103,7 +103,7 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
packing: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
@@ -111,19 +111,19 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
tool_format: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
default_system: Optional[str] = field(
|
||||
default_system: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override the default system message in the template."},
|
||||
)
|
||||
enable_thinking: Optional[bool] = field(
|
||||
enable_thinking: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
tokenized_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
save_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
freeze_extra_modules: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
||||
class LoraArguments:
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
lora_alpha: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
loraplus_lr_ratio: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
||||
class OFTArguments:
|
||||
r"""Arguments pertaining to the OFT training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
ref_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
ref_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
ref_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
reward_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
reward_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
reward_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
ld_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
badam_start_block: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
badam_switch_interval: int | None = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||
)
|
||||
swanlab_project: Optional[str] = field(
|
||||
swanlab_project: str | None = field(
|
||||
default="llamafactory",
|
||||
metadata={"help": "The project name in SwanLab."},
|
||||
)
|
||||
swanlab_workspace: Optional[str] = field(
|
||||
swanlab_workspace: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The workspace name in SwanLab."},
|
||||
)
|
||||
swanlab_run_name: Optional[str] = field(
|
||||
swanlab_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The experiment name in SwanLab."},
|
||||
)
|
||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
||||
default="cloud",
|
||||
metadata={"help": "The mode of SwanLab."},
|
||||
)
|
||||
swanlab_api_key: Optional[str] = field(
|
||||
swanlab_api_key: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The API key for SwanLab."},
|
||||
)
|
||||
swanlab_logdir: Optional[str] = field(
|
||||
swanlab_logdir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The log directory for SwanLab."},
|
||||
)
|
||||
swanlab_lark_webhook_url: Optional[str] = field(
|
||||
swanlab_lark_webhook_url: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||
)
|
||||
swanlab_lark_secret: Optional[str] = field(
|
||||
swanlab_lark_secret: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
@@ -510,7 +510,7 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||
)
|
||||
early_stopping_steps: Optional[int] = field(
|
||||
early_stopping_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||
)
|
||||
@@ -530,11 +530,11 @@ class FinetuningArguments(
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||
},
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
adapter_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -50,11 +49,11 @@ class BaseModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
adapter_folder: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
@@ -70,17 +69,17 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
add_tokens: Optional[str] = field(
|
||||
add_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||
},
|
||||
)
|
||||
add_special_tokens: Optional[str] = field(
|
||||
add_special_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
new_special_tokens_config: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -110,7 +109,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
rope_scaling: RopeScaling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
@@ -122,7 +121,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
@@ -138,7 +137,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
moe_aux_loss_coef: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
)
|
||||
@@ -182,15 +181,15 @@ class BaseModelArguments:
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
hf_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
ms_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
om_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class QuantizationArguments:
|
||||
default=QuantizationMethod.BNB,
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||
)
|
||||
@@ -295,7 +294,7 @@ class QuantizationArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
quantization_device_map: Literal["auto"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
@@ -375,7 +374,7 @@ class ProcessorArguments:
|
||||
class ExportArguments:
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
export_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
@@ -387,11 +386,11 @@ class ExportArguments:
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
export_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
export_quantization_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||
)
|
||||
@@ -407,7 +406,7 @@ class ExportArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
export_hub_model_id: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||
)
|
||||
@@ -437,7 +436,7 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
vllm_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -463,7 +462,7 @@ class SGLangArguments:
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
sglang_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -487,21 +486,21 @@ class KTransformersArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
kt_optimize_rule: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
)
|
||||
chunk_size: Optional[int] = field(
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
)
|
||||
@@ -539,17 +538,17 @@ class ModelArguments(
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
compute_dtype: torch.dtype | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
device_map: str | dict[str, Any] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
model_max_length: int | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -65,7 +65,7 @@ else:
|
||||
_TRAIN_MCA_CLS = tuple()
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
|
||||
finetuning_args.use_mca = True
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
if is_env_enabled("USE_MCA"):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||
else:
|
||||
@@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,7 +40,7 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class RayArguments:
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
@@ -56,7 +56,7 @@ class RayArguments:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: Union[dict, str] = field(
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
@@ -64,7 +64,7 @@ class RayArguments:
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user