Files
llamacpp-ha/src/llamacpp_ha/forwarder.py
2026-05-18 00:12:57 +02:00

190 lines
5.7 KiB
Python

from __future__ import annotations
import itertools
import logging
from typing import AsyncIterator
import aiohttp
from fastapi import Request
from fastapi.responses import Response, StreamingResponse
from starlette.datastructures import Headers
from .registry import BackendState
from .scheduler import Scheduler
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_STREAM_CHUNK = 8192
_STREAMING_CONTENT_TYPES = (
"text/event-stream",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"audio/webm",
"audio/aac",
)
_HOP_BY_HOP = frozenset(
[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
"content-encoding",
"content-length",
]
)
def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]:
headers: dict[str, str] = {}
for name, value in request.headers.items():
if name.lower() not in _HOP_BY_HOP and name.lower() != "host":
headers[name] = value
if backend.config.api_key:
headers["Authorization"] = f"Bearer {backend.config.api_key}"
else:
headers.pop("Authorization", None)
headers.pop("authorization", None)
return headers
async def forward_request(
request: Request,
backend: BackendState,
session: aiohttp.ClientSession,
slot_tracker: SlotTracker,
scheduler: Scheduler,
model_id: str = "",
path_override: str | None = None,
) -> Response:
path = path_override or request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
resp_ctx = session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
)
resp = await resp_ctx.__aenter__()
content_type = resp.headers.get("Content-Type", "")
is_streaming = any(ct in content_type for ct in _STREAMING_CONTENT_TYPES)
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
if is_streaming:
async def generate() -> AsyncIterator[bytes]:
try:
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
yield chunk
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return StreamingResponse(
generate(),
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
else:
try:
data = await resp.read()
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Forward error to %s: %s", backend.url, exc)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
raise
_best_effort_counter = itertools.count()
async def forward_best_effort(
request: Request,
registry_backends: list[BackendState],
session: aiohttp.ClientSession,
preferred_url: str | None = None,
) -> Response:
"""Forward without slot gating to any live backend (catch-all paths).
Backends are tried in round-robin order; on connection failure the next
backend is attempted until all live backends are exhausted. When
preferred_url is given (session affinity hint) that backend is tried first.
"""
if not registry_backends:
return Response(content="No live backends", status_code=503)
n = len(registry_backends)
start = next(_best_effort_counter) % n
ordered = registry_backends[start:] + registry_backends[:start]
if preferred_url:
preferred = [b for b in ordered if b.url == preferred_url]
rest = [b for b in ordered if b.url != preferred_url]
ordered = preferred + rest
path = request.url.path
query = request.url.query
body = await request.body()
for backend in ordered:
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
try:
async with session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
) as resp:
content_type = resp.headers.get("Content-Type", "")
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
data = await resp.read()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.warning("Best-effort forward to %s failed: %s", backend.url, exc)
continue
return Response(content="All backends failed", status_code=502)