mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model.
|
||||
"""
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -184,9 +182,7 @@ class BaseModelArguments:
|
||||
|
||||
@dataclass
|
||||
class QuantizationArguments:
|
||||
r"""
|
||||
Arguments pertaining to the quantization method.
|
||||
"""
|
||||
r"""Arguments pertaining to the quantization method."""
|
||||
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
@@ -212,9 +208,7 @@ class QuantizationArguments:
|
||||
|
||||
@dataclass
|
||||
class ProcessorArguments:
|
||||
r"""
|
||||
Arguments pertaining to the image processor.
|
||||
"""
|
||||
r"""Arguments pertaining to the image processor."""
|
||||
|
||||
image_max_pixels: int = field(
|
||||
default=768 * 768,
|
||||
@@ -244,9 +238,7 @@ class ProcessorArguments:
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model export.
|
||||
"""
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -292,9 +284,7 @@ class ExportArguments:
|
||||
|
||||
@dataclass
|
||||
class VllmArguments:
|
||||
r"""
|
||||
Arguments pertaining to the vLLM worker.
|
||||
"""
|
||||
r"""Arguments pertaining to the vLLM worker."""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=4096,
|
||||
@@ -324,8 +314,7 @@ class VllmArguments:
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
@@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
@@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
Reference in New Issue
Block a user