Former-commit-id: bee1bd43b946501690d70e4980205f9d82404296
This commit is contained in:
hiyouga 2024-08-30 03:21:50 +08:00
parent 913ee05e74
commit 92c398166d
8 changed files with 24 additions and 13 deletions

View File

@ -20,7 +20,7 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.44.3
transformers>=4.41.2,<=4.45.0
datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.33.0
peft>=0.11.1,<=0.12.0

View File

@ -16,6 +16,7 @@ import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
@ -115,13 +116,11 @@ class VllmEngine(BaseEngine):
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(messages, [image], self.processor)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]

View File

@ -68,7 +68,7 @@ class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None
image_grid_thw = None # TODO: better handle various VLMs
if "image_grid_thw" in features[0]:
image_grid_thw_list = [
torch.Tensor(feature["image_grid_thw"]).long()

View File

@ -74,6 +74,9 @@ class BasePlugin:
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
return messages
def process_token_ids(
@ -83,6 +86,9 @@ class BasePlugin:
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
return input_ids, labels
def get_mm_inputs(
@ -91,6 +97,9 @@ class BasePlugin:
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
r"""
Builds batched multimodal inputs for VLMs.
"""
return {}
def process_model_inputs(
@ -100,6 +109,9 @@ class BasePlugin:
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
r"""
Appends multimodal inputs to model inputs for VLMs.
"""
return

View File

@ -84,7 +84,7 @@ def preprocess_feedback_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = defaultdict(list)

View File

@ -70,7 +70,7 @@ def preprocess_pairwise_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]

View File

@ -62,7 +62,7 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):