[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -13,7 +13,8 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import pytest
import torch
@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
@@ -96,11 +97,11 @@ def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {},
) -> None:
# test mm_messages
if plugin.__class__.__name__ != "BasePlugin":