mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
inputs: Dict[str, Union["torch.Tensor", Any]],
|
||||
inputs: dict[str, Union["torch.Tensor", Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
**gen_kwargs,
|
||||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Remove the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def save_predictions(
|
||||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user