Merge commit from fork

* fix lfi and ssrf

* move utils to common

---------

Co-authored-by: d3do <chamlinx@outlook.com>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Wu Wenhao 2025-10-07 20:55:29 +08:00 committed by GitHub
parent d5bb4e6394
commit 95b7188090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 1 deletions

View File

@ -26,7 +26,7 @@ from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
ChatCompletionResponse, ChatCompletionResponse,
@ -121,8 +121,10 @@ def _process_request(
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file elif os.path.isfile(image_url): # local file
check_lfi_path(image_url)
image_stream = open(image_url, "rb") image_stream = open(image_url, "rb")
else: # web uri else: # web uri
check_ssrf_url(image_url)
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB")) images.append(Image.open(image_stream).convert("RGB"))
@ -132,8 +134,10 @@ def _process_request(
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file elif os.path.isfile(video_url): # local file
check_lfi_path(video_url)
video_stream = video_url video_stream = video_url
else: # web uri else: # web uri
check_ssrf_url(video_url)
video_stream = requests.get(video_url, stream=True).raw video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream) videos.append(video_stream)
@ -143,8 +147,10 @@ def _process_request(
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file elif os.path.isfile(audio_url): # local file
check_lfi_path(audio_url)
audio_stream = audio_url audio_stream = audio_url
else: # web uri else: # web uri
check_ssrf_url(audio_url)
audio_stream = requests.get(audio_url, stream=True).raw audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream) audios.append(audio_stream)

View File

@ -12,14 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ipaddress
import json import json
import os
import socket
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available
if is_fastapi_available():
from fastapi import HTTPException, status
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
def dictify(data: "BaseModel") -> dict[str, Any]: def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(exclude_unset=True)
@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
def check_lfi_path(path: str) -> None:
"""Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
if not ALLOW_LOCAL_FILES:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
try:
os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
real_path = os.path.realpath(path)
safe_path = os.path.realpath(SAFE_MEDIA_PATH)
if not real_path.startswith(safe_path):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
def check_ssrf_url(url: str) -> None:
"""Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
try:
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
hostname = parsed_url.hostname
if not hostname:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
ip_info = socket.getaddrinfo(hostname, parsed_url.port)
ip_address_str = ip_info[0][4][0]
ip = ipaddress.ip_address(ip_address_str)
if not ip.is_global:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access to private or reserved IP addresses is not allowed.",
)
except socket.gaierror:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")