From 92c398166de69e66655ab8534fab05dda2244ec7 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 30 Aug 2024 03:21:50 +0800 Subject: [PATCH] tiny fix Former-commit-id: bee1bd43b946501690d70e4980205f9d82404296 --- src/llamafactory/__init__.py | 2 +- src/llamafactory/chat/vllm_engine.py | 13 ++++++------- src/llamafactory/data/collator.py | 2 +- src/llamafactory/data/mm_plugin.py | 12 ++++++++++++ src/llamafactory/data/processors/feedback.py | 2 +- src/llamafactory/data/processors/pairwise.py | 2 +- src/llamafactory/data/processors/pretrain.py | 2 +- src/llamafactory/data/processors/unsupervised.py | 2 +- 8 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index ed54278f..fde2f568 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -20,7 +20,7 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.44.3 + transformers>=4.41.2,<=4.45.0 datasets>=2.16.0,<=2.21.0 accelerate>=0.30.1,<=0.33.0 peft>=0.11.1,<=0.12.0 diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 8dc7214a..d64f4d25 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -16,6 +16,7 @@ import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from ..extras.misc import get_device_count from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 @@ -115,13 +116,11 @@ class VllmEngine(BaseEngine): ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) - if ( - self.processor is not None - and image is not None - and not hasattr(self.processor, "image_seq_length") - and self.template.image_token not in messages[0]["content"] - ): # llava-like models (TODO: paligemma models) - messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"] + if image is not None: + if IMAGE_PLACEHOLDER not in messages[0]["content"]: + messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages(messages, [image], self.processor) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 0885705a..29bbc9eb 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -68,7 +68,7 @@ class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - image_grid_thw = None + image_grid_thw = None # TODO: better handle various VLMs if "image_grid_thw" in features[0]: image_grid_thw_list = [ torch.Tensor(feature["image_grid_thw"]).long() diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 714f09fb..cd2604d8 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -74,6 +74,9 @@ class BasePlugin: images: Sequence["ImageObject"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: + r""" + Pre-processes input messages before tokenization for VLMs. + """ return messages def process_token_ids( @@ -83,6 +86,9 @@ class BasePlugin: tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: + r""" + Pre-processes token ids after tokenization for VLMs. + """ return input_ids, labels def get_mm_inputs( @@ -91,6 +97,9 @@ class BasePlugin: feature_seqlens: Dict[str, int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Any]: + r""" + Builds batched multimodal inputs for VLMs. + """ return {} def process_model_inputs( @@ -100,6 +109,9 @@ class BasePlugin: feature_seqlens: Dict[str, int], processor: Optional["ProcessorMixin"], ) -> None: + r""" + Appends multimodal inputs to model inputs for VLMs. + """ return diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index c09ef488..826919bc 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -84,7 +84,7 @@ def preprocess_feedback_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs kl_response = examples["response"][::-1] model_inputs = defaultdict(list) diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index fec25783..ad625d33 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -70,7 +70,7 @@ def preprocess_pairwise_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = defaultdict(list) for i in range(len(examples["prompt"])): diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 67d6009b..93422259 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: def preprocess_pretrain_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index cf9ff643..49a29aa6 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -62,7 +62,7 @@ def preprocess_unsupervised_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = defaultdict(list) for i in range(len(examples["prompt"])):