mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-22 23:00:36 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Redirects the logging output to the logging file for LLaMA Board.
|
||||
"""
|
||||
r"""Redirect the logging output to the logging file for LLaMA Board."""
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports rank0 logging.
|
||||
"""
|
||||
r"""A logger that supports rank0 logging."""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Returns the default logging level.
|
||||
"""
|
||||
r"""Return the default logging level."""
|
||||
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""
|
||||
Configures root logger using a stdout stream handler with an explicit format.
|
||||
"""
|
||||
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""
|
||||
Adds a handler to the root logger.
|
||||
"""
|
||||
r"""Add a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""
|
||||
Removes a handler to the root logger.
|
||||
"""
|
||||
r"""Remove a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
r"""Compute and store the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -75,9 +74,7 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
@@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.2.0")
|
||||
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||
@@ -103,10 +98,8 @@ def check_dependencies() -> None:
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""Calculate effective tokens per second."""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
||||
r"""Return the number of trainable parameters and number of all parameters in the model."""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
r"""Get the number of available GPU or NPU devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
@@ -180,18 +167,14 @@ def get_device_count() -> int:
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
r"""Get logits processor that removes NaN and Inf logits."""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_peak_memory() -> Tuple[int, int]:
|
||||
r"""
|
||||
Gets the peak memory usage for the current device (in Bytes).
|
||||
"""
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
r"""Check if the path has a tokenized dataset."""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""
|
||||
Checks if the GPU or NPU is available.
|
||||
"""
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""
|
||||
Checks if the environment variable is enabled.
|
||||
"""
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
r"""Cast a torch tensor or a numpy array to a numpy array."""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
r"""Avoid flash attention import error in custom model files."""
|
||||
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
r"""Collect GPU or NPU memory."""
|
||||
gc.collect()
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@@ -31,10 +31,8 @@ if is_matplotlib_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
def smooth(scalars: list[float]) -> list[float]:
|
||||
r"""EMA implementation according to TensorBoard."""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""Plot loss curves in LlamaBoard."""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
|
||||
r"""Plot loss curves and saves the image."""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
Reference in New Issue
Block a user