diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 3ac04982..eeed9a29 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -86,12 +86,12 @@ class HuggingfaceEngine(BaseEngine): mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} if images is not None: mm_input_dict.update({"images": images, "imglens": [len(images)]}) - if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages): + if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] if videos is not None: mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) - if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages): + if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] messages = template.mm_plugin.process_messages( diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 37feccc2..5f6612be 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -107,7 +107,7 @@ class VllmEngine(BaseEngine): ) -> AsyncIterator["RequestOutput"]: request_id = f"chatcmpl-{uuid.uuid4().hex}" if images is not None: - if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages): + if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f6748883..6a174838 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: path: Optional[str] bytes: Optional[bytes] - ImageInput = Union[str, EncodedImage, ImageObject] + ImageInput = Union[str, bytes, EncodedImage, ImageObject] VideoInput = str @@ -104,6 +104,8 @@ class BasePlugin: for image in images: if isinstance(image, str): image = Image.open(image) + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)) elif isinstance(image, dict): if image["bytes"] is not None: image = Image.open(BytesIO(image["bytes"]))