mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-06 19:56:01 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
@@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
@@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
@dataclass
|
||||
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
|
||||
Reference in New Issue
Block a user