Files
llamacpp-ha/src/llamacpp_ha/forwarder.py
2026-05-17 09:54:18 +02:00

161 lines
4.8 KiB
Python

from __future__ import annotations
import logging
from typing import AsyncIterator
import aiohttp
from fastapi import Request
from fastapi.responses import Response, StreamingResponse
from starlette.datastructures import Headers
from .registry import BackendState
from .scheduler import Scheduler
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_STREAM_CHUNK = 8192
_HOP_BY_HOP = frozenset(
[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
"content-encoding",
"content-length",
]
)
def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]:
headers: dict[str, str] = {}
for name, value in request.headers.items():
if name.lower() not in _HOP_BY_HOP and name.lower() != "host":
headers[name] = value
if backend.config.api_key:
headers["Authorization"] = f"Bearer {backend.config.api_key}"
else:
headers.pop("Authorization", None)
headers.pop("authorization", None)
return headers
async def forward_request(
request: Request,
backend: BackendState,
session: aiohttp.ClientSession,
slot_tracker: SlotTracker,
scheduler: Scheduler,
model_id: str = "",
path_override: str | None = None,
) -> Response:
path = path_override or request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
resp_ctx = session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
)
resp = await resp_ctx.__aenter__()
content_type = resp.headers.get("Content-Type", "")
is_streaming = "text/event-stream" in content_type
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
if is_streaming:
async def generate() -> AsyncIterator[bytes]:
try:
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
yield chunk
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return StreamingResponse(
generate(),
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
else:
try:
data = await resp.read()
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Forward error to %s: %s", backend.url, exc)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
raise
async def forward_best_effort(
request: Request,
registry_backends: list[BackendState],
session: aiohttp.ClientSession,
) -> Response:
"""Forward without slot gating to any live backend (catch-all paths)."""
if not registry_backends:
return Response(content="No live backends", status_code=503)
backend = registry_backends[0]
path = request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
async with session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
) as resp:
content_type = resp.headers.get("Content-Type", "")
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
data = await resp.read()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Best-effort forward error: %s", exc)
return Response(content="Backend error", status_code=502)