161 lines
4.8 KiB
Python
161 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import AsyncIterator
|
|
|
|
import aiohttp
|
|
from fastapi import Request
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from starlette.datastructures import Headers
|
|
|
|
from .registry import BackendState
|
|
from .scheduler import Scheduler
|
|
from .slot_tracker import SlotTracker
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_STREAM_CHUNK = 8192
|
|
_HOP_BY_HOP = frozenset(
|
|
[
|
|
"connection",
|
|
"keep-alive",
|
|
"proxy-authenticate",
|
|
"proxy-authorization",
|
|
"te",
|
|
"trailers",
|
|
"transfer-encoding",
|
|
"upgrade",
|
|
"content-encoding",
|
|
"content-length",
|
|
]
|
|
)
|
|
|
|
|
|
def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]:
|
|
headers: dict[str, str] = {}
|
|
for name, value in request.headers.items():
|
|
if name.lower() not in _HOP_BY_HOP and name.lower() != "host":
|
|
headers[name] = value
|
|
|
|
if backend.config.api_key:
|
|
headers["Authorization"] = f"Bearer {backend.config.api_key}"
|
|
else:
|
|
headers.pop("Authorization", None)
|
|
headers.pop("authorization", None)
|
|
|
|
return headers
|
|
|
|
|
|
async def forward_request(
|
|
request: Request,
|
|
backend: BackendState,
|
|
session: aiohttp.ClientSession,
|
|
slot_tracker: SlotTracker,
|
|
scheduler: Scheduler,
|
|
model_id: str = "",
|
|
path_override: str | None = None,
|
|
) -> Response:
|
|
path = path_override or request.url.path
|
|
query = request.url.query
|
|
target = backend.url + path + (f"?{query}" if query else "")
|
|
headers = _forward_headers(request, backend)
|
|
body = await request.body()
|
|
|
|
try:
|
|
resp_ctx = session.request(
|
|
method=request.method,
|
|
url=target,
|
|
headers=headers,
|
|
data=body if body else None,
|
|
allow_redirects=False,
|
|
)
|
|
|
|
resp = await resp_ctx.__aenter__()
|
|
|
|
content_type = resp.headers.get("Content-Type", "")
|
|
is_streaming = "text/event-stream" in content_type
|
|
|
|
response_headers = {
|
|
k: v
|
|
for k, v in resp.headers.items()
|
|
if k.lower() not in _HOP_BY_HOP
|
|
}
|
|
|
|
if is_streaming:
|
|
async def generate() -> AsyncIterator[bytes]:
|
|
try:
|
|
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
|
|
yield chunk
|
|
finally:
|
|
await resp_ctx.__aexit__(None, None, None)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
else:
|
|
try:
|
|
data = await resp.read()
|
|
finally:
|
|
await resp_ctx.__aexit__(None, None, None)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
return Response(
|
|
content=data,
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
|
|
except Exception as exc:
|
|
log.error("Forward error to %s: %s", backend.url, exc)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
raise
|
|
|
|
|
|
async def forward_best_effort(
|
|
request: Request,
|
|
registry_backends: list[BackendState],
|
|
session: aiohttp.ClientSession,
|
|
) -> Response:
|
|
"""Forward without slot gating to any live backend (catch-all paths)."""
|
|
if not registry_backends:
|
|
return Response(content="No live backends", status_code=503)
|
|
|
|
backend = registry_backends[0]
|
|
path = request.url.path
|
|
query = request.url.query
|
|
target = backend.url + path + (f"?{query}" if query else "")
|
|
headers = _forward_headers(request, backend)
|
|
body = await request.body()
|
|
|
|
try:
|
|
async with session.request(
|
|
method=request.method,
|
|
url=target,
|
|
headers=headers,
|
|
data=body if body else None,
|
|
allow_redirects=False,
|
|
) as resp:
|
|
content_type = resp.headers.get("Content-Type", "")
|
|
response_headers = {
|
|
k: v
|
|
for k, v in resp.headers.items()
|
|
if k.lower() not in _HOP_BY_HOP
|
|
}
|
|
data = await resp.read()
|
|
return Response(
|
|
content=data,
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
except Exception as exc:
|
|
log.error("Best-effort forward error: %s", exc)
|
|
return Response(content="Backend error", status_code=502)
|