Files
llamacpp-ha/src/llamacpp_ha/proxy.py
2026-05-18 00:34:27 +02:00

435 lines
14 KiB
Python

from __future__ import annotations
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
import aiohttp
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from .config import ProxyConfig
from .forwarder import forward_best_effort, forward_request
from .middleware import ApiKeyMiddleware
from .monitor import ProxyStats, build_router as build_monitor_router
from .policies import RoundRobinPolicy
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry, BackendState
from .scheduler import Scheduler
from .session_store import SessionStore
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_APPLICATION_JSON = "application/json"
_SESSION_COOKIE = "x-llm-session"
_SESSION_HEADER = "X-Session-ID"
_SLOT_GATED_PATHS = [
"/v1/chat/completions",
"/v1/completions",
"/v1/embeddings",
"/v1/images/generations",
"/v1/images/edits",
"/v1/images/variations",
"/v1/audio/speech",
"/v1/audio/transcriptions",
"/v1/audio/translations",
]
class _HttpSession:
"""Mutable holder for the shared aiohttp session, created inside lifespan."""
__slots__ = ("client",)
def __init__(self) -> None:
self.client: aiohttp.ClientSession | None = None
# ------------------------------------------------------------------
# Pure helpers
# ------------------------------------------------------------------
def _estimate_tokens(body: dict) -> int | None:
msgs = body.get("messages", [])
if not msgs:
prompt = body.get("prompt", "")
if isinstance(prompt, str):
return max(1, len(prompt) // 4)
return None
total = 0
for m in msgs:
content = m.get("content", "")
if isinstance(content, str):
total += max(1, len(content) // 4)
return total or None
def _get_model(body: dict) -> str:
return body.get("model", "")
def _session_id_from(request: Request) -> str | None:
return (
request.cookies.get(_SESSION_COOKIE)
or request.headers.get(_SESSION_HEADER)
)
def _attach_session(response: Response, session_id: str) -> None:
response.set_cookie(
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=False
)
response.headers[_SESSION_HEADER] = session_id
def _parse_body(raw: bytes) -> dict[str, Any]:
try:
return json.loads(raw) if raw else {}
except json.JSONDecodeError:
return {}
def _init_slot_tracker(
registry: BackendRegistry,
slot_tracker: SlotTracker,
default_max_models: int | None,
model_unload_delay: float,
) -> None:
slot_tracker.set_model_unload_delay(model_unload_delay)
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
effective_max = (
state.config.max_models
if state.config.max_models is not None
else default_max_models
)
slot_tracker.set_max_models(state.url, effective_max)
def _init_global_model_limits(config: ProxyConfig, slot_tracker: SlotTracker) -> None:
for model_id, max_concurrent in config.model_limits.items():
slot_tracker.set_global_model_limit(model_id, max_concurrent)
# ------------------------------------------------------------------
# Background tasks
# ------------------------------------------------------------------
async def _sync_capacities(
registry: BackendRegistry, slot_tracker: SlotTracker, interval: float
) -> None:
while True:
await asyncio.sleep(interval)
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
async def _expire_sessions(session_store: SessionStore) -> None:
while True:
await asyncio.sleep(60)
await session_store.expire()
# ------------------------------------------------------------------
# Route handlers (bound to app components via functools.partial)
# ------------------------------------------------------------------
async def _dispatch_entry(
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
slot_tracker: SlotTracker,
scheduler: Scheduler,
model_id: str,
session_id: str | None,
body: dict,
) -> BackendState | JSONResponse:
stats.increment_requests()
loop = asyncio.get_running_loop()
entry = QueueEntry(
model_id=model_id,
session_id=session_id,
estimated_tokens=_estimate_tokens(body),
future=loop.create_future(),
)
await request_queue.enqueue(entry)
try:
backend: BackendState = await asyncio.wait_for(
asyncio.shield(entry.future),
timeout=config.slot_wait_timeout,
)
except asyncio.TimeoutError:
await request_queue.remove(entry)
if entry.future.done() and not entry.future.cancelled():
# Scheduler won the race: it acquired a slot and resolved the future
# between our timeout and our remove. Release that slot now or it leaks.
resolved: BackendState = entry.future.result()
await slot_tracker.release(resolved.url, model_id)
scheduler.notify_slot_released()
elif not entry.future.done():
entry.future.cancel()
return JSONResponse(
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
status_code=503,
)
return backend
async def _recover_session_affinity(
session_id: str,
messages: list,
session_store: SessionStore,
) -> None:
"""Pre-seed a cookieless new session's preferred backend via message-prefix matching."""
if len(messages) < 2:
return
hint = await session_store.find_by_prefix(messages)
if hint:
await session_store.update(session_id, preferred_backend=hint)
async def _find_failover(
model_id: str,
excluded_urls: set[str],
registry: BackendRegistry,
slot_tracker: SlotTracker,
) -> BackendState | None:
"""Acquire a slot on any live backend not already tried. Returns None if none are free."""
candidates = (
registry.get_backends_for_model(model_id)
if model_id
else registry.get_all_live_backends()
)
for b in candidates:
if b.url in excluded_urls or not slot_tracker.can_accept(b.url, model_id):
continue
try:
async with asyncio.timeout(0):
await slot_tracker.acquire(b.url, model_id)
return b
except TimeoutError:
continue
return None
def _list_models(*, registry: BackendRegistry) -> JSONResponse:
models = registry.get_all_models()
data = [
{"id": m, "object": "model", "created": 0, "owned_by": "llamacpp-ha"}
for m in models
]
return JSONResponse({"object": "list", "data": data})
def _health(*, registry: BackendRegistry) -> Response:
if registry.get_all_live_backends():
return Response(content='{"status":"ok"}', media_type=_APPLICATION_JSON)
return Response(
content='{"status":"no live backends"}',
status_code=503,
media_type=_APPLICATION_JSON,
)
async def _catch_all(
request: Request,
*,
http: _HttpSession,
registry: BackendRegistry,
session_store: SessionStore,
) -> Response:
if http.client is None:
return Response(content="Proxy not ready", status_code=503)
# Non-inference paths are not counted as "requests served".
# If the client has a session cookie, prefer the backend that is already
# holding its KV-cache so web-UI API calls stay consistent with the chat.
session_id = _session_id_from(request)
preferred_url: str | None = None
if session_id:
preferred_url = await session_store.get_preferred_backend(session_id)
live = registry.get_all_live_backends()
return await forward_best_effort(request, live, http.client, preferred_url=preferred_url)
async def _inference_endpoint(
request: Request,
*,
http: _HttpSession,
slot_tracker: SlotTracker,
scheduler: Scheduler,
session_store: SessionStore,
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
registry: BackendRegistry,
) -> Response:
if http.client is None:
return Response(
content='{"error":{"message":"Proxy not ready","type":"server_error"}}',
status_code=503,
media_type=_APPLICATION_JSON,
)
body = _parse_body(await request.body())
model_id = _get_model(body)
incoming_session_id = _session_id_from(request)
session_id = incoming_session_id or session_store.new_session_id()
if not incoming_session_id:
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
preferred_url: str | None = None
if incoming_session_id:
preferred_url = await session_store.get_preferred_backend(session_id)
result = await _dispatch_entry(
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
)
if isinstance(result, Response):
return result
backend = result
tried: set[str] = {backend.url}
while True:
try:
response = await forward_request(
request=request,
backend=backend,
session=http.client,
slot_tracker=slot_tracker,
scheduler=scheduler,
model_id=model_id,
)
break
except aiohttp.ClientError as exc:
log.warning("Backend %s unreachable (%s); trying failover", backend.url, exc)
failover = await _find_failover(model_id, tried, registry, slot_tracker)
if failover is None:
return JSONResponse(
{"error": {"message": "Backend unreachable, no failover available", "type": "server_error"}},
status_code=502,
media_type=_APPLICATION_JSON,
)
tried.add(failover.url)
backend = failover
stats.record_model(model_id, _estimate_tokens(body))
stats.record_backend(backend.url)
stats.record_session(bool(incoming_session_id), preferred_url, backend.url)
if model_id:
messages = body.get("messages", [])
await session_store.update(
session_id,
model_id=model_id,
messages=messages if messages else None,
preferred_backend=backend.url,
)
_attach_session(response, session_id)
return response
# ------------------------------------------------------------------
# App factory
# ------------------------------------------------------------------
def create_app(config: ProxyConfig) -> FastAPI:
slot_tracker = SlotTracker()
session_store = SessionStore(ttl=config.session_idle_ttl)
request_queue = RequestQueue()
stats = ProxyStats()
http = _HttpSession()
async def _on_backend_recovered(url: str) -> None:
await slot_tracker.reset_acquired(url)
scheduler.notify_slot_released()
registry = BackendRegistry(config, on_backend_recovered=_on_backend_recovered)
scheduler = Scheduler(
queue=request_queue,
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=config.max_queue_skip,
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
http.client = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300),
connector=aiohttp.TCPConnector(ssl=False),
)
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
_init_global_model_limits(config, slot_tracker)
registry.start()
scheduler.start()
cap_task = asyncio.create_task(
_sync_capacities(registry, slot_tracker, config.poll_interval),
name="capacity-sync",
)
expire_task = asyncio.create_task(
_expire_sessions(session_store), name="session-expiry"
)
yield
cap_task.cancel()
expire_task.cancel()
await scheduler.stop()
await registry.stop()
if http.client:
await http.client.close()
app = FastAPI(title="llamacpp-ha", lifespan=lifespan)
app.add_middleware(ApiKeyMiddleware, api_keys=config.api_keys)
app.include_router(
build_monitor_router(
registry=registry,
slot_tracker=slot_tracker,
request_queue=request_queue,
session_store=session_store,
stats=stats,
)
)
def list_models_handler() -> JSONResponse:
return _list_models(registry=registry)
def health_handler() -> Response:
return _health(registry=registry)
async def inference_handler(request: Request) -> Response:
return await _inference_endpoint(
request,
http=http,
slot_tracker=slot_tracker,
scheduler=scheduler,
session_store=session_store,
request_queue=request_queue,
stats=stats,
config=config,
registry=registry,
)
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001
return await _catch_all(request, http=http, registry=registry, session_store=session_store)
app.add_api_route("/v1/models", list_models_handler, methods=["GET"])
app.add_api_route("/health", health_handler, methods=["GET"])
for path in _SLOT_GATED_PATHS:
app.add_api_route(path, inference_handler, methods=["POST"])
app.add_api_route(
"/{full_path:path}",
catch_all_handler,
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
include_in_schema=False,
)
return app