435 lines
14 KiB
Python
435 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncGenerator
|
|
|
|
import aiohttp
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, Response
|
|
|
|
from .config import ProxyConfig
|
|
from .forwarder import forward_best_effort, forward_request
|
|
from .middleware import ApiKeyMiddleware
|
|
from .monitor import ProxyStats, build_router as build_monitor_router
|
|
from .policies import RoundRobinPolicy
|
|
from .queue import QueueEntry, RequestQueue
|
|
from .registry import BackendRegistry, BackendState
|
|
from .scheduler import Scheduler
|
|
from .session_store import SessionStore
|
|
from .slot_tracker import SlotTracker
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_APPLICATION_JSON = "application/json"
|
|
_SESSION_COOKIE = "x-llm-session"
|
|
_SESSION_HEADER = "X-Session-ID"
|
|
_SLOT_GATED_PATHS = [
|
|
"/v1/chat/completions",
|
|
"/v1/completions",
|
|
"/v1/embeddings",
|
|
"/v1/images/generations",
|
|
"/v1/images/edits",
|
|
"/v1/images/variations",
|
|
"/v1/audio/speech",
|
|
"/v1/audio/transcriptions",
|
|
"/v1/audio/translations",
|
|
]
|
|
|
|
|
|
class _HttpSession:
|
|
"""Mutable holder for the shared aiohttp session, created inside lifespan."""
|
|
|
|
__slots__ = ("client",)
|
|
|
|
def __init__(self) -> None:
|
|
self.client: aiohttp.ClientSession | None = None
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Pure helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _estimate_tokens(body: dict) -> int | None:
|
|
msgs = body.get("messages", [])
|
|
if not msgs:
|
|
prompt = body.get("prompt", "")
|
|
if isinstance(prompt, str):
|
|
return max(1, len(prompt) // 4)
|
|
return None
|
|
total = 0
|
|
for m in msgs:
|
|
content = m.get("content", "")
|
|
if isinstance(content, str):
|
|
total += max(1, len(content) // 4)
|
|
return total or None
|
|
|
|
|
|
def _get_model(body: dict) -> str:
|
|
return body.get("model", "")
|
|
|
|
|
|
def _session_id_from(request: Request) -> str | None:
|
|
return (
|
|
request.cookies.get(_SESSION_COOKIE)
|
|
or request.headers.get(_SESSION_HEADER)
|
|
)
|
|
|
|
|
|
def _attach_session(response: Response, session_id: str) -> None:
|
|
response.set_cookie(
|
|
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=False
|
|
)
|
|
response.headers[_SESSION_HEADER] = session_id
|
|
|
|
|
|
def _parse_body(raw: bytes) -> dict[str, Any]:
|
|
try:
|
|
return json.loads(raw) if raw else {}
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
|
|
def _init_slot_tracker(
|
|
registry: BackendRegistry,
|
|
slot_tracker: SlotTracker,
|
|
default_max_models: int | None,
|
|
model_unload_delay: float,
|
|
) -> None:
|
|
slot_tracker.set_model_unload_delay(model_unload_delay)
|
|
for state in registry.get_all_states():
|
|
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
|
effective_max = (
|
|
state.config.max_models
|
|
if state.config.max_models is not None
|
|
else default_max_models
|
|
)
|
|
slot_tracker.set_max_models(state.url, effective_max)
|
|
|
|
|
|
def _init_global_model_limits(config: ProxyConfig, slot_tracker: SlotTracker) -> None:
|
|
for model_id, max_concurrent in config.model_limits.items():
|
|
slot_tracker.set_global_model_limit(model_id, max_concurrent)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Background tasks
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _sync_capacities(
|
|
registry: BackendRegistry, slot_tracker: SlotTracker, interval: float
|
|
) -> None:
|
|
while True:
|
|
await asyncio.sleep(interval)
|
|
for state in registry.get_all_states():
|
|
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
|
|
|
|
|
async def _expire_sessions(session_store: SessionStore) -> None:
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
await session_store.expire()
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Route handlers (bound to app components via functools.partial)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _dispatch_entry(
|
|
request_queue: RequestQueue,
|
|
stats: ProxyStats,
|
|
config: ProxyConfig,
|
|
slot_tracker: SlotTracker,
|
|
scheduler: Scheduler,
|
|
model_id: str,
|
|
session_id: str | None,
|
|
body: dict,
|
|
) -> BackendState | JSONResponse:
|
|
stats.increment_requests()
|
|
loop = asyncio.get_running_loop()
|
|
entry = QueueEntry(
|
|
model_id=model_id,
|
|
session_id=session_id,
|
|
estimated_tokens=_estimate_tokens(body),
|
|
future=loop.create_future(),
|
|
)
|
|
await request_queue.enqueue(entry)
|
|
try:
|
|
backend: BackendState = await asyncio.wait_for(
|
|
asyncio.shield(entry.future),
|
|
timeout=config.slot_wait_timeout,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
await request_queue.remove(entry)
|
|
if entry.future.done() and not entry.future.cancelled():
|
|
# Scheduler won the race: it acquired a slot and resolved the future
|
|
# between our timeout and our remove. Release that slot now or it leaks.
|
|
resolved: BackendState = entry.future.result()
|
|
await slot_tracker.release(resolved.url, model_id)
|
|
scheduler.notify_slot_released()
|
|
elif not entry.future.done():
|
|
entry.future.cancel()
|
|
return JSONResponse(
|
|
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
|
|
status_code=503,
|
|
)
|
|
return backend
|
|
|
|
|
|
async def _recover_session_affinity(
|
|
session_id: str,
|
|
messages: list,
|
|
session_store: SessionStore,
|
|
) -> None:
|
|
"""Pre-seed a cookieless new session's preferred backend via message-prefix matching."""
|
|
if len(messages) < 2:
|
|
return
|
|
hint = await session_store.find_by_prefix(messages)
|
|
if hint:
|
|
await session_store.update(session_id, preferred_backend=hint)
|
|
|
|
|
|
async def _find_failover(
|
|
model_id: str,
|
|
excluded_urls: set[str],
|
|
registry: BackendRegistry,
|
|
slot_tracker: SlotTracker,
|
|
) -> BackendState | None:
|
|
"""Acquire a slot on any live backend not already tried. Returns None if none are free."""
|
|
candidates = (
|
|
registry.get_backends_for_model(model_id)
|
|
if model_id
|
|
else registry.get_all_live_backends()
|
|
)
|
|
for b in candidates:
|
|
if b.url in excluded_urls or not slot_tracker.can_accept(b.url, model_id):
|
|
continue
|
|
try:
|
|
async with asyncio.timeout(0):
|
|
await slot_tracker.acquire(b.url, model_id)
|
|
return b
|
|
except TimeoutError:
|
|
continue
|
|
return None
|
|
|
|
|
|
def _list_models(*, registry: BackendRegistry) -> JSONResponse:
|
|
models = registry.get_all_models()
|
|
data = [
|
|
{"id": m, "object": "model", "created": 0, "owned_by": "llamacpp-ha"}
|
|
for m in models
|
|
]
|
|
return JSONResponse({"object": "list", "data": data})
|
|
|
|
|
|
def _health(*, registry: BackendRegistry) -> Response:
|
|
if registry.get_all_live_backends():
|
|
return Response(content='{"status":"ok"}', media_type=_APPLICATION_JSON)
|
|
return Response(
|
|
content='{"status":"no live backends"}',
|
|
status_code=503,
|
|
media_type=_APPLICATION_JSON,
|
|
)
|
|
|
|
|
|
async def _catch_all(
|
|
request: Request,
|
|
*,
|
|
http: _HttpSession,
|
|
registry: BackendRegistry,
|
|
session_store: SessionStore,
|
|
) -> Response:
|
|
if http.client is None:
|
|
return Response(content="Proxy not ready", status_code=503)
|
|
# Non-inference paths are not counted as "requests served".
|
|
# If the client has a session cookie, prefer the backend that is already
|
|
# holding its KV-cache so web-UI API calls stay consistent with the chat.
|
|
session_id = _session_id_from(request)
|
|
preferred_url: str | None = None
|
|
if session_id:
|
|
preferred_url = await session_store.get_preferred_backend(session_id)
|
|
live = registry.get_all_live_backends()
|
|
return await forward_best_effort(request, live, http.client, preferred_url=preferred_url)
|
|
|
|
|
|
async def _inference_endpoint(
|
|
request: Request,
|
|
*,
|
|
http: _HttpSession,
|
|
slot_tracker: SlotTracker,
|
|
scheduler: Scheduler,
|
|
session_store: SessionStore,
|
|
request_queue: RequestQueue,
|
|
stats: ProxyStats,
|
|
config: ProxyConfig,
|
|
registry: BackendRegistry,
|
|
) -> Response:
|
|
if http.client is None:
|
|
return Response(
|
|
content='{"error":{"message":"Proxy not ready","type":"server_error"}}',
|
|
status_code=503,
|
|
media_type=_APPLICATION_JSON,
|
|
)
|
|
|
|
body = _parse_body(await request.body())
|
|
model_id = _get_model(body)
|
|
incoming_session_id = _session_id_from(request)
|
|
session_id = incoming_session_id or session_store.new_session_id()
|
|
|
|
if not incoming_session_id:
|
|
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
|
|
|
|
preferred_url: str | None = None
|
|
if incoming_session_id:
|
|
preferred_url = await session_store.get_preferred_backend(session_id)
|
|
|
|
result = await _dispatch_entry(
|
|
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
|
|
)
|
|
if isinstance(result, Response):
|
|
return result
|
|
|
|
backend = result
|
|
tried: set[str] = {backend.url}
|
|
|
|
while True:
|
|
try:
|
|
response = await forward_request(
|
|
request=request,
|
|
backend=backend,
|
|
session=http.client,
|
|
slot_tracker=slot_tracker,
|
|
scheduler=scheduler,
|
|
model_id=model_id,
|
|
)
|
|
break
|
|
except aiohttp.ClientError as exc:
|
|
log.warning("Backend %s unreachable (%s); trying failover", backend.url, exc)
|
|
failover = await _find_failover(model_id, tried, registry, slot_tracker)
|
|
if failover is None:
|
|
return JSONResponse(
|
|
{"error": {"message": "Backend unreachable, no failover available", "type": "server_error"}},
|
|
status_code=502,
|
|
media_type=_APPLICATION_JSON,
|
|
)
|
|
tried.add(failover.url)
|
|
backend = failover
|
|
|
|
stats.record_model(model_id, _estimate_tokens(body))
|
|
stats.record_backend(backend.url)
|
|
stats.record_session(bool(incoming_session_id), preferred_url, backend.url)
|
|
|
|
if model_id:
|
|
messages = body.get("messages", [])
|
|
await session_store.update(
|
|
session_id,
|
|
model_id=model_id,
|
|
messages=messages if messages else None,
|
|
preferred_backend=backend.url,
|
|
)
|
|
|
|
_attach_session(response, session_id)
|
|
return response
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# App factory
|
|
# ------------------------------------------------------------------
|
|
|
|
def create_app(config: ProxyConfig) -> FastAPI:
|
|
slot_tracker = SlotTracker()
|
|
session_store = SessionStore(ttl=config.session_idle_ttl)
|
|
request_queue = RequestQueue()
|
|
stats = ProxyStats()
|
|
http = _HttpSession()
|
|
|
|
async def _on_backend_recovered(url: str) -> None:
|
|
await slot_tracker.reset_acquired(url)
|
|
scheduler.notify_slot_released()
|
|
|
|
registry = BackendRegistry(config, on_backend_recovered=_on_backend_recovered)
|
|
scheduler = Scheduler(
|
|
queue=request_queue,
|
|
registry=registry,
|
|
slot_tracker=slot_tracker,
|
|
session_store=session_store,
|
|
policy=RoundRobinPolicy(),
|
|
max_queue_skip=config.max_queue_skip,
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
http.client = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(total=300),
|
|
connector=aiohttp.TCPConnector(ssl=False),
|
|
)
|
|
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
|
|
_init_global_model_limits(config, slot_tracker)
|
|
registry.start()
|
|
scheduler.start()
|
|
cap_task = asyncio.create_task(
|
|
_sync_capacities(registry, slot_tracker, config.poll_interval),
|
|
name="capacity-sync",
|
|
)
|
|
expire_task = asyncio.create_task(
|
|
_expire_sessions(session_store), name="session-expiry"
|
|
)
|
|
yield
|
|
cap_task.cancel()
|
|
expire_task.cancel()
|
|
await scheduler.stop()
|
|
await registry.stop()
|
|
if http.client:
|
|
await http.client.close()
|
|
|
|
app = FastAPI(title="llamacpp-ha", lifespan=lifespan)
|
|
app.add_middleware(ApiKeyMiddleware, api_keys=config.api_keys)
|
|
app.include_router(
|
|
build_monitor_router(
|
|
registry=registry,
|
|
slot_tracker=slot_tracker,
|
|
request_queue=request_queue,
|
|
session_store=session_store,
|
|
stats=stats,
|
|
)
|
|
)
|
|
|
|
def list_models_handler() -> JSONResponse:
|
|
return _list_models(registry=registry)
|
|
|
|
def health_handler() -> Response:
|
|
return _health(registry=registry)
|
|
|
|
async def inference_handler(request: Request) -> Response:
|
|
return await _inference_endpoint(
|
|
request,
|
|
http=http,
|
|
slot_tracker=slot_tracker,
|
|
scheduler=scheduler,
|
|
session_store=session_store,
|
|
request_queue=request_queue,
|
|
stats=stats,
|
|
config=config,
|
|
registry=registry,
|
|
)
|
|
|
|
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001
|
|
return await _catch_all(request, http=http, registry=registry, session_store=session_store)
|
|
|
|
app.add_api_route("/v1/models", list_models_handler, methods=["GET"])
|
|
app.add_api_route("/health", health_handler, methods=["GET"])
|
|
|
|
for path in _SLOT_GATED_PATHS:
|
|
app.add_api_route(path, inference_handler, methods=["POST"])
|
|
|
|
app.add_api_route(
|
|
"/{full_path:path}",
|
|
catch_all_handler,
|
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
|
|
include_in_schema=False,
|
|
)
|
|
|
|
return app
|