mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	Merge commit from fork
* fix lfi and ssrf * move utils to common --------- Co-authored-by: d3do <chamlinx@outlook.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									d5bb4e6394
								
							
						
					
					
						commit
						95b7188090
					
				@ -26,7 +26,7 @@ from ..extras import logging
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
 | 
			
		||||
from .common import dictify, jsonify
 | 
			
		||||
from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
 | 
			
		||||
from .protocol import (
 | 
			
		||||
    ChatCompletionMessage,
 | 
			
		||||
    ChatCompletionResponse,
 | 
			
		||||
@ -121,8 +121,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url):  # base64 image
 | 
			
		||||
                        image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(image_url):  # local file
 | 
			
		||||
                        check_lfi_path(image_url)
 | 
			
		||||
                        image_stream = open(image_url, "rb")
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(image_url)
 | 
			
		||||
                        image_stream = requests.get(image_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    images.append(Image.open(image_stream).convert("RGB"))
 | 
			
		||||
@ -132,8 +134,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url):  # base64 video
 | 
			
		||||
                        video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(video_url):  # local file
 | 
			
		||||
                        check_lfi_path(video_url)
 | 
			
		||||
                        video_stream = video_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(video_url)
 | 
			
		||||
                        video_stream = requests.get(video_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    videos.append(video_stream)
 | 
			
		||||
@ -143,8 +147,10 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url):  # base64 audio
 | 
			
		||||
                        audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(audio_url):  # local file
 | 
			
		||||
                        check_lfi_path(audio_url)
 | 
			
		||||
                        audio_stream = audio_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        check_ssrf_url(audio_url)
 | 
			
		||||
                        audio_stream = requests.get(audio_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
                    audios.append(audio_stream)
 | 
			
		||||
 | 
			
		||||
@ -12,14 +12,29 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import ipaddress
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import socket
 | 
			
		||||
from typing import TYPE_CHECKING, Any
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
from ..extras.misc import is_env_enabled
 | 
			
		||||
from ..extras.packages import is_fastapi_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_fastapi_available():
 | 
			
		||||
    from fastapi import HTTPException, status
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
 | 
			
		||||
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dictify(data: "BaseModel") -> dict[str, Any]:
 | 
			
		||||
    try:  # pydantic v2
 | 
			
		||||
        return data.model_dump(exclude_unset=True)
 | 
			
		||||
@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
 | 
			
		||||
        return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
 | 
			
		||||
    except AttributeError:  # pydantic v1
 | 
			
		||||
        return data.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_lfi_path(path: str) -> None:
 | 
			
		||||
    """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
 | 
			
		||||
    if not ALLOW_LOCAL_FILES:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
 | 
			
		||||
        real_path = os.path.realpath(path)
 | 
			
		||||
        safe_path = os.path.realpath(SAFE_MEDIA_PATH)
 | 
			
		||||
 | 
			
		||||
        if not real_path.startswith(safe_path):
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
 | 
			
		||||
            )
 | 
			
		||||
    except Exception:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_ssrf_url(url: str) -> None:
 | 
			
		||||
    """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
 | 
			
		||||
    try:
 | 
			
		||||
        parsed_url = urlparse(url)
 | 
			
		||||
        if parsed_url.scheme not in ["http", "https"]:
 | 
			
		||||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
 | 
			
		||||
 | 
			
		||||
        hostname = parsed_url.hostname
 | 
			
		||||
        if not hostname:
 | 
			
		||||
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
 | 
			
		||||
 | 
			
		||||
        ip_info = socket.getaddrinfo(hostname, parsed_url.port)
 | 
			
		||||
        ip_address_str = ip_info[0][4][0]
 | 
			
		||||
        ip = ipaddress.ip_address(ip_address_str)
 | 
			
		||||
 | 
			
		||||
        if not ip.is_global:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=status.HTTP_403_FORBIDDEN,
 | 
			
		||||
                detail="Access to private or reserved IP addresses is not allowed.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    except socket.gaierror:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user