Merge pull request #5895 from hiyouga/dev

[inference] support multiple images

Former-commit-id: 0a55e60693ab15d92fbe3d7d536408e26228ab82
This commit is contained in:
hoshi-hiyouga 2024-11-01 16:52:55 +08:00 committed by GitHub
commit 25093c2d82
12 changed files with 222 additions and 67 deletions

View File

@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
### Style guide
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
### Create a Pull Request
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
git clone git@github.com:[username]/LLaMA-Factory.git
cd LLaMA-Factory
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
```
3. Create a new branch to hold your development changes:
```bash
git checkout -b dev_your_branch
```
4. Set up a development environment by running the following command in a virtual environment:
```bash
pip install -e ".[dev]"
```
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
5. Check code before commit:
```bash
make commit
make style && make quality
make test
```
6. Submit changes:
```bash
git add .
git commit -m "commit message"
git fetch upstream
git rebase upstream/main
git push -u origin dev_your_branch
```
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).

View File

@ -12,7 +12,7 @@ repos:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: no-commit-to-branch
args: ['--branch', 'master']
args: ['--branch', 'main']
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0

View File

@ -584,6 +584,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
### Download from ModelScope Hub

View File

@ -585,6 +585,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
### 从魔搭社区下载

65
scripts/test_image.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
messages = []
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "Output the color and number of each box."},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 1:", result.choices[0].message.content)
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "What kind of flower is this?"},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 2:", result.choices[0].message.content)
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
if __name__ == "__main__":
main()

View File

@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
images = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -111,10 +111,11 @@ def _process_request(
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB")
images.append(Image.open(image_stream).convert("RGB"))
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
images = None if len(images) == 0 else images
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
@ -124,7 +125,7 @@ def _process_request(
else:
tools = None
return input_messages, system, tools, image
return input_messages, system, tools, images
def _create_stream_chat_completion_chunk(
@ -143,12 +144,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@ -194,7 +195,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
input_messages, system, tools, images = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -208,7 +209,7 @@ async def create_stream_chat_completion_response(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,

View File

@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""

View File

@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
)
return task.result()
@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token
def get_scores(

View File

@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
@ -186,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
@ -222,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
@ -270,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@ -287,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages,
system,
tools,
image,
video,
images,
videos,
input_kwargs,
)
async with self.semaphore:
@ -301,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@ -318,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages,
system,
tools,
image,
video,
images,
videos,
input_kwargs,
)
async with self.semaphore:

View File

@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
if images is not None:
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True,
)
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if images is not None: # add image features
image_data = []
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
image_data.append(image)
multi_modal_data = {"image": image_data}
else:
multi_modal_data = None
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator:
final_output = request_output
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text

View File

@ -226,6 +226,14 @@ class BasePlugin:
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
seqlens: number of tokens in each sample, shape (batch_size,)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
return {}

View File

@ -141,7 +141,14 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = ""
response = ""
for new_text in self.stream_chat(
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
messages,
system,
tools,
images=[image],
videos=[video],
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
):
response += new_text
if tools: