diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f6748883..6a174838 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: path: Optional[str] bytes: Optional[bytes] - ImageInput = Union[str, EncodedImage, ImageObject] + ImageInput = Union[str, bytes, EncodedImage, ImageObject] VideoInput = str @@ -104,6 +104,8 @@ class BasePlugin: for image in images: if isinstance(image, str): image = Image.open(image) + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)) elif isinstance(image, dict): if image["bytes"] is not None: image = Image.open(BytesIO(image["bytes"]))