[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -15,7 +15,7 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
@@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
models: dict[str, dict[DownloadSource, str]],
template: Optional[str] = None,
multimodal: bool = False,
) -> None:

View File

@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler):
r"""
Redirects the logging output to the logging file for LLaMA Board.
"""
r"""Redirect the logging output to the logging file for LLaMA Board."""
def __init__(self, output_dir: str) -> None:
super().__init__()
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports rank0 logging.
"""
r"""A logger that supports rank0 logging."""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
"""
r"""Return the default logging level."""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
def _configure_library_root_logger() -> None:
r"""
Configures root logger using a stdout stream handler with an explicit format.
"""
r"""Configure root logger using a stdout stream handler with an explicit format."""
global _default_handler
with _thread_lock:
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
r"""Add a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
r"""Remove a handler to the root logger."""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)

View File

@@ -17,7 +17,8 @@
import gc
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, Union
import torch
import torch.distributed as dist
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
r"""Compute and store the average and current value."""
def __init__(self):
self.reset()
@@ -75,9 +74,7 @@ class AverageMeter:
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
@@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
@@ -103,10 +98,8 @@ def check_dependencies() -> None:
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""Calculate effective tokens per second."""
effective_token_num = 0
for data in dataset:
if stage == "sft":
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
r"""Return the number of trainable parameters and number of all parameters in the model."""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
r"""Get the number of available GPU or NPU devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
@@ -180,18 +167,14 @@ def get_device_count() -> int:
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
r"""Get logits processor that removes NaN and Inf logits."""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
r"""Check if the path has a tokenized dataset."""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""
Checks if the environment variable is enabled.
"""
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
r"""Cast a torch tensor or a numpy array to a numpy array."""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
r"""Avoid flash attention import error in custom model files."""
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
r"""Collect GPU or NPU memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()

View File

@@ -15,7 +15,7 @@
import json
import math
import os
from typing import Any, Dict, List
from typing import Any
from transformers.trainer import TRAINER_STATE_NAME
@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard."""
if len(scalars) == 0:
return []
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""Plot loss curves in LlamaBoard."""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r"""Plot loss curves and saves the image."""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)