diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index c0d87196..93236c5c 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -26,7 +26,7 @@ from ..extras import logging from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.misc import is_env_enabled from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available -from .common import dictify, jsonify +from .common import check_lfi_path, check_ssrf_url, dictify, jsonify from .protocol import ( ChatCompletionMessage, ChatCompletionResponse, @@ -121,8 +121,10 @@ def _process_request( if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) elif os.path.isfile(image_url): # local file + check_lfi_path(image_url) image_stream = open(image_url, "rb") else: # web uri + check_ssrf_url(image_url) image_stream = requests.get(image_url, stream=True).raw images.append(Image.open(image_stream).convert("RGB")) @@ -132,8 +134,10 @@ def _process_request( if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) elif os.path.isfile(video_url): # local file + check_lfi_path(video_url) video_stream = video_url else: # web uri + check_ssrf_url(video_url) video_stream = requests.get(video_url, stream=True).raw videos.append(video_stream) @@ -143,8 +147,10 @@ def _process_request( if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) elif os.path.isfile(audio_url): # local file + check_lfi_path(audio_url) audio_stream = audio_url else: # web uri + check_ssrf_url(audio_url) audio_stream = requests.get(audio_url, stream=True).raw audios.append(audio_stream) diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py index f4d0c2fb..7b4e9602 100644 --- a/src/llamafactory/api/common.py +++ b/src/llamafactory/api/common.py @@ -12,14 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ipaddress import json +import os +import socket from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +from ..extras.misc import is_env_enabled +from ..extras.packages import is_fastapi_available + + +if is_fastapi_available(): + from fastapi import HTTPException, status if TYPE_CHECKING: from pydantic import BaseModel +SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media")) +ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1") + + def dictify(data: "BaseModel") -> dict[str, Any]: try: # pydantic v2 return data.model_dump(exclude_unset=True) @@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str: return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) except AttributeError: # pydantic v1 return data.json(exclude_unset=True, ensure_ascii=False) + + +def check_lfi_path(path: str) -> None: + """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe.""" + if not ALLOW_LOCAL_FILES: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.") + + try: + os.makedirs(SAFE_MEDIA_PATH, exist_ok=True) + real_path = os.path.realpath(path) + safe_path = os.path.realpath(SAFE_MEDIA_PATH) + + if not real_path.startswith(safe_path): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory." + ) + except Exception: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.") + + +def check_ssrf_url(url: str) -> None: + """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe.""" + try: + parsed_url = urlparse(url) + if parsed_url.scheme not in ["http", "https"]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.") + + hostname = parsed_url.hostname + if not hostname: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.") + + ip_info = socket.getaddrinfo(hostname, parsed_url.port) + ip_address_str = ip_info[0][4][0] + ip = ipaddress.ip_address(ip_address_str) + + if not ip.is_global: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access to private or reserved IP addresses is not allowed.", + ) + + except socket.gaierror: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}" + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")