190 lines
5.7 KiB
Python
190 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import itertools
|
|
import logging
|
|
from typing import AsyncIterator
|
|
|
|
import aiohttp
|
|
from fastapi import Request
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from starlette.datastructures import Headers
|
|
|
|
from .registry import BackendState
|
|
from .scheduler import Scheduler
|
|
from .slot_tracker import SlotTracker
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_STREAM_CHUNK = 8192
|
|
_STREAMING_CONTENT_TYPES = (
|
|
"text/event-stream",
|
|
"audio/mpeg",
|
|
"audio/ogg",
|
|
"audio/wav",
|
|
"audio/webm",
|
|
"audio/aac",
|
|
)
|
|
_HOP_BY_HOP = frozenset(
|
|
[
|
|
"connection",
|
|
"keep-alive",
|
|
"proxy-authenticate",
|
|
"proxy-authorization",
|
|
"te",
|
|
"trailers",
|
|
"transfer-encoding",
|
|
"upgrade",
|
|
"content-encoding",
|
|
"content-length",
|
|
]
|
|
)
|
|
|
|
|
|
def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]:
|
|
headers: dict[str, str] = {}
|
|
for name, value in request.headers.items():
|
|
if name.lower() not in _HOP_BY_HOP and name.lower() != "host":
|
|
headers[name] = value
|
|
|
|
if backend.config.api_key:
|
|
headers["Authorization"] = f"Bearer {backend.config.api_key}"
|
|
else:
|
|
headers.pop("Authorization", None)
|
|
headers.pop("authorization", None)
|
|
|
|
return headers
|
|
|
|
|
|
async def forward_request(
|
|
request: Request,
|
|
backend: BackendState,
|
|
session: aiohttp.ClientSession,
|
|
slot_tracker: SlotTracker,
|
|
scheduler: Scheduler,
|
|
model_id: str = "",
|
|
path_override: str | None = None,
|
|
) -> Response:
|
|
path = path_override or request.url.path
|
|
query = request.url.query
|
|
target = backend.url + path + (f"?{query}" if query else "")
|
|
headers = _forward_headers(request, backend)
|
|
body = await request.body()
|
|
|
|
try:
|
|
resp_ctx = session.request(
|
|
method=request.method,
|
|
url=target,
|
|
headers=headers,
|
|
data=body if body else None,
|
|
allow_redirects=False,
|
|
)
|
|
|
|
resp = await resp_ctx.__aenter__()
|
|
|
|
content_type = resp.headers.get("Content-Type", "")
|
|
is_streaming = any(ct in content_type for ct in _STREAMING_CONTENT_TYPES)
|
|
|
|
response_headers = {
|
|
k: v
|
|
for k, v in resp.headers.items()
|
|
if k.lower() not in _HOP_BY_HOP
|
|
}
|
|
|
|
if is_streaming:
|
|
async def generate() -> AsyncIterator[bytes]:
|
|
try:
|
|
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
|
|
yield chunk
|
|
finally:
|
|
await resp_ctx.__aexit__(None, None, None)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
else:
|
|
try:
|
|
data = await resp.read()
|
|
finally:
|
|
await resp_ctx.__aexit__(None, None, None)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
return Response(
|
|
content=data,
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
|
|
except Exception as exc:
|
|
log.error("Forward error to %s: %s", backend.url, exc)
|
|
await slot_tracker.release(backend.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
raise
|
|
|
|
|
|
_best_effort_counter = itertools.count()
|
|
|
|
|
|
async def forward_best_effort(
|
|
request: Request,
|
|
registry_backends: list[BackendState],
|
|
session: aiohttp.ClientSession,
|
|
preferred_url: str | None = None,
|
|
) -> Response:
|
|
"""Forward without slot gating to any live backend (catch-all paths).
|
|
|
|
Backends are tried in round-robin order; on connection failure the next
|
|
backend is attempted until all live backends are exhausted. When
|
|
preferred_url is given (session affinity hint) that backend is tried first.
|
|
"""
|
|
if not registry_backends:
|
|
return Response(content="No live backends", status_code=503)
|
|
|
|
n = len(registry_backends)
|
|
start = next(_best_effort_counter) % n
|
|
ordered = registry_backends[start:] + registry_backends[:start]
|
|
|
|
if preferred_url:
|
|
preferred = [b for b in ordered if b.url == preferred_url]
|
|
rest = [b for b in ordered if b.url != preferred_url]
|
|
ordered = preferred + rest
|
|
|
|
path = request.url.path
|
|
query = request.url.query
|
|
body = await request.body()
|
|
|
|
for backend in ordered:
|
|
target = backend.url + path + (f"?{query}" if query else "")
|
|
headers = _forward_headers(request, backend)
|
|
try:
|
|
async with session.request(
|
|
method=request.method,
|
|
url=target,
|
|
headers=headers,
|
|
data=body if body else None,
|
|
allow_redirects=False,
|
|
) as resp:
|
|
content_type = resp.headers.get("Content-Type", "")
|
|
response_headers = {
|
|
k: v
|
|
for k, v in resp.headers.items()
|
|
if k.lower() not in _HOP_BY_HOP
|
|
}
|
|
data = await resp.read()
|
|
return Response(
|
|
content=data,
|
|
status_code=resp.status,
|
|
headers=response_headers,
|
|
media_type=content_type,
|
|
)
|
|
except Exception as exc:
|
|
log.warning("Best-effort forward to %s failed: %s", backend.url, exc)
|
|
continue
|
|
|
|
return Response(content="All backends failed", status_code=502)
|