mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-28 09:40:34 +08:00
* fix lfi and ssrf * move utils to common --------- Co-authored-by: d3do <chamlinx@outlook.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
# Copyright 2025 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import ipaddress
|
|
import json
|
|
import os
|
|
import socket
|
|
from typing import TYPE_CHECKING, Any
|
|
from urllib.parse import urlparse
|
|
|
|
from ..extras.misc import is_env_enabled
|
|
from ..extras.packages import is_fastapi_available
|
|
|
|
|
|
if is_fastapi_available():
|
|
from fastapi import HTTPException, status
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from pydantic import BaseModel
|
|
|
|
|
|
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
|
|
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
|
|
|
|
|
|
def dictify(data: "BaseModel") -> dict[str, Any]:
|
|
try: # pydantic v2
|
|
return data.model_dump(exclude_unset=True)
|
|
except AttributeError: # pydantic v1
|
|
return data.dict(exclude_unset=True)
|
|
|
|
|
|
def jsonify(data: "BaseModel") -> str:
|
|
try: # pydantic v2
|
|
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
|
except AttributeError: # pydantic v1
|
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
|
|
|
|
|
def check_lfi_path(path: str) -> None:
|
|
"""Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
|
|
if not ALLOW_LOCAL_FILES:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
|
|
|
|
try:
|
|
os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
|
|
real_path = os.path.realpath(path)
|
|
safe_path = os.path.realpath(SAFE_MEDIA_PATH)
|
|
|
|
if not real_path.startswith(safe_path):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
|
|
)
|
|
except Exception:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
|
|
|
|
|
|
def check_ssrf_url(url: str) -> None:
|
|
"""Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
|
|
try:
|
|
parsed_url = urlparse(url)
|
|
if parsed_url.scheme not in ["http", "https"]:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
|
|
|
|
hostname = parsed_url.hostname
|
|
if not hostname:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
|
|
|
|
ip_info = socket.getaddrinfo(hostname, parsed_url.port)
|
|
ip_address_str = ip_info[0][4][0]
|
|
ip = ipaddress.ip_address(ip_address_str)
|
|
|
|
if not ip.is_global:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access to private or reserved IP addresses is not allowed.",
|
|
)
|
|
|
|
except socket.gaierror:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
|