[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent bcd287848c
commit efa86e730c
113 changed files with 984 additions and 1407 deletions

View File

@@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
@@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def __init__(
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[Dict[str, Any]] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
@@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def prediction_step(
self,
model: "torch.nn.Module",
inputs: Dict[str, Union["torch.Tensor", Any]],
inputs: dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
ignore_keys: Optional[list[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
@@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""