mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 12:42:48 +08:00
Summary: The Strawberry GraphQL library recently disabled multipart requests by default. This resulted in a video upload request returning "Unsupported content type" instead of uploading the video, processing it, and returning the video path. This change enables multipart request support on the endpoint view. Test Plan: Tested locally and upload succeeds
141 lines
4.0 KiB
Python
141 lines
4.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from typing import Any, Generator
|
|
|
|
from app_conf import (
|
|
GALLERY_PATH,
|
|
GALLERY_PREFIX,
|
|
POSTERS_PATH,
|
|
POSTERS_PREFIX,
|
|
UPLOADS_PATH,
|
|
UPLOADS_PREFIX,
|
|
)
|
|
from data.loader import preload_data
|
|
from data.schema import schema
|
|
from data.store import set_videos
|
|
from flask import Flask, make_response, Request, request, Response, send_from_directory
|
|
from flask_cors import CORS
|
|
from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
|
|
from inference.multipart import MultipartResponseBuilder
|
|
from inference.predictor import InferenceAPI
|
|
from strawberry.flask.views import GraphQLView
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = Flask(__name__)
|
|
cors = CORS(app, supports_credentials=True)
|
|
|
|
videos = preload_data()
|
|
set_videos(videos)
|
|
|
|
inference_api = InferenceAPI()
|
|
|
|
|
|
@app.route("/healthy")
|
|
def healthy() -> Response:
|
|
return make_response("OK", 200)
|
|
|
|
|
|
@app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
|
|
def send_gallery_video(path: str) -> Response:
|
|
try:
|
|
return send_from_directory(
|
|
GALLERY_PATH,
|
|
path,
|
|
)
|
|
except:
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
@app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
|
|
def send_poster_image(path: str) -> Response:
|
|
try:
|
|
return send_from_directory(
|
|
POSTERS_PATH,
|
|
path,
|
|
)
|
|
except:
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
@app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
|
|
def send_uploaded_video(path: str):
|
|
try:
|
|
return send_from_directory(
|
|
UPLOADS_PATH,
|
|
path,
|
|
)
|
|
except:
|
|
raise ValueError("resource not found")
|
|
|
|
|
|
# TOOD: Protect route with ToS permission check
|
|
@app.route("/propagate_in_video", methods=["POST"])
|
|
def propagate_in_video() -> Response:
|
|
data = request.json
|
|
args = {
|
|
"session_id": data["session_id"],
|
|
"start_frame_index": data.get("start_frame_index", 0),
|
|
}
|
|
|
|
boundary = "frame"
|
|
frame = gen_track_with_mask_stream(boundary, **args)
|
|
return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
|
|
|
|
|
|
def gen_track_with_mask_stream(
|
|
boundary: str,
|
|
session_id: str,
|
|
start_frame_index: int,
|
|
) -> Generator[bytes, None, None]:
|
|
with inference_api.autocast_context():
|
|
request = PropagateInVideoRequest(
|
|
type="propagate_in_video",
|
|
session_id=session_id,
|
|
start_frame_index=start_frame_index,
|
|
)
|
|
|
|
for chunk in inference_api.propagate_in_video(request=request):
|
|
yield MultipartResponseBuilder.build(
|
|
boundary=boundary,
|
|
headers={
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Frame-Current": "-1",
|
|
# Total frames minus the reference frame
|
|
"Frame-Total": "-1",
|
|
"Mask-Type": "RLE[]",
|
|
},
|
|
body=chunk.to_json().encode("UTF-8"),
|
|
).get_message()
|
|
|
|
|
|
class MyGraphQLView(GraphQLView):
|
|
def get_context(self, request: Request, response: Response) -> Any:
|
|
return {"inference_api": inference_api}
|
|
|
|
|
|
# Add GraphQL route to Flask app.
|
|
app.add_url_rule(
|
|
"/graphql",
|
|
view_func=MyGraphQLView.as_view(
|
|
"graphql_view",
|
|
schema=schema,
|
|
# Disable GET queries
|
|
# https://strawberry.rocks/docs/operations/deployment
|
|
# https://strawberry.rocks/docs/integrations/flask
|
|
allow_queries_via_get=False,
|
|
# Strawberry recently changed multipart request handling, which now
|
|
# requires enabling support explicitly for views.
|
|
# https://github.com/strawberry-graphql/strawberry/issues/3655
|
|
multipart_uploads_enabled=True,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=5000)
|