From 7344aa4ef4141fc4818c0c9ebe297eea8c7b9096 Mon Sep 17 00:00:00 2001 From: chacha Date: Sun, 17 May 2026 09:54:18 +0200 Subject: [PATCH] first commit --- .claude/settings.local.json | 7 + .gitignore | 1 + AGENTS.md | 70 +++++ CLAUDE.md | 1 + README.md | 225 +++++++++++++++ config.json.example | 20 ++ pyproject.toml | 33 +++ src/llamacpp_ha/__init__.py | 0 src/llamacpp_ha/__main__.py | 51 ++++ src/llamacpp_ha/config.py | 54 ++++ src/llamacpp_ha/forwarder.py | 160 +++++++++++ src/llamacpp_ha/middleware.py | 36 +++ src/llamacpp_ha/monitor.py | 208 ++++++++++++++ src/llamacpp_ha/policies.py | 30 ++ src/llamacpp_ha/proxy.py | 342 ++++++++++++++++++++++ src/llamacpp_ha/queue.py | 78 +++++ src/llamacpp_ha/registry.py | 191 +++++++++++++ src/llamacpp_ha/scheduler.py | 212 ++++++++++++++ src/llamacpp_ha/session_store.py | 107 +++++++ src/llamacpp_ha/slot_tracker.py | 122 ++++++++ tests/__init__.py | 0 tests/test_config.py | 130 +++++++++ tests/test_forwarder.py | 276 ++++++++++++++++++ tests/test_integration.py | 477 +++++++++++++++++++++++++++++++ tests/test_middleware.py | 79 +++++ tests/test_monitor.py | 109 +++++++ tests/test_policies.py | 59 ++++ tests/test_queue.py | 92 ++++++ tests/test_registry.py | 152 ++++++++++ tests/test_scheduler.py | 300 +++++++++++++++++++ tests/test_session_store.py | 112 ++++++++ tests/test_slot_tracker.py | 187 ++++++++++++ 32 files changed, 3921 insertions(+) create mode 100644 .claude/settings.local.json create mode 100644 .gitignore create mode 100644 AGENTS.md create mode 100644 CLAUDE.md create mode 100644 README.md create mode 100644 config.json.example create mode 100644 pyproject.toml create mode 100644 src/llamacpp_ha/__init__.py create mode 100644 src/llamacpp_ha/__main__.py create mode 100644 src/llamacpp_ha/config.py create mode 100644 src/llamacpp_ha/forwarder.py create mode 100644 src/llamacpp_ha/middleware.py create mode 100644 src/llamacpp_ha/monitor.py create mode 100644 src/llamacpp_ha/policies.py create mode 100644 src/llamacpp_ha/proxy.py create mode 100644 src/llamacpp_ha/queue.py create mode 100644 src/llamacpp_ha/registry.py create mode 100644 src/llamacpp_ha/scheduler.py create mode 100644 src/llamacpp_ha/session_store.py create mode 100644 src/llamacpp_ha/slot_tracker.py create mode 100644 tests/__init__.py create mode 100644 tests/test_config.py create mode 100644 tests/test_forwarder.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_middleware.py create mode 100644 tests/test_monitor.py create mode 100644 tests/test_policies.py create mode 100644 tests/test_queue.py create mode 100644 tests/test_registry.py create mode 100644 tests/test_scheduler.py create mode 100644 tests/test_session_store.py create mode 100644 tests/test_slot_tracker.py diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..9c2a36a --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(python -m pytest)" + ] + } +} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..45bbccb --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,70 @@ +# llamacpp-ha — Agent instructions + +## Project overview + +`llamacpp-ha` is a slot-aware load balancer for llama.cpp servers. It exposes a single OpenAI-compatible HTTP API and distributes inference requests across multiple backends using a global request queue, per-backend slot tracking, and session affinity. The codebase is pure Python 3.13 async with FastAPI + aiohttp. + +## Commands + +```bash +# Install (editable, with test deps) +pip install -e ".[test]" + +# Run all tests +python -m pytest + +# Run a single test file +python -m pytest tests/test_scheduler.py -v + +# Start the proxy +llamacpp-ha --config config.json +``` + +## Architecture + +All source lives in `src/llamacpp_ha/`. Module responsibilities: + +| Module | Responsibility | +|---|---| +| `config.py` | `BackendConfig` + `ProxyConfig` (pydantic-settings, `LLAMACPP_HA_` env prefix) | +| `registry.py` | Polls `/health`, `/v1/models`, `/slots` on each backend; maintains live/dead state | +| `slot_tracker.py` | Per-backend `asyncio.Condition`; `acquire()` blocks until a slot is free | +| `queue.py` | FIFO `asyncio.Future` queue; `asyncio.Event` wakeup for the scheduler | +| `scheduler.py` | Drains the queue, resolves futures with a chosen `BackendState`; `notify_slot_released()` triggers re-drain | +| `policies.py` | `RoundRobinPolicy` — per-model atomic counter over `BackendCandidate` list | +| `session_store.py` | SHA-256 prefix hash → preferred backend; TTL eviction | +| `forwarder.py` | aiohttp outbound request; streaming SSE via `iter_chunked`; releases slot in `finally` | +| `middleware.py` | `ApiKeyMiddleware` — `Bearer` token validation; `/monitor` and `/monitor/data` are exempt | +| `monitor.py` | `ProxyStats` dataclass + `build_router()` for `/monitor` (HTML) and `/monitor/data` (JSON) | +| `proxy.py` | `create_app()` — wires all components; registers FastAPI routes; lifespan manages tasks | +| `__main__.py` | CLI entry point (`llamacpp-ha` script) | + +## Critical design invariants + +**Slot release must be awaited, not task-spawned.** `SlotTracker.release()` is `async` and must be `await`ed directly. Scheduling it as a task causes a race where the scheduler checks `has_free_slot()` before the release runs. This applies in `forwarder.py` (both the streaming `finally` block and the error path). + +**`QueueEntry.future` has no default.** The `future` field is `None` by default. Callers must create it explicitly: `entry.future = asyncio.get_running_loop().create_future()`. Never use `asyncio.get_event_loop()` — it is deprecated in async contexts. + +**`build_router()` creates a new `APIRouter` instance each call.** The router must not be a module-level singleton; doing so causes route handlers to share state across test cases. + +**`ProxyStats` is per-app.** Created in `create_app()` and passed into `build_router()`. Never use module-level counters or timestamps in `monitor.py`. + +**`scheduler.start()` is `def`, not `async def`.** It creates a task; it does not need to be awaited. + +**`asyncio.shield` in `_dispatch`.** The queue entry future is shielded so that a client timeout doesn't cancel the backend dispatch. On timeout, the entry is removed from the queue and the future is cancelled manually. + +## Testing + +- Tests are in `tests/` using `unittest` (not pytest-native style), with `IsolatedAsyncioTestCase` for async tests. +- `pytest.ini_options` sets `asyncio_mode = "auto"` so async test methods run without additional decorators. +- Integration tests (`test_integration.py`) use `with TestClient(app) as client:` — the context manager form is required to run the FastAPI lifespan (which starts the scheduler and opens the aiohttp session). +- Test helpers: `_entry(**kwargs)` creates a `QueueEntry` with a future attached via `asyncio.get_running_loop().create_future()`. +- All 113 tests must pass: `python -m pytest` should exit 0. + +## Style + +- Python 3.13+; use `from __future__ import annotations` in all modules. +- No comments unless the WHY is non-obvious (hidden constraint, asyncio race, workaround). +- No docstrings on internal helpers. +- Type annotations on all public functions and dataclass fields. +- `asyncio.get_running_loop()` everywhere; never `asyncio.get_event_loop()`. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..43c994c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..e04caab --- /dev/null +++ b/README.md @@ -0,0 +1,225 @@ +# llamacpp-ha + +Smart load balancer for [llama.cpp](https://github.com/ggerganov/llama.cpp) servers. Presents a single OpenAI-compatible API endpoint while distributing inference requests across multiple backends with slot-aware scheduling, session affinity, model-affinity reordering, and a live monitor page. + +## Features + +- **OpenAI-compatible API** — drop-in replacement for any client using `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/models`, etc. +- **Slot-aware scheduling** — tracks each backend's inference slot capacity (from `/slots`) and queues requests instead of sending them to overloaded backends +- **Preemption prevention** — optional `max_models` limit per backend prevents llama.cpp from evicting a running model's KV cache when a new model request arrives +- **Model-affinity reordering** — queued requests for an already-loaded model can be promoted ahead of waiting requests (configurable with starvation protection) +- **Session affinity** — routes follow-up turns in a conversation back to the same backend, improving KV-cache reuse +- **Round-robin policy** — distributes load evenly across backends per model +- **Backend health polling** — continuously polls `/health` and `/v1/models`; removes dead backends from rotation automatically; resets slot counters on recovery +- **Request queue** — FIFO queue with configurable timeout; returns 503 with a JSON error body when no slot becomes available in time +- **API key auth** — optional `Bearer` token validation; the proxy rewrites outbound auth to match each backend's own key +- **Monitor page** — self-contained HTML dashboard at `/monitor` (no CDN dependencies); auto-refreshes every 3 seconds +- **Catch-all proxy** — non-inference paths are forwarded best-effort to a live backend + +## Requirements + +- Python 3.13+ +- llama.cpp server(s) with `--slots` endpoint enabled + +## Installation + +```bash +pip install . +# or in editable mode for development: +pip install -e ".[test]" +``` + +## Quick start + +Copy the example config and edit it: + +```bash +cp config.json.example config.json +llamacpp-ha --config config.json +``` + +The proxy starts on `http://0.0.0.0:8080` by default. + +## Configuration + +Configuration is a JSON file. All fields also accept environment variable overrides with the `LLAMACPP_HA_` prefix (nested fields use `__` as delimiter). + +```json +{ + "host": "0.0.0.0", + "port": 8080, + "api_keys": ["your-secret-key"], + "poll_interval": 5, + "slot_wait_timeout": 30, + "session_idle_ttl": 300, + "default_max_models": 1, + "max_queue_skip": 4, + "backends": [ + { + "url": "http://localhost:8081", + "api_key": null, + "model_ids": [], + "max_models": 1 + }, + { + "url": "http://localhost:8082", + "api_key": "backend-secret", + "model_ids": ["llama3"] + } + ] +} +``` + +### Global fields + +| Field | Default | Description | +|---|---|---| +| `host` | `0.0.0.0` | Listen address | +| `port` | `8080` | Listen port | +| `api_keys` | `[]` | Accepted bearer tokens. Empty = no auth. | +| `poll_interval` | `5.0` | Seconds between backend health polls | +| `slot_wait_timeout` | `30.0` | Max seconds a request waits for a free slot | +| `session_idle_ttl` | `300.0` | Seconds before an idle session is evicted | +| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. | +| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. | + +### Per-backend fields + +| Field | Default | Description | +|---|---|---| +| `url` | required | Backend base URL | +| `api_key` | `null` | Injected as `Authorization: Bearer ` on outbound requests; client key is stripped | +| `model_ids` | `[]` | Override the model list instead of polling `/v1/models` | +| `max_models` | `null` | Maximum concurrent distinct models on this backend. Overrides `default_max_models`. `null` = unlimited. | + +### Environment variable overrides + +```bash +LLAMACPP_HA_PORT=9090 llamacpp-ha --config config.json +LLAMACPP_HA_API_KEYS='["key1","key2"]' llamacpp-ha --config config.json +``` + +## Preemption prevention (`max_models`) + +llama.cpp evicts the current model's KV cache when a different model is loaded, which can interrupt an in-flight request. Setting `max_models: 1` on a backend tells the proxy to block requests for a second model until all slots for the first model are released. + +```json +{ + "default_max_models": 1, + "backends": [ + {"url": "http://localhost:8081"}, + {"url": "http://localhost:8082"} + ] +} +``` + +With this configuration each backend serves exactly one model at a time. When both backends are busy with different models, a third request waits in the queue until a slot is freed on a compatible backend. + +Set `default_max_models: 2` (or higher) to allow two models to share a backend's slots simultaneously when hardware permits. + +## Model-affinity reordering (`max_queue_skip`) + +By default the queue is strict FIFO. Setting `max_queue_skip` to a positive integer enables a two-phase dispatch: + +1. **Affinity pass** — scans the queue for requests whose model is already active on a free backend. Matching requests are promoted and dispatched immediately, bypassing earlier entries. +2. **FIFO pass** — remaining entries are dispatched in arrival order. + +Each time a request is bypassed, its `skip_count` is incremented. Once `skip_count` reaches `max_queue_skip` the entry is frozen at head-of-line and blocks the affinity pass — preventing indefinite starvation. + +```json +{ + "max_queue_skip": 4 +} +``` + +This trades strict fairness for throughput: a warm model serves back-to-back requests without stalling for a cold-start on another model, while the `max_queue_skip` cap guarantees every request is eventually served. + +## CLI reference + +``` +llamacpp-ha [--config PATH] [--host HOST] [--port PORT] [--log-level LEVEL] +``` + +`--config` defaults to `config.json` in the current directory. `--host` and `--port` override the values in the config file. + +## API endpoints + +| Method | Path | Description | +|---|---|---| +| `GET` | `/health` | Returns `{"status":"ok"}` if at least one backend is live, 503 otherwise | +| `GET` | `/v1/models` | Aggregated model list across all live backends | +| `POST` | `/v1/chat/completions` | Slot-gated, session-aware inference (streaming supported) | +| `POST` | `/v1/completions` | Same as above | +| `POST` | `/v1/embeddings` | Slot-gated pass-through | +| `POST` | `/v1/images/*`, `/v1/audio/*` | Slot-gated pass-through | +| `*` | `/*` | Best-effort forward to any live backend | +| `GET` | `/monitor` | HTML dashboard | +| `GET` | `/monitor/data` | Dashboard data as JSON (exempt from API key auth) | + +### Session affinity + +The proxy assigns a session ID to every request. It is sent back via both a cookie (`x-llm-session`) and a response header (`X-Session-ID`). Clients can echo either on subsequent requests to pin their conversation to the same backend. The affinity record expires after `session_idle_ttl` seconds of inactivity. + +### Streaming + +SSE streaming responses (`text/event-stream`) are passed through transparently. The backend slot is held for the duration of the stream and released when the final chunk is sent. + +## Monitor + +Open `http://localhost:8080/monitor` in a browser. The page polls `/monitor/data` every 3 seconds and shows: + +- Uptime, total requests served, queue depth, active sessions, live backend count +- Per-backend: URL, live/dead status, models, slot usage (`acquired/total`), time since last poll +- Current queue contents with wait time and estimated token count +- Active sessions grouped by model + +The monitor page and its data endpoint are exempt from API key authentication. + +## Architecture + +``` +client + │ + ▼ +ApiKeyMiddleware + │ + ▼ +FastAPI app (proxy.py) + ├── GET /v1/models ──► BackendRegistry.get_all_models() + ├── POST /v1/chat/... ──► RequestQueue ──► Scheduler ──► SlotTracker + │ │ + │ BackendState ◄─┘ + │ │ + │ forwarder.py ──► aiohttp ──► backend + ├── GET /health + ├── GET /monitor[/data] ──► monitor.py + └── /* catch-all ──► forwarder.forward_best_effort() + +BackendRegistry polls /health + /v1/models + /slots every poll_interval; + calls on_backend_recovered when a dead backend comes back live +SlotTracker asyncio.Condition per backend; acquire blocks until slot free; + enforces max_models (preemption prevention) +SessionStore SHA-256 prefix hash → preferred backend URL; TTL eviction +RequestQueue FIFO asyncio.Future queue; asyncio.Event for wakeup +Scheduler two-phase dispatch (affinity pass + FIFO); N-skip reordering +RoundRobinPolicy per-model atomic counter for backend selection +``` + +## Development + +```bash +# Install with test dependencies +pip install -e ".[test]" + +# Run all tests +python -m pytest + +# Run a specific module +python -m pytest tests/test_scheduler.py -v +``` + +Tests use `unittest.IsolatedAsyncioTestCase` for async tests and Starlette's `TestClient` as a context manager for integration tests that require the full lifespan (aiohttp session, scheduler, registry). + +## License + +MIT diff --git a/config.json.example b/config.json.example new file mode 100644 index 0000000..2d43709 --- /dev/null +++ b/config.json.example @@ -0,0 +1,20 @@ +{ + "host": "0.0.0.0", + "port": 8080, + "api_keys": ["your-secret-key"], + "poll_interval": 5, + "slot_wait_timeout": 30, + "session_idle_ttl": 300, + "backends": [ + { + "url": "http://localhost:8081", + "api_key": null, + "model_ids": [] + }, + { + "url": "http://localhost:8082", + "api_key": "backend-secret", + "model_ids": ["llama3"] + } + ] +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cbab03e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "llamacpp-ha" +version = "0.1.0" +description = "Smart load balancer for llama.cpp servers" +requires-python = ">=3.13" +dependencies = [ + "fastapi>=0.115", + "uvicorn[standard]>=0.32", + "aiohttp>=3.11", + "pydantic-settings>=2.7", + "pydantic>=2.10", +] + +[project.optional-dependencies] +test = [ + "httpx>=0.28", + "pytest>=8", + "pytest-asyncio>=0.24", +] + +[project.scripts] +llamacpp-ha = "llamacpp_ha.__main__:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/llamacpp_ha"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/src/llamacpp_ha/__init__.py b/src/llamacpp_ha/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/llamacpp_ha/__main__.py b/src/llamacpp_ha/__main__.py new file mode 100644 index 0000000..f35a622 --- /dev/null +++ b/src/llamacpp_ha/__main__.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import argparse +import logging +import sys + +import uvicorn + +from .config import ProxyConfig +from .proxy import create_app + + +def main() -> None: + parser = argparse.ArgumentParser(description="llamacpp-ha load balancer") + parser.add_argument( + "--config", + default="config.json", + help="Path to JSON config file (default: config.json)", + ) + parser.add_argument("--host", help="Override listen host") + parser.add_argument("--port", type=int, help="Override listen port") + parser.add_argument( + "--log-level", + default="info", + choices=["debug", "info", "warning", "error"], + ) + args = parser.parse_args() + + logging.basicConfig( + level=args.log_level.upper(), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + overrides = {} + if args.host: + overrides["host"] = args.host + if args.port: + overrides["port"] = args.port + + try: + config = ProxyConfig.from_file(args.config, **overrides) + except FileNotFoundError: + print(f"Config file not found: {args.config}", file=sys.stderr) + sys.exit(1) + + app = create_app(config) + uvicorn.run(app, host=config.host, port=config.port, log_level=args.log_level) + + +if __name__ == "__main__": + main() diff --git a/src/llamacpp_ha/config.py b/src/llamacpp_ha/config.py new file mode 100644 index 0000000..195f16a --- /dev/null +++ b/src/llamacpp_ha/config.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import json +from typing import Any + +from pydantic import BaseModel, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class BackendConfig(BaseModel): + url: str + api_key: str | None = None + model_ids: list[str] = [] + # Maximum number of distinct models allowed in-flight simultaneously on this + # backend. Set to 1 for llama.cpp to prevent mid-request model preemption. + # None means the backend handles concurrency itself (no proxy-level limit). + max_models: int | None = None + + @field_validator("url") + @classmethod + def strip_trailing_slash(cls, v: str) -> str: + return v.rstrip("/") + + +class ProxyConfig(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="LLAMACPP_HA_", + env_nested_delimiter="__", + ) + + host: str = "0.0.0.0" + port: int = 8080 + api_keys: list[str] = [] + poll_interval: float = 5.0 + slot_wait_timeout: float = 30.0 + session_idle_ttl: float = 300.0 + default_slot_capacity: int = 1 + # Fallback max_models applied to any backend that does not set its own. + # None = unlimited. Set to 1 globally when all backends are llama.cpp. + default_max_models: int | None = None + # How many queue positions a model-affinity request may skip ahead. + # 0 = pure FIFO (default). N > 0 enables reordering: the scheduler looks + # for a request matching an already-active model and can promote it up to N + # positions; each bypassed entry accumulates a skip count and is immune to + # further skipping once it reaches N. + max_queue_skip: int = 0 + backends: list[BackendConfig] = [] + + @classmethod + def from_file(cls, path: str, **overrides: Any) -> "ProxyConfig": + with open(path) as f: + data = json.load(f) + data.update(overrides) + return cls(**data) diff --git a/src/llamacpp_ha/forwarder.py b/src/llamacpp_ha/forwarder.py new file mode 100644 index 0000000..8f41c0e --- /dev/null +++ b/src/llamacpp_ha/forwarder.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import logging +from typing import AsyncIterator + +import aiohttp +from fastapi import Request +from fastapi.responses import Response, StreamingResponse +from starlette.datastructures import Headers + +from .registry import BackendState +from .scheduler import Scheduler +from .slot_tracker import SlotTracker + +log = logging.getLogger(__name__) + +_STREAM_CHUNK = 8192 +_HOP_BY_HOP = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + "content-encoding", + "content-length", + ] +) + + +def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]: + headers: dict[str, str] = {} + for name, value in request.headers.items(): + if name.lower() not in _HOP_BY_HOP and name.lower() != "host": + headers[name] = value + + if backend.config.api_key: + headers["Authorization"] = f"Bearer {backend.config.api_key}" + else: + headers.pop("Authorization", None) + headers.pop("authorization", None) + + return headers + + +async def forward_request( + request: Request, + backend: BackendState, + session: aiohttp.ClientSession, + slot_tracker: SlotTracker, + scheduler: Scheduler, + model_id: str = "", + path_override: str | None = None, +) -> Response: + path = path_override or request.url.path + query = request.url.query + target = backend.url + path + (f"?{query}" if query else "") + headers = _forward_headers(request, backend) + body = await request.body() + + try: + resp_ctx = session.request( + method=request.method, + url=target, + headers=headers, + data=body if body else None, + allow_redirects=False, + ) + + resp = await resp_ctx.__aenter__() + + content_type = resp.headers.get("Content-Type", "") + is_streaming = "text/event-stream" in content_type + + response_headers = { + k: v + for k, v in resp.headers.items() + if k.lower() not in _HOP_BY_HOP + } + + if is_streaming: + async def generate() -> AsyncIterator[bytes]: + try: + async for chunk in resp.content.iter_chunked(_STREAM_CHUNK): + yield chunk + finally: + await resp_ctx.__aexit__(None, None, None) + await slot_tracker.release(backend.url, model_id) + scheduler.notify_slot_released() + + return StreamingResponse( + generate(), + status_code=resp.status, + headers=response_headers, + media_type=content_type, + ) + else: + try: + data = await resp.read() + finally: + await resp_ctx.__aexit__(None, None, None) + await slot_tracker.release(backend.url, model_id) + scheduler.notify_slot_released() + return Response( + content=data, + status_code=resp.status, + headers=response_headers, + media_type=content_type, + ) + + except Exception as exc: + log.error("Forward error to %s: %s", backend.url, exc) + await slot_tracker.release(backend.url, model_id) + scheduler.notify_slot_released() + raise + + +async def forward_best_effort( + request: Request, + registry_backends: list[BackendState], + session: aiohttp.ClientSession, +) -> Response: + """Forward without slot gating to any live backend (catch-all paths).""" + if not registry_backends: + return Response(content="No live backends", status_code=503) + + backend = registry_backends[0] + path = request.url.path + query = request.url.query + target = backend.url + path + (f"?{query}" if query else "") + headers = _forward_headers(request, backend) + body = await request.body() + + try: + async with session.request( + method=request.method, + url=target, + headers=headers, + data=body if body else None, + allow_redirects=False, + ) as resp: + content_type = resp.headers.get("Content-Type", "") + response_headers = { + k: v + for k, v in resp.headers.items() + if k.lower() not in _HOP_BY_HOP + } + data = await resp.read() + return Response( + content=data, + status_code=resp.status, + headers=response_headers, + media_type=content_type, + ) + except Exception as exc: + log.error("Best-effort forward error: %s", exc) + return Response(content="Backend error", status_code=502) diff --git a/src/llamacpp_ha/middleware.py b/src/llamacpp_ha/middleware.py new file mode 100644 index 0000000..c136ce1 --- /dev/null +++ b/src/llamacpp_ha/middleware.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/data"]) + + +class ApiKeyMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, api_keys: list[str]) -> None: + super().__init__(app) + self._keys: frozenset[str] = frozenset(api_keys) + + async def dispatch(self, request: Request, call_next) -> Response: + if not self._keys: + return await call_next(request) + + if request.url.path in _EXEMPT_PATHS: + return await call_next(request) + + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + return JSONResponse( + {"error": {"message": "Missing or invalid API key", "type": "auth_error"}}, + status_code=401, + ) + token = auth[len("Bearer "):] + if token not in self._keys: + return JSONResponse( + {"error": {"message": "Invalid API key", "type": "auth_error"}}, + status_code=401, + ) + + return await call_next(request) diff --git a/src/llamacpp_ha/monitor.py b/src/llamacpp_ha/monitor.py new file mode 100644 index 0000000..a7c9494 --- /dev/null +++ b/src/llamacpp_ha/monitor.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass, field + +from fastapi import APIRouter +from fastapi.responses import HTMLResponse, JSONResponse + +from .queue import RequestQueue +from .registry import BackendRegistry +from .session_store import SessionStore +from .slot_tracker import SlotTracker + + +@dataclass +class ProxyStats: + """Per-app counters and timing. Created by create_app, passed to build_router.""" + start_time: float = field(default_factory=time.monotonic) + total_requests: int = 0 + + def increment_requests(self) -> None: + self.total_requests += 1 + + def uptime_str(self) -> str: + secs = int(time.monotonic() - self.start_time) + h, remainder = divmod(secs, 3600) + m, s = divmod(remainder, 60) + return f"{h:02d}:{m:02d}:{s:02d}" + + +_HTML = """ + + + +llamacpp-ha Monitor + + + +

llamacpp-ha loading...

+
Smart Load Balancer for llama.cpp
+ +
+
-
Uptime
+
-
Requests Served
+
-
Queue Depth
+
-
Active Sessions
+
-
Live Backends
+
+ +

Backends

+ + + +
URLStatusModelsSlotsLast Poll
Loading...
+ +

Queue

+ + + +
Request IDModelSessionWait (s)Est. Tokens
Queue is empty
+ +

Sessions by Model

+ + + +
ModelActive Sessions
No active sessions
+ + + + +""" + + +def build_router( + registry: BackendRegistry, + slot_tracker: SlotTracker, + request_queue: RequestQueue, + session_store: SessionStore, + stats: ProxyStats, +) -> APIRouter: + router = APIRouter() + + @router.get("/monitor", response_class=HTMLResponse, include_in_schema=False) + async def monitor_page() -> HTMLResponse: + return HTMLResponse(content=_HTML) + + @router.get("/monitor/data", include_in_schema=False) + async def monitor_data() -> JSONResponse: + states = registry.get_all_states() + backends_data = [] + for state in states: + acquired, total = slot_tracker.usage(state.url) + age = state.last_poll_age + backends_data.append( + { + "url": state.url, + "live": state.live, + "models": list(state.models), + "slots_acquired": acquired, + "slots_total": total, + "last_poll_age": None if age == float("inf") else round(age, 1), + } + ) + + queue_snapshot = await request_queue.snapshot() + session_count = await session_store.count() + sessions_by_model = await session_store.count_by_model() + live_count = sum(1 for s in states if s.live) + + return JSONResponse( + { + "uptime": stats.uptime_str(), + "total_requests": stats.total_requests, + "queue_depth": len(queue_snapshot), + "session_count": session_count, + "live_backend_count": live_count, + "backends": backends_data, + "queue": queue_snapshot, + "sessions_by_model": sessions_by_model, + } + ) + + return router diff --git a/src/llamacpp_ha/policies.py b/src/llamacpp_ha/policies.py new file mode 100644 index 0000000..6b28019 --- /dev/null +++ b/src/llamacpp_ha/policies.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class BackendCandidate: + """Minimal view of a backend passed to routing policies.""" + url: str + index: int # position in the original backends list + + +class RoutingPolicy(ABC): + @abstractmethod + def select(self, model_id: str, candidates: list[BackendCandidate]) -> BackendCandidate: + """Select one backend from the candidate list for the given model.""" + + +class RoundRobinPolicy(RoutingPolicy): + def __init__(self) -> None: + self._counters: dict[str, int] = {} + + def select(self, model_id: str, candidates: list[BackendCandidate]) -> BackendCandidate: + if not candidates: + raise ValueError("No candidates available") + count = self._counters.get(model_id, 0) + chosen = candidates[count % len(candidates)] + self._counters[model_id] = count + 1 + return chosen diff --git a/src/llamacpp_ha/proxy.py b/src/llamacpp_ha/proxy.py new file mode 100644 index 0000000..09fdc5a --- /dev/null +++ b/src/llamacpp_ha/proxy.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator + +import aiohttp +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response + +from .config import ProxyConfig +from .forwarder import forward_best_effort, forward_request +from .middleware import ApiKeyMiddleware +from .monitor import ProxyStats, build_router as build_monitor_router +from .policies import RoundRobinPolicy +from .queue import QueueEntry, RequestQueue +from .registry import BackendRegistry, BackendState +from .scheduler import Scheduler +from .session_store import SessionStore +from .slot_tracker import SlotTracker + +log = logging.getLogger(__name__) + +_APPLICATION_JSON = "application/json" +_SESSION_COOKIE = "x-llm-session" +_SESSION_HEADER = "X-Session-ID" +_SLOT_GATED_PATHS = [ + "/v1/chat/completions", + "/v1/completions", + "/v1/embeddings", + "/v1/images/generations", + "/v1/images/edits", + "/v1/images/variations", + "/v1/audio/speech", + "/v1/audio/transcriptions", + "/v1/audio/translations", +] + + +class _HttpSession: + """Mutable holder for the shared aiohttp session, created inside lifespan.""" + + __slots__ = ("client",) + + def __init__(self) -> None: + self.client: aiohttp.ClientSession | None = None + + +# ------------------------------------------------------------------ +# Pure helpers +# ------------------------------------------------------------------ + +def _estimate_tokens(body: dict) -> int | None: + msgs = body.get("messages", []) + if not msgs: + prompt = body.get("prompt", "") + if isinstance(prompt, str): + return max(1, len(prompt) // 4) + return None + total = 0 + for m in msgs: + content = m.get("content", "") + if isinstance(content, str): + total += max(1, len(content) // 4) + return total or None + + +def _get_model(body: dict) -> str: + return body.get("model", "") + + +def _session_id_from(request: Request) -> str | None: + return ( + request.cookies.get(_SESSION_COOKIE) + or request.headers.get(_SESSION_HEADER) + ) + + +def _attach_session(response: Response, session_id: str) -> None: + response.set_cookie( + _SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=True + ) + response.headers[_SESSION_HEADER] = session_id + + +def _init_slot_tracker( + registry: BackendRegistry, + slot_tracker: SlotTracker, + default_max_models: int | None, +) -> None: + for state in registry.get_all_states(): + slot_tracker.set_capacity(state.url, state.slot_capacity) + effective_max = ( + state.config.max_models + if state.config.max_models is not None + else default_max_models + ) + slot_tracker.set_max_models(state.url, effective_max) + + +# ------------------------------------------------------------------ +# Background tasks +# ------------------------------------------------------------------ + +async def _sync_capacities( + registry: BackendRegistry, slot_tracker: SlotTracker, interval: float +) -> None: + while True: + await asyncio.sleep(interval) + for state in registry.get_all_states(): + slot_tracker.set_capacity(state.url, state.slot_capacity) + + +async def _expire_sessions(session_store: SessionStore) -> None: + while True: + await asyncio.sleep(60) + await session_store.expire() + + +# ------------------------------------------------------------------ +# Route handlers (bound to app components via functools.partial) +# ------------------------------------------------------------------ + +async def _dispatch_entry( + request_queue: RequestQueue, + stats: ProxyStats, + config: ProxyConfig, + model_id: str, + session_id: str | None, + body: dict, +) -> BackendState | JSONResponse: + stats.increment_requests() + loop = asyncio.get_running_loop() + entry = QueueEntry( + model_id=model_id, + session_id=session_id, + estimated_tokens=_estimate_tokens(body), + future=loop.create_future(), + ) + await request_queue.enqueue(entry) + try: + backend: BackendState = await asyncio.wait_for( + asyncio.shield(entry.future), + timeout=config.slot_wait_timeout, + ) + except asyncio.TimeoutError: + await request_queue.remove(entry) + if not entry.future.done(): + entry.future.cancel() + return JSONResponse( + {"error": {"message": "No slot available (timeout)", "type": "overloaded"}}, + status_code=503, + ) + return backend + + +def _list_models(*, registry: BackendRegistry) -> JSONResponse: + models = registry.get_all_models() + data = [ + {"id": m, "object": "model", "created": 0, "owned_by": "llamacpp-ha"} + for m in models + ] + return JSONResponse({"object": "list", "data": data}) + + +def _health(*, registry: BackendRegistry) -> Response: + if registry.get_all_live_backends(): + return Response(content='{"status":"ok"}', media_type=_APPLICATION_JSON) + return Response( + content='{"status":"no live backends"}', + status_code=503, + media_type=_APPLICATION_JSON, + ) + + +async def _catch_all( + request: Request, + *, + http: _HttpSession, + registry: BackendRegistry, + stats: ProxyStats, +) -> Response: + if http.client is None: + return Response(content="Proxy not ready", status_code=503) + stats.increment_requests() + live = registry.get_all_live_backends() + return await forward_best_effort(request, live, http.client) + + +async def _inference_endpoint( + request: Request, + *, + http: _HttpSession, + slot_tracker: SlotTracker, + scheduler: Scheduler, + session_store: SessionStore, + request_queue: RequestQueue, + stats: ProxyStats, + config: ProxyConfig, +) -> Response: + if http.client is None: + return Response( + content='{"error":{"message":"Proxy not ready","type":"server_error"}}', + status_code=503, + media_type=_APPLICATION_JSON, + ) + + raw = await request.body() + try: + body: dict[str, Any] = json.loads(raw) if raw else {} + except json.JSONDecodeError: + body = {} + + model_id = _get_model(body) + session_id = _session_id_from(request) or session_store.new_session_id() + + result = await _dispatch_entry( + request_queue, stats, config, model_id, session_id, body + ) + if isinstance(result, Response): + return result + + backend = result + if model_id: + messages = body.get("messages", []) + await session_store.update( + session_id, + model_id=model_id, + messages=messages if messages else None, + preferred_backend=backend.url, + ) + + response = await forward_request( + request=request, + backend=backend, + session=http.client, + slot_tracker=slot_tracker, + scheduler=scheduler, + model_id=model_id, + ) + _attach_session(response, session_id) + return response + + +# ------------------------------------------------------------------ +# App factory +# ------------------------------------------------------------------ + +def create_app(config: ProxyConfig) -> FastAPI: + slot_tracker = SlotTracker() + session_store = SessionStore(ttl=config.session_idle_ttl) + request_queue = RequestQueue() + stats = ProxyStats() + http = _HttpSession() + + async def _on_backend_recovered(url: str) -> None: + await slot_tracker.reset_acquired(url) + scheduler.notify_slot_released() + + registry = BackendRegistry(config, on_backend_recovered=_on_backend_recovered) + scheduler = Scheduler( + queue=request_queue, + registry=registry, + slot_tracker=slot_tracker, + session_store=session_store, + policy=RoundRobinPolicy(), + max_queue_skip=config.max_queue_skip, + ) + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + http.client = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300), + connector=aiohttp.TCPConnector(ssl=False), + ) + _init_slot_tracker(registry, slot_tracker, config.default_max_models) + registry.start() + scheduler.start() + cap_task = asyncio.create_task( + _sync_capacities(registry, slot_tracker, config.poll_interval), + name="capacity-sync", + ) + expire_task = asyncio.create_task( + _expire_sessions(session_store), name="session-expiry" + ) + yield + cap_task.cancel() + expire_task.cancel() + await scheduler.stop() + await registry.stop() + if http.client: + await http.client.close() + + app = FastAPI(title="llamacpp-ha", lifespan=lifespan) + app.add_middleware(ApiKeyMiddleware, api_keys=config.api_keys) + app.include_router( + build_monitor_router( + registry=registry, + slot_tracker=slot_tracker, + request_queue=request_queue, + session_store=session_store, + stats=stats, + ) + ) + + def list_models_handler() -> JSONResponse: + return _list_models(registry=registry) + + def health_handler() -> Response: + return _health(registry=registry) + + async def inference_handler(request: Request) -> Response: + return await _inference_endpoint( + request, + http=http, + slot_tracker=slot_tracker, + scheduler=scheduler, + session_store=session_store, + request_queue=request_queue, + stats=stats, + config=config, + ) + + async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001 + return await _catch_all(request, http=http, registry=registry, stats=stats) + + app.add_api_route("/v1/models", list_models_handler, methods=["GET"]) + app.add_api_route("/health", health_handler, methods=["GET"]) + + for path in _SLOT_GATED_PATHS: + app.add_api_route(path, inference_handler, methods=["POST"]) + + app.add_api_route( + "/{full_path:path}", + catch_all_handler, + methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], + include_in_schema=False, + ) + + return app diff --git a/src/llamacpp_ha/queue.py b/src/llamacpp_ha/queue.py new file mode 100644 index 0000000..6e0af4d --- /dev/null +++ b/src/llamacpp_ha/queue.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from dataclasses import dataclass, field + + +@dataclass +class QueueEntry: + request_id: str = field(default_factory=lambda: uuid.uuid4().hex) + model_id: str = "" + session_id: str | None = None + arrival_time: float = field(default_factory=time.monotonic) + estimated_tokens: int | None = None + # Created explicitly by the enqueuing coroutine via asyncio.get_running_loop() + future: asyncio.Future | None = field(default=None) + # Populated by scheduler at dispatch time + assigned_backend: str | None = None + # Incremented each time a later entry is dispatched ahead of this one via + # model-affinity reordering. When it reaches max_queue_skip the entry + # becomes immune to further skipping (starvation prevention). + skip_count: int = 0 + + @property + def wait_seconds(self) -> float: + return time.monotonic() - self.arrival_time + + +class RequestQueue: + """Global FIFO queue for pending inference requests.""" + + def __init__(self) -> None: + self._entries: list[QueueEntry] = [] + self._lock = asyncio.Lock() + self._wakeup = asyncio.Event() + + @property + def wakeup_event(self) -> asyncio.Event: + return self._wakeup + + async def enqueue(self, entry: QueueEntry) -> None: + async with self._lock: + self._entries.append(entry) + self._wakeup.set() + + async def remove(self, entry: QueueEntry) -> None: + async with self._lock: + try: + self._entries.remove(entry) + except ValueError: + pass + + async def pending(self) -> list[QueueEntry]: + async with self._lock: + return list(self._entries) + + async def depth(self) -> int: + async with self._lock: + return len(self._entries) + + async def snapshot(self) -> list[dict]: + async with self._lock: + return [ + { + "request_id": e.request_id, + "model_id": e.model_id, + "session_id": (e.session_id or "")[:8] or None, + "wait_seconds": round(e.wait_seconds, 2), + "estimated_tokens": e.estimated_tokens, + "skip_count": e.skip_count, + } + for e in self._entries + ] + + def notify(self) -> None: + """Wake the scheduler without holding the lock.""" + self._wakeup.set() diff --git a/src/llamacpp_ha/registry.py b/src/llamacpp_ha/registry.py new file mode 100644 index 0000000..b09814f --- /dev/null +++ b/src/llamacpp_ha/registry.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import asyncio +import logging +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field + +import aiohttp + +from .config import BackendConfig, ProxyConfig + +log = logging.getLogger(__name__) + +_HEALTH_PATH = "/health" +_MODELS_PATH = "/v1/models" +_SLOTS_PATH = "/slots" + + +@dataclass +class BackendState: + config: BackendConfig + live: bool = False + models: list[str] = field(default_factory=list) + slot_capacity: int = 1 + last_poll_time: float | None = None + + @property + def url(self) -> str: + return self.config.url + + @property + def last_poll_age(self) -> float: + if self.last_poll_time is None: + return float("inf") + return time.monotonic() - self.last_poll_time + + +class BackendRegistry: + def __init__( + self, + config: ProxyConfig, + on_backend_recovered: Callable[[str], Awaitable[None]] | None = None, + ) -> None: + self._config = config + self._states: dict[str, BackendState] = { + b.url: BackendState(config=b, slot_capacity=config.default_slot_capacity) + for b in config.backends + } + self._model_index: dict[str, list[BackendState]] = {} + self._lock = asyncio.Lock() + self._session: aiohttp.ClientSession | None = None + self._task: asyncio.Task | None = None + self._on_backend_recovered = on_backend_recovered + + # ------------------------------------------------------------------ + # Public read API (non-blocking, returns snapshots) + # ------------------------------------------------------------------ + + def get_backends_for_model(self, model_id: str) -> list[BackendState]: + return list(self._model_index.get(model_id, [])) + + def get_all_live_backends(self) -> list[BackendState]: + return [s for s in self._states.values() if s.live] + + def get_all_models(self) -> list[str]: + seen: set[str] = set() + result: list[str] = [] + for state in self._states.values(): + if state.live: + for m in state.models: + if m not in seen: + seen.add(m) + result.append(m) + return result + + def get_all_states(self) -> list[BackendState]: + return list(self._states.values()) + + def get_state(self, url: str) -> BackendState | None: + return self._states.get(url) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start(self) -> None: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=10), + connector=aiohttp.TCPConnector(ssl=False), + ) + self._task = asyncio.create_task(self._poll_loop(), name="registry-poll") + + async def stop(self) -> None: + if self._task: + self._task.cancel() + await asyncio.gather(self._task, return_exceptions=True) + if self._session: + await self._session.close() + + # ------------------------------------------------------------------ + # Polling + # ------------------------------------------------------------------ + + async def _poll_loop(self) -> None: + while True: + await self._poll_all() + await asyncio.sleep(self._config.poll_interval) + + async def _poll_all(self) -> None: + states = list(self._states.values()) + results = await asyncio.gather( + *[self._poll_one(s) for s in states], return_exceptions=True + ) + for state, result in zip(states, results): + if isinstance(result, Exception): + log.warning("Poll error for %s: %s", state.url, result) + + async with self._lock: + self._rebuild_index() + + async def _poll_one(self, state: BackendState) -> None: + assert self._session is not None + url = state.url + headers = {} + if state.config.api_key: + headers["Authorization"] = f"Bearer {state.config.api_key}" + + try: + async with self._session.get(url + _HEALTH_PATH, headers=headers) as resp: + live = resp.status == 200 + except Exception: + live = False + + if live: + models = await self._fetch_models(state, headers) + capacity = await self._fetch_slot_capacity(state, headers) + else: + models = state.models + capacity = state.slot_capacity + + was_live = state.live + async with self._lock: + state.live = live + if live: + state.models = models + state.slot_capacity = capacity + state.last_poll_time = time.monotonic() + + if live and not was_live and self._on_backend_recovered: + await self._on_backend_recovered(url) + + async def _fetch_models(self, state: BackendState, headers: dict) -> list[str]: + assert self._session is not None + if state.config.model_ids: + return list(state.config.model_ids) + try: + async with self._session.get( + state.url + _MODELS_PATH, headers=headers + ) as resp: + if resp.status != 200: + return state.models + data = await resp.json() + return [m["id"] for m in data.get("data", [])] + except Exception as exc: + log.debug("Failed to fetch models from %s: %s", state.url, exc) + return state.models + + async def _fetch_slot_capacity(self, state: BackendState, headers: dict) -> int: + assert self._session is not None + try: + async with self._session.get( + state.url + _SLOTS_PATH, headers=headers + ) as resp: + if resp.status != 200: + return state.slot_capacity + data = await resp.json() + if isinstance(data, list): + return max(len(data), 1) + except Exception: + pass + return state.slot_capacity + + def _rebuild_index(self) -> None: + index: dict[str, list[BackendState]] = {} + for state in self._states.values(): + if not state.live: + continue + for model in state.models: + index.setdefault(model, []).append(state) + self._model_index = index diff --git a/src/llamacpp_ha/scheduler.py b/src/llamacpp_ha/scheduler.py new file mode 100644 index 0000000..14a9756 --- /dev/null +++ b/src/llamacpp_ha/scheduler.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import asyncio +import logging + +from .policies import BackendCandidate, RoutingPolicy +from .queue import QueueEntry, RequestQueue +from .registry import BackendRegistry, BackendState +from .session_store import SessionStore +from .slot_tracker import SlotTracker + +log = logging.getLogger(__name__) + + +class Scheduler: + """ + Continuously binds queued requests to backends. + + Dispatch order + -------------- + When max_queue_skip == 0 (default) the queue is pure FIFO. + + When max_queue_skip > N the scheduler runs a two-phase dispatch on every + wakeup: + + Phase 1 — model-affinity promotion + The scheduler scans the queue looking for entries whose model is already + in-flight on a free backend. It promotes those entries ahead of earlier + entries that would require a model switch. For every entry that is + bypassed, its skip_count is incremented. Scanning stops as soon as an + entry with skip_count >= max_queue_skip is encountered (that entry is + frozen at the head and must be served next, preventing starvation). + + Phase 2 — standard FIFO + Remaining entries are dispatched in arrival order to any backend that + can accept the model (slot free, max_models constraint satisfied). + + Session affinity + ---------------- + Session affinity is applied inside _resolve_backend as a hint: if the + preferred backend is in the candidate set it is chosen first. Affinity + is re-pinned to whichever backend ultimately serves the request; it is + never queued on a preferred backend when a free alternative exists. + + Preemption prevention + --------------------- + SlotTracker.can_accept() enforces the per-backend max_models limit. A + backend with max_models=1 that is serving model-A will reject model-B + requests until all model-A slots are released, preventing llama.cpp from + preempting the in-flight generation. + """ + + def __init__( + self, + queue: RequestQueue, + registry: BackendRegistry, + slot_tracker: SlotTracker, + session_store: SessionStore, + policy: RoutingPolicy, + max_queue_skip: int = 0, + ) -> None: + self._queue = queue + self._registry = registry + self._slots = slot_tracker + self._sessions = session_store + self._policy = policy + self._max_queue_skip = max_queue_skip + self._task: asyncio.Task | None = None + + def notify_slot_released(self) -> None: + """Called by the forwarder after a slot is freed; wakes the dispatch loop.""" + self._queue.notify() + + def start(self) -> None: + self._task = asyncio.create_task(self._loop(), name="scheduler") + + async def stop(self) -> None: + if self._task: + self._task.cancel() + await asyncio.gather(self._task, return_exceptions=True) + + async def _loop(self) -> None: + while True: + await self._queue.wakeup_event.wait() + self._queue.wakeup_event.clear() + await self._dispatch_all() + + # ------------------------------------------------------------------ + # Dispatch + # ------------------------------------------------------------------ + + async def _dispatch_all(self) -> None: + entries = await self._queue.pending() + + # Prune stale futures first so they don't count in skip logic. + for entry in entries: + if entry.future is None or entry.future.done(): + await self._queue.remove(entry) + entries = [e for e in entries if e.future is not None and not e.future.done()] + + if not entries: + return + + dispatched: set[int] = set() + + if self._max_queue_skip > 0: + await self._affinity_pass(entries, dispatched) + + # Standard FIFO pass for all remaining live entries. + for entry in entries: + if id(entry) in dispatched: + continue + if entry.future is None or entry.future.done(): + await self._queue.remove(entry) + continue + if await self._try_dispatch(entry): + dispatched.add(id(entry)) + await self._queue.remove(entry) + + async def _affinity_pass( + self, entries: list[QueueEntry], dispatched: set[int] + ) -> None: + """Phase 1: promote entries whose model is already active on a free backend. + + Stops scanning as soon as an entry with skip_count >= max_queue_skip is + encountered — that entry is frozen and must be served by the FIFO pass. + """ + for i, entry in enumerate(entries): + if id(entry) in dispatched: + continue + if entry.skip_count >= self._max_queue_skip: + break # frozen head-of-line; FIFO pass must handle it next + + if await self._try_dispatch_affinity(entry): + dispatched.add(id(entry)) + await self._queue.remove(entry) + # Bump skip_count for every earlier entry we bypassed. + for j in range(i): + if id(entries[j]) not in dispatched: + entries[j].skip_count += 1 + + async def _try_dispatch_affinity(self, entry: QueueEntry) -> bool: + """Dispatch only to a backend that already has entry.model_id in-flight.""" + if not entry.model_id: + return False + live_backends = self._registry.get_backends_for_model(entry.model_id) + active_backends = [ + b for b in live_backends + if entry.model_id in self._slots.active_model_set(b.url) + and self._slots.can_accept(b.url, entry.model_id) + ] + if not active_backends: + return False + return await self._acquire_and_resolve(entry, active_backends) + + async def _try_dispatch(self, entry: QueueEntry) -> bool: + """Standard dispatch: any live backend that can accept this model.""" + if entry.model_id: + live_backends = self._registry.get_backends_for_model(entry.model_id) + else: + live_backends = self._registry.get_all_live_backends() + + if not live_backends: + return False + + free_backends = [ + b for b in live_backends if self._slots.can_accept(b.url, entry.model_id) + ] + if not free_backends: + return False + + return await self._acquire_and_resolve(entry, free_backends) + + async def _acquire_and_resolve( + self, entry: QueueEntry, candidates: list[BackendState] + ) -> bool: + chosen = await self._resolve_backend(entry, candidates) + + # has_free_slot + acquire are not atomic: another coroutine could have + # taken the slot between the check above and here. timeout=0 means + # "succeed now or not at all" so we simply leave the entry queued. + try: + async with asyncio.timeout(0): + await self._slots.acquire(chosen.url, entry.model_id) + except TimeoutError: + return False + + if entry.future is not None and not entry.future.done(): + entry.assigned_backend = chosen.url + entry.future.set_result(chosen) + log.debug("Dispatched %s -> %s", entry.request_id, chosen.url) + return True + + # Future was cancelled between our acquire and here — release the slot back. + await self._slots.release(chosen.url, entry.model_id) + return False + + async def _resolve_backend( + self, entry: QueueEntry, candidates: list[BackendState] + ) -> BackendState: + """Apply session affinity then routing policy to pick one backend.""" + if entry.session_id: + preferred_url = await self._sessions.get_preferred_backend(entry.session_id) + if preferred_url: + for b in candidates: + if b.url == preferred_url: + return b + backend_candidates = [ + BackendCandidate(url=b.url, index=i) for i, b in enumerate(candidates) + ] + chosen = self._policy.select(entry.model_id or "_any", backend_candidates) + return next(b for b in candidates if b.url == chosen.url) diff --git a/src/llamacpp_ha/session_store.py b/src/llamacpp_ha/session_store.py new file mode 100644 index 0000000..40db6b3 --- /dev/null +++ b/src/llamacpp_ha/session_store.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import time +import uuid +from dataclasses import dataclass, field + + +@dataclass +class Session: + session_id: str + model_id: str | None = None + last_message_index: int = 0 + prefix_hash: str = "" + preferred_backend: str | None = None + last_active: float = field(default_factory=time.monotonic) + + def touch(self) -> None: + self.last_active = time.monotonic() + + def is_expired(self, ttl: float) -> bool: + return (time.monotonic() - self.last_active) > ttl + + +def compute_prefix_hash(messages: list[dict]) -> str: + """Hash the messages list for KV-cache affinity tracking.""" + blob = json.dumps(messages, sort_keys=True, separators=(",", ":")).encode() + return hashlib.sha256(blob).hexdigest()[:16] + + +class SessionStore: + def __init__(self, ttl: float = 300.0) -> None: + self._ttl = ttl + self._sessions: dict[str, Session] = {} + self._lock = asyncio.Lock() + + def new_session_id(self) -> str: + return uuid.uuid4().hex + + async def get_or_create(self, session_id: str) -> Session: + async with self._lock: + if session_id not in self._sessions: + self._sessions[session_id] = Session(session_id=session_id) + session = self._sessions[session_id] + session.touch() + return session + + async def update( + self, + session_id: str, + model_id: str | None = None, + messages: list[dict] | None = None, + preferred_backend: str | None = None, + ) -> None: + async with self._lock: + if session_id not in self._sessions: + self._sessions[session_id] = Session(session_id=session_id) + session = self._sessions[session_id] + if model_id is not None: + session.model_id = model_id + if messages is not None: + session.last_message_index = len(messages) + session.prefix_hash = compute_prefix_hash(messages) + if preferred_backend is not None: + session.preferred_backend = preferred_backend + session.touch() + + async def get_preferred_backend(self, session_id: str) -> str | None: + async with self._lock: + session = self._sessions.get(session_id) + if session is None or session.is_expired(self._ttl): + return None + return session.preferred_backend + + async def expire(self) -> int: + """Remove expired sessions. Returns count removed.""" + async with self._lock: + expired = [ + sid + for sid, s in self._sessions.items() + if s.is_expired(self._ttl) + ] + for sid in expired: + del self._sessions[sid] + return len(expired) + + async def count(self) -> int: + async with self._lock: + return sum( + 1 for s in self._sessions.values() if not s.is_expired(self._ttl) + ) + + async def count_by_model(self) -> dict[str, int]: + async with self._lock: + result: dict[str, int] = {} + for s in self._sessions.values(): + if not s.is_expired(self._ttl) and s.model_id: + result[s.model_id] = result.get(s.model_id, 0) + 1 + return result + + async def snapshot(self) -> list[Session]: + async with self._lock: + return [ + s for s in self._sessions.values() if not s.is_expired(self._ttl) + ] diff --git a/src/llamacpp_ha/slot_tracker.py b/src/llamacpp_ha/slot_tracker.py new file mode 100644 index 0000000..71ebfe4 --- /dev/null +++ b/src/llamacpp_ha/slot_tracker.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + + +@dataclass +class _SlotState: + capacity: int + max_models: int | None = None + acquired: int = 0 + active_models: dict[str, int] = field(default_factory=dict) + condition: asyncio.Condition = field(default_factory=asyncio.Condition) + + +class SlotTracker: + """Tracks per-backend slot usage with model-aware acquire/release. + + Callers control timeouts via asyncio.timeout() at the call site rather than + passing a timeout parameter here — acquire() blocks until a slot is free. + """ + + def __init__(self) -> None: + self._slots: dict[str, _SlotState] = {} + + def _ensure(self, url: str, capacity: int = 1) -> _SlotState: + if url not in self._slots: + self._slots[url] = _SlotState(capacity=capacity) + return self._slots[url] + + def set_capacity(self, url: str, capacity: int) -> None: + state = self._ensure(url, capacity) + state.capacity = max(capacity, 1) + + def set_max_models(self, url: str, max_models: int | None) -> None: + state = self._ensure(url) + state.max_models = max_models + + # ------------------------------------------------------------------ + # Read helpers (non-blocking) + # ------------------------------------------------------------------ + + def _can_acquire(self, state: _SlotState, model_id: str) -> bool: + if state.acquired >= state.capacity: + return False + if model_id and model_id not in state.active_models: + if state.max_models is not None and len(state.active_models) >= state.max_models: + return False + return True + + def has_free_slot(self, url: str) -> bool: + state = self._slots.get(url) + if state is None: + return True + return state.acquired < state.capacity + + def can_accept(self, url: str, model_id: str) -> bool: + """True if this backend can accept a new request for model_id right now.""" + state = self._slots.get(url) + if state is None: + return True + return self._can_acquire(state, model_id) + + def active_model_set(self, url: str) -> frozenset[str]: + """Models currently in-flight on this backend.""" + state = self._slots.get(url) + if state is None: + return frozenset() + return frozenset(state.active_models) + + def usage(self, url: str) -> tuple[int, int]: + state = self._slots.get(url) + if state is None: + return 0, 1 + return state.acquired, state.capacity + + # ------------------------------------------------------------------ + # Acquire / release + # ------------------------------------------------------------------ + + async def acquire(self, url: str, model_id: str = "") -> None: + """Acquire a slot for model_id. Blocks until one is available. + + Use asyncio.timeout() at the call site to bound the wait: + try: + async with asyncio.timeout(5.0): + await tracker.acquire(url, model_id) + except TimeoutError: + ... + """ + state = self._ensure(url) + async with state.condition: + while not self._can_acquire(state, model_id): + await state.condition.wait() + state.acquired += 1 + if model_id: + state.active_models[model_id] = state.active_models.get(model_id, 0) + 1 + + async def release(self, url: str, model_id: str = "") -> None: + """Release a slot and notify waiters.""" + state = self._slots.get(url) + if state is None: + return + async with state.condition: + state.acquired = max(0, state.acquired - 1) + if model_id and model_id in state.active_models: + count = state.active_models[model_id] - 1 + if count <= 0: + del state.active_models[model_id] + else: + state.active_models[model_id] = count + state.condition.notify_all() + + async def reset_acquired(self, url: str) -> None: + """Zero out slot tracking for a backend that just recovered from a crash.""" + state = self._slots.get(url) + if state is None: + return + async with state.condition: + state.acquired = 0 + state.active_models.clear() + state.condition.notify_all() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..48656ef --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,130 @@ +import json +import os +import tempfile +import unittest + +from llamacpp_ha.config import BackendConfig, ProxyConfig + + +class TestBackendConfig(unittest.TestCase): + def test_url_trailing_slash_stripped(self): + b = BackendConfig(url="http://localhost:8080/") + self.assertEqual(b.url, "http://localhost:8080") + + def test_defaults(self): + b = BackendConfig(url="http://localhost:8080") + self.assertIsNone(b.api_key) + self.assertEqual(b.model_ids, []) + self.assertIsNone(b.max_models) + + def test_explicit_model_ids(self): + b = BackendConfig(url="http://x", model_ids=["llama3", "mistral"]) + self.assertEqual(b.model_ids, ["llama3", "mistral"]) + + def test_max_models(self): + b = BackendConfig(url="http://x", max_models=1) + self.assertEqual(b.max_models, 1) + + def test_max_models_none_is_unlimited(self): + b = BackendConfig(url="http://x", max_models=None) + self.assertIsNone(b.max_models) + + +class TestProxyConfig(unittest.TestCase): + def _write_config(self, data: dict) -> str: + f = tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) + json.dump(data, f) + f.close() + return f.name + + def test_defaults(self): + cfg = ProxyConfig() + self.assertEqual(cfg.host, "0.0.0.0") + self.assertEqual(cfg.port, 8080) + self.assertEqual(cfg.api_keys, []) + self.assertEqual(cfg.poll_interval, 5.0) + self.assertEqual(cfg.slot_wait_timeout, 30.0) + self.assertEqual(cfg.session_idle_ttl, 300.0) + self.assertEqual(cfg.backends, []) + self.assertIsNone(cfg.default_max_models) + self.assertEqual(cfg.max_queue_skip, 0) + + def test_from_file_minimal(self): + path = self._write_config( + {"backends": [{"url": "http://localhost:8081"}]} + ) + try: + cfg = ProxyConfig.from_file(path) + self.assertEqual(len(cfg.backends), 1) + self.assertEqual(cfg.backends[0].url, "http://localhost:8081") + finally: + os.unlink(path) + + def test_from_file_full(self): + data = { + "host": "127.0.0.1", + "port": 9090, + "api_keys": ["key1", "key2"], + "poll_interval": 10, + "slot_wait_timeout": 60, + "session_idle_ttl": 600, + "default_max_models": 2, + "max_queue_skip": 5, + "backends": [ + {"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1}, + {"url": "http://b2/"}, + ], + } + path = self._write_config(data) + try: + cfg = ProxyConfig.from_file(path) + self.assertEqual(cfg.host, "127.0.0.1") + self.assertEqual(cfg.port, 9090) + self.assertEqual(cfg.api_keys, ["key1", "key2"]) + self.assertEqual(cfg.poll_interval, 10) + self.assertEqual(cfg.default_max_models, 2) + self.assertEqual(cfg.max_queue_skip, 5) + self.assertEqual(cfg.backends[0].api_key, "secret") + self.assertEqual(cfg.backends[0].model_ids, ["m1"]) + self.assertEqual(cfg.backends[0].max_models, 1) + self.assertEqual(cfg.backends[1].url, "http://b2") + finally: + os.unlink(path) + + def test_from_file_with_overrides(self): + path = self._write_config({"port": 8080, "backends": []}) + try: + cfg = ProxyConfig.from_file(path, port=9999) + self.assertEqual(cfg.port, 9999) + finally: + os.unlink(path) + + def test_env_var_override(self): + old = os.environ.get("LLAMACPP_HA_PORT") + try: + os.environ["LLAMACPP_HA_PORT"] = "7777" + cfg = ProxyConfig() + self.assertEqual(cfg.port, 7777) + finally: + if old is None: + os.environ.pop("LLAMACPP_HA_PORT", None) + else: + os.environ["LLAMACPP_HA_PORT"] = old + + def test_invalid_json_raises(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + f.write("not json {{{") + path = f.name + try: + with self.assertRaises(Exception): + ProxyConfig.from_file(path) + finally: + os.unlink(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py new file mode 100644 index 0000000..fae5b37 --- /dev/null +++ b/tests/test_forwarder.py @@ -0,0 +1,276 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock + +from fastapi import Request +from starlette.datastructures import Headers + +from llamacpp_ha.config import BackendConfig +from llamacpp_ha.forwarder import _forward_headers, forward_request +from llamacpp_ha.registry import BackendState +from llamacpp_ha.slot_tracker import SlotTracker + + +def _make_state(url="http://b1", api_key=None): + cfg = BackendConfig(url=url, api_key=api_key) + return BackendState(config=cfg) + + +def _make_request(headers: dict, method="POST", path="/v1/chat/completions") -> MagicMock: + req = MagicMock() + req.headers = Headers(headers=headers) + req.url.path = path + req.url.query = "" + req.method = method + req.body = AsyncMock(return_value=b"") + return req + + +def _make_scheduler(): + sched = MagicMock() + sched.notify_slot_released = MagicMock() + return sched + + +class TestForwardHeaders(unittest.TestCase): + def test_injects_backend_api_key(self): + state = _make_state(api_key="backend-secret") + req = _make_request({"Authorization": "Bearer client-key", "Content-Type": "application/json"}) + headers = _forward_headers(req, state) + self.assertEqual(headers["Authorization"], "Bearer backend-secret") + combined = {k.lower(): v for k, v in headers.items()} + self.assertIn("application/json", combined.get("content-type", "")) + + def test_removes_auth_when_no_backend_key(self): + state = _make_state(api_key=None) + req = _make_request({"Authorization": "Bearer client-key"}) + headers = _forward_headers(req, state) + combined = {k.lower(): v for k, v in headers.items()} + self.assertNotIn("authorization", combined) + + def test_hop_by_hop_stripped(self): + state = _make_state() + req = _make_request({ + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + "X-Custom": "value", + }) + headers = _forward_headers(req, state) + combined = {k.lower(): v for k, v in headers.items()} + self.assertNotIn("connection", combined) + self.assertNotIn("transfer-encoding", combined) + self.assertEqual(combined.get("x-custom"), "value") + + def test_host_header_stripped(self): + state = _make_state() + req = _make_request({"Host": "proxy.local", "Accept": "application/json"}) + headers = _forward_headers(req, state) + combined = {k.lower(): v for k, v in headers.items()} + self.assertNotIn("host", combined) + self.assertIn("accept", combined) + + +def _mock_aiohttp_response( + status: int = 200, + content_type: str = "application/json", + body: bytes = b'{"ok":true}', + headers: dict | None = None, +) -> MagicMock: + resp = MagicMock() + resp.status = status + all_headers = {"Content-Type": content_type} + if headers: + all_headers.update(headers) + resp.headers = all_headers + resp.read = AsyncMock(return_value=body) + resp.content = MagicMock() + resp.content.iter_chunked = MagicMock() + + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=resp) + ctx.__aexit__ = AsyncMock(return_value=False) + return ctx, resp + + +class TestForwardRequestNonStreaming(unittest.IsolatedAsyncioTestCase): + async def _acquire(self, tracker: SlotTracker, url: str) -> None: + async with asyncio.timeout(1.0): + await tracker.acquire(url) + + async def test_successful_passthrough(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + await self._acquire(slot_tracker, "http://b1") + scheduler = _make_scheduler() + + ctx, _ = _mock_aiohttp_response(status=200, body=b'{"choices":[]}') + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + response = await forward_request(req, state, session, slot_tracker, scheduler) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.body, b'{"choices":[]}') + acquired, _ = slot_tracker.usage("http://b1") + self.assertEqual(acquired, 0) + scheduler.notify_slot_released.assert_called_once() + + async def test_slot_released_on_forward_error(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + await self._acquire(slot_tracker, "http://b1") + scheduler = _make_scheduler() + + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(side_effect=Exception("network error")) + ctx.__aexit__ = AsyncMock(return_value=False) + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + with self.assertRaises(Exception): + await forward_request(req, state, session, slot_tracker, scheduler) + + acquired, _ = slot_tracker.usage("http://b1") + self.assertEqual(acquired, 0) + scheduler.notify_slot_released.assert_called_once() + + async def test_non_streaming_returns_full_body(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + await self._acquire(slot_tracker, "http://b1") + scheduler = _make_scheduler() + + body = b'{"id":"xyz","choices":[{"text":"hello"}]}' + ctx, _ = _mock_aiohttp_response(body=body) + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + response = await forward_request(req, state, session, slot_tracker, scheduler) + + self.assertEqual(response.body, body) + + async def test_status_code_preserved(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + await self._acquire(slot_tracker, "http://b1") + scheduler = _make_scheduler() + + ctx, _ = _mock_aiohttp_response(status=429, body=b"rate limited") + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + response = await forward_request(req, state, session, slot_tracker, scheduler) + + self.assertEqual(response.status_code, 429) + + async def test_model_id_tracked_in_active_models(self): + """model_id passed to forward_request is reflected in slot tracker.""" + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 2) + async with asyncio.timeout(1.0): + await slot_tracker.acquire("http://b1", "my-model") + scheduler = _make_scheduler() + + ctx, _ = _mock_aiohttp_response(status=200, body=b"{}") + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + await forward_request(req, state, session, slot_tracker, scheduler, model_id="my-model") + + self.assertEqual(slot_tracker.active_model_set("http://b1"), frozenset()) + + +class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase): + async def test_streaming_sse_passthrough(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + async with asyncio.timeout(1.0): + await slot_tracker.acquire("http://b1") + scheduler = _make_scheduler() + + sse_chunks = [ + b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n', + b'data: [DONE]\n\n', + ] + + async def fake_iter_chunked(_size): + for chunk in sse_chunks: + yield chunk + + resp = MagicMock() + resp.status = 200 + resp.headers = {"Content-Type": "text/event-stream"} + resp.content = MagicMock() + resp.content.iter_chunked = fake_iter_chunked + + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + response = await forward_request(req, state, session, slot_tracker, scheduler) + + from fastapi.responses import StreamingResponse + self.assertIsInstance(response, StreamingResponse) + + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + self.assertEqual(chunks, sse_chunks) + acquired, _ = slot_tracker.usage("http://b1") + self.assertEqual(acquired, 0) + scheduler.notify_slot_released.assert_called_once() + + async def test_streaming_slot_released_on_stream_end(self): + state = _make_state("http://b1") + slot_tracker = SlotTracker() + slot_tracker.set_capacity("http://b1", 1) + async with asyncio.timeout(1.0): + await slot_tracker.acquire("http://b1") + scheduler = _make_scheduler() + + async def fake_iter_chunked(_size): + yield b"data: chunk1\n\n" + yield b"data: chunk2\n\n" + + resp = MagicMock() + resp.status = 200 + resp.headers = {"Content-Type": "text/event-stream"} + resp.content = MagicMock() + resp.content.iter_chunked = fake_iter_chunked + + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session = MagicMock() + session.request = MagicMock(return_value=ctx) + + req = _make_request({}) + response = await forward_request(req, state, session, slot_tracker, scheduler) + + acquired, _ = slot_tracker.usage("http://b1") + self.assertEqual(acquired, 1) + + chunks = [chunk async for chunk in response.body_iterator] + self.assertEqual(len(chunks), 2) + + acquired, _ = slot_tracker.usage("http://b1") + self.assertEqual(acquired, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..591ad4e --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,477 @@ +""" +Integration tests using an in-process fake llama.cpp backend. + +The fake backend runs as a FastAPI app via httpx.AsyncClient(transport=...). +We patch aiohttp calls in the registry/forwarder so that requests go to our +fake server without opening real TCP sockets. +""" +from __future__ import annotations + +import asyncio +import json +import unittest +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.testclient import TestClient + +from llamacpp_ha.config import BackendConfig, ProxyConfig +from llamacpp_ha.proxy import create_app + + +# --------------------------------------------------------------------------- +# Helpers to build a minimal proxy with pre-seeded live backends +# --------------------------------------------------------------------------- + + +def _live_backend_config(url: str, models: list[str] | None = None) -> BackendConfig: + return BackendConfig(url=url, model_ids=models or ["test-model"]) + + +def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) -> ProxyConfig: + return ProxyConfig( + host="127.0.0.1", + port=8080, + api_keys=api_keys or [], + poll_interval=9999, # disable background polling during tests + slot_wait_timeout=1.0, + session_idle_ttl=300.0, + default_slot_capacity=2, + backends=[_live_backend_config(url) for url in backend_urls], + ) + + +def _seed_live(app_obj, backend_urls: list[str], models=None) -> None: + """Directly seed the registry with live backends (skipping real HTTP poll).""" + from llamacpp_ha.registry import BackendRegistry + # Access the app's state via lifespan-created objects + # We do this by patching _poll_all to be a no-op and manually setting state + pass + + +# --------------------------------------------------------------------------- +# Fake llama.cpp server (in-process, no real sockets) +# --------------------------------------------------------------------------- + + +def _make_fake_backend( + model_id: str = "test-model", + response_content: str = '{"choices":[{"message":{"role":"assistant","content":"hello"}}]}', + streaming: bool = False, + slow_seconds: float = 0, + slot_count: int = 2, +) -> FastAPI: + fake = FastAPI() + + @fake.get("/health") + async def health(): + if slow_seconds: + await asyncio.sleep(slow_seconds) + return Response(status_code=200) + + @fake.get("/v1/models") + async def models(): + return JSONResponse({"object": "list", "data": [{"id": model_id, "object": "model"}]}) + + @fake.get("/slots") + async def slots(): + return JSONResponse([{"id": i, "state": 0} for i in range(slot_count)]) + + @fake.post("/v1/chat/completions") + async def chat(request: Request): + if slow_seconds: + await asyncio.sleep(slow_seconds) + body = await request.json() + stream = body.get("stream", False) + if stream or streaming: + async def gen(): + yield b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n' + yield b"data: [DONE]\n\n" + return StreamingResponse(gen(), media_type="text/event-stream") + return JSONResponse(json.loads(response_content)) + + @fake.post("/v1/completions") + async def completions(request: Request): + return JSONResponse({"choices": [{"text": "hello"}]}) + + return fake + + +# --------------------------------------------------------------------------- +# Integration test cases +# --------------------------------------------------------------------------- + + +class TestHealthEndpoint(unittest.TestCase): + def _make_proxy_with_live_backend(self): + cfg = _proxy_config(["http://fake-b1"]) + app = create_app(cfg) + + # Seed registry directly before first request + def _patch_registry(app_obj): + pass + + return app, cfg + + def test_health_no_live_backend_returns_503(self): + cfg = _proxy_config([]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/health") + self.assertEqual(resp.status_code, 503) + + def test_health_with_api_key_required(self): + cfg = _proxy_config(["http://b1"], api_keys=["mykey"]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/health") + self.assertEqual(resp.status_code, 401) + + def test_health_with_valid_api_key(self): + cfg = _proxy_config([], api_keys=["mykey"]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + # No live backends -> 503 but authenticated + resp = client.get("/health", headers={"Authorization": "Bearer mykey"}) + self.assertEqual(resp.status_code, 503) + + +class TestModelsEndpoint(unittest.TestCase): + def test_models_returns_list_format(self): + cfg = _proxy_config(["http://b1"]) + app = create_app(cfg) + # Seed registry + from llamacpp_ha.registry import BackendState + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/v1/models") + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["object"], "list") + self.assertIn("data", data) + + def test_models_deduplication(self): + """Models appearing on multiple backends appear once.""" + cfg = ProxyConfig( + backends=[ + BackendConfig(url="http://b1", model_ids=["shared", "only-b1"]), + BackendConfig(url="http://b2", model_ids=["shared", "only-b2"]), + ] + ) + app = create_app(cfg) + from llamacpp_ha.registry import BackendState + # Manually seed live + import asyncio + + async def _seed(): + from llamacpp_ha import proxy as _proxy_module + # Can't easily seed without internal access in this test structure + pass + + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/v1/models") + data = resp.json() + model_ids = [m["id"] for m in data["data"]] + self.assertEqual(len(model_ids), len(set(model_ids)), "Duplicates found") + + +class TestApiKeyMiddlewareIntegration(unittest.TestCase): + def test_request_rejected_without_key(self): + cfg = _proxy_config([], api_keys=["secret"]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + resp = client.post("/v1/chat/completions", json={"model": "m", "messages": []}) + self.assertEqual(resp.status_code, 401) + + def test_monitor_exempt_from_auth(self): + cfg = _proxy_config([], api_keys=["secret"]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/monitor") + self.assertEqual(resp.status_code, 200) + + def test_monitor_data_exempt_from_auth(self): + cfg = _proxy_config([], api_keys=["secret"]) + app = create_app(cfg) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/monitor/data") + self.assertEqual(resp.status_code, 200) + + +class TestSlotExhaustionTimeout(unittest.TestCase): + def test_timeout_when_no_slot_available(self): + """Request with no live backends times out and returns 503.""" + cfg = ProxyConfig( + backends=[BackendConfig(url="http://b1", model_ids=["m"])], + slot_wait_timeout=0.2, + default_slot_capacity=1, + ) + app = create_app(cfg) + # No live backends -> scheduler never dispatches -> timeout -> 503 + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.post( + "/v1/chat/completions", + json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, + ) + self.assertEqual(resp.status_code, 503) + + +class TestCatchAllForwarding(unittest.TestCase): + def test_catch_all_no_backends_returns_503(self): + cfg = _proxy_config([]) + app = create_app(cfg) + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/some/unknown/path") + self.assertEqual(resp.status_code, 503) + + +class TestMonitorIntegration(unittest.TestCase): + def test_monitor_data_reflects_queue_depth(self): + cfg = ProxyConfig( + backends=[BackendConfig(url="http://b1", model_ids=["m"])], + slot_wait_timeout=0.1, + default_slot_capacity=0, + ) + app = create_app(cfg) + with TestClient(app, raise_server_exceptions=False) as client: + data = client.get("/monitor/data").json() + self.assertIsInstance(data["queue_depth"], int) + self.assertGreaterEqual(data["queue_depth"], 0) + + def test_monitor_data_shows_all_backends(self): + cfg = _proxy_config(["http://b1", "http://b2", "http://b3"]) + app = create_app(cfg) + with TestClient(app, raise_server_exceptions=False) as client: + data = client.get("/monitor/data").json() + self.assertEqual(len(data["backends"]), 3) + urls = {b["url"] for b in data["backends"]} + self.assertEqual(urls, {"http://b1", "http://b2", "http://b3"}) + + +class TestFullForwardPath(unittest.TestCase): + """ + Full path test: proxy starts up, monitor/data reflects backend state + (dead until polled), scheduler queues requests when no backends are live. + """ + + def test_backend_initially_dead_no_poll(self): + cfg = ProxyConfig( + backends=[BackendConfig(url="http://fake", model_ids=["test-model"])], + slot_wait_timeout=2.0, + default_slot_capacity=2, + poll_interval=9999, # no background polling in test + ) + app = create_app(cfg) + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.get("/monitor/data") + data = resp.json() + # Backend exists but is dead until first poll completes + self.assertEqual(len(data["backends"]), 1) + self.assertFalse(data["backends"][0]["live"]) + + +class TestBackendFailover(unittest.IsolatedAsyncioTestCase): + async def test_dead_backend_not_in_model_index(self): + """Dead backends are not returned for model routing.""" + from llamacpp_ha.config import BackendConfig, ProxyConfig + from llamacpp_ha.registry import BackendRegistry, BackendState + + cfg = ProxyConfig( + backends=[ + BackendConfig(url="http://b1", model_ids=["m"]), + BackendConfig(url="http://b2", model_ids=["m"]), + ] + ) + reg = BackendRegistry(cfg) + + # Simulate: b1 live, b2 dead + async with reg._lock: + reg._states["http://b1"].live = True + reg._states["http://b1"].models = ["m"] + reg._states["http://b2"].live = False + reg._rebuild_index() + + backends = reg.get_backends_for_model("m") + self.assertEqual(len(backends), 1) + self.assertEqual(backends[0].url, "http://b1") + + async def test_all_dead_returns_empty(self): + from llamacpp_ha.config import BackendConfig, ProxyConfig + from llamacpp_ha.registry import BackendRegistry + + cfg = ProxyConfig( + backends=[ + BackendConfig(url="http://b1", model_ids=["m"]), + ] + ) + reg = BackendRegistry(cfg) + backends = reg.get_backends_for_model("m") + self.assertEqual(backends, []) + + +class TestModelRouting(unittest.IsolatedAsyncioTestCase): + async def test_routes_to_backend_serving_model(self): + from llamacpp_ha.config import BackendConfig, ProxyConfig + from llamacpp_ha.policies import RoundRobinPolicy + from llamacpp_ha.queue import QueueEntry, RequestQueue + from llamacpp_ha.registry import BackendRegistry + from llamacpp_ha.scheduler import Scheduler + from llamacpp_ha.session_store import SessionStore + from llamacpp_ha.slot_tracker import SlotTracker + + cfg = ProxyConfig( + backends=[ + BackendConfig(url="http://b1", model_ids=["model-a"]), + BackendConfig(url="http://b2", model_ids=["model-b"]), + ] + ) + reg = BackendRegistry(cfg) + async with reg._lock: + reg._states["http://b1"].live = True + reg._states["http://b1"].models = ["model-a"] + reg._states["http://b2"].live = True + reg._states["http://b2"].models = ["model-b"] + reg._rebuild_index() + + slots = SlotTracker() + slots.set_capacity("http://b1", 2) + slots.set_capacity("http://b2", 2) + sessions = SessionStore() + queue = RequestQueue() + scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy()) + + loop = asyncio.get_running_loop() + entry = QueueEntry(model_id="model-b", future=loop.create_future()) + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertTrue(entry.future.done()) + self.assertEqual(entry.future.result().url, "http://b2") + + +class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase): + """ + Slot is fully occupied → request queues → slot released → request dispatched. + Tests the complete wait-and-unblock path without involving real HTTP. + """ + + async def test_queued_request_dispatched_after_slot_release(self): + from llamacpp_ha.config import BackendConfig, ProxyConfig + from llamacpp_ha.policies import RoundRobinPolicy + from llamacpp_ha.queue import QueueEntry, RequestQueue + from llamacpp_ha.registry import BackendRegistry + from llamacpp_ha.scheduler import Scheduler + from llamacpp_ha.session_store import SessionStore + from llamacpp_ha.slot_tracker import SlotTracker + + cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])]) + reg = BackendRegistry(cfg) + async with reg._lock: + reg._states["http://b1"].live = True + reg._states["http://b1"].models = ["m"] + reg._rebuild_index() + + slots = SlotTracker() + slots.set_capacity("http://b1", 1) + sessions = SessionStore() + queue = RequestQueue() + scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy()) + + # Occupy the only slot (simulating an in-flight request) + async with asyncio.timeout(1.0): + await slots.acquire("http://b1") + + # Enqueue a request — scheduler should not be able to dispatch + loop = asyncio.get_running_loop() + entry = QueueEntry(model_id="m", future=loop.create_future()) + await queue.enqueue(entry) + await scheduler._dispatch_all() + self.assertFalse(entry.future.done(), "Should be queued, not dispatched yet") + + # Release the slot (simulates a prior request completing) + await slots.release("http://b1") + scheduler.notify_slot_released() + + # Scheduler re-evaluates on next wakeup + await scheduler._dispatch_all() + self.assertTrue(entry.future.done(), "Should be dispatched after slot freed") + self.assertEqual(entry.future.result().url, "http://b1") + + async def test_multiple_queued_requests_dispatched_as_slots_free(self): + from llamacpp_ha.config import BackendConfig, ProxyConfig + from llamacpp_ha.policies import RoundRobinPolicy + from llamacpp_ha.queue import QueueEntry, RequestQueue + from llamacpp_ha.registry import BackendRegistry + from llamacpp_ha.scheduler import Scheduler + from llamacpp_ha.session_store import SessionStore + from llamacpp_ha.slot_tracker import SlotTracker + + cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])]) + reg = BackendRegistry(cfg) + async with reg._lock: + reg._states["http://b1"].live = True + reg._states["http://b1"].models = ["m"] + reg._rebuild_index() + + slots = SlotTracker() + slots.set_capacity("http://b1", 2) + sessions = SessionStore() + queue = RequestQueue() + scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy()) + + # Fill both slots + async with asyncio.timeout(1.0): + await slots.acquire("http://b1") + async with asyncio.timeout(1.0): + await slots.acquire("http://b1") + + loop = asyncio.get_running_loop() + entries = [QueueEntry(model_id="m", future=loop.create_future()) for _ in range(3)] + for e in entries: + await queue.enqueue(e) + + await scheduler._dispatch_all() + dispatched = sum(1 for e in entries if e.future.done()) + self.assertEqual(dispatched, 0) + + # Free one slot → one request dispatched + await slots.release("http://b1") + await scheduler._dispatch_all() + dispatched = sum(1 for e in entries if e.future.done()) + self.assertEqual(dispatched, 1) + + # Free second slot → another dispatched + await slots.release("http://b1") + await scheduler._dispatch_all() + dispatched = sum(1 for e in entries if e.future.done()) + self.assertEqual(dispatched, 2) + + +class TestSessionCookieIntegration(unittest.TestCase): + def test_session_cookie_set_in_response(self): + """Proxy sets X-Session-ID header and cookie on inference responses.""" + # No live backends → times out → 503, but we still get session tracking + # We verify the session machinery works by checking that when the proxy + # processes a request, it assigns a session and would attach it to the response. + # For a full round-trip we'd need a live backend; here we verify the 503 + # path still doesn't crash the session handling. + cfg = ProxyConfig( + backends=[BackendConfig(url="http://b1", model_ids=["m"])], + slot_wait_timeout=0.1, + ) + app = create_app(cfg) + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.post( + "/v1/chat/completions", + json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, + ) + # Times out (no live backend) → 503 + self.assertEqual(resp.status_code, 503) + # Even on timeout, no crash + self.assertIn("error", resp.json()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..9e9a287 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,79 @@ +import unittest + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + +from llamacpp_ha.middleware import ApiKeyMiddleware + + +def _make_app(api_keys: list[str]) -> FastAPI: + app = FastAPI() + app.add_middleware(ApiKeyMiddleware, api_keys=api_keys) + + @app.get("/test") + async def test_endpoint(): + return {"ok": True} + + @app.get("/monitor") + async def monitor(): + return {"monitor": True} + + @app.get("/monitor/data") + async def monitor_data(): + return {"data": True} + + return app + + +class TestApiKeyMiddleware(unittest.TestCase): + def test_no_keys_configured_passes_all(self): + client = TestClient(_make_app([])) + resp = client.get("/test") + self.assertEqual(resp.status_code, 200) + + def test_valid_key_passes(self): + client = TestClient(_make_app(["key1", "key2"])) + resp = client.get("/test", headers={"Authorization": "Bearer key1"}) + self.assertEqual(resp.status_code, 200) + + def test_valid_second_key_passes(self): + client = TestClient(_make_app(["key1", "key2"])) + resp = client.get("/test", headers={"Authorization": "Bearer key2"}) + self.assertEqual(resp.status_code, 200) + + def test_missing_key_returns_401(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/test") + self.assertEqual(resp.status_code, 401) + + def test_wrong_key_returns_401(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/test", headers={"Authorization": "Bearer wrongkey"}) + self.assertEqual(resp.status_code, 401) + + def test_malformed_auth_returns_401(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/test", headers={"Authorization": "key1"}) + self.assertEqual(resp.status_code, 401) + + def test_monitor_exempt(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/monitor") + self.assertEqual(resp.status_code, 200) + + def test_monitor_data_exempt(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/monitor/data") + self.assertEqual(resp.status_code, 200) + + def test_error_response_is_json(self): + client = TestClient(_make_app(["key1"])) + resp = client.get("/test") + self.assertEqual(resp.headers["content-type"], "application/json") + body = resp.json() + self.assertIn("error", body) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_monitor.py b/tests/test_monitor.py new file mode 100644 index 0000000..375c132 --- /dev/null +++ b/tests/test_monitor.py @@ -0,0 +1,109 @@ +import unittest + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llamacpp_ha.config import BackendConfig, ProxyConfig +from llamacpp_ha.monitor import ProxyStats, build_router +from llamacpp_ha.queue import RequestQueue +from llamacpp_ha.registry import BackendRegistry, BackendState +from llamacpp_ha.session_store import SessionStore +from llamacpp_ha.slot_tracker import SlotTracker + + +def _make_monitor_app(states=None): + cfg = ProxyConfig(backends=[BackendConfig(url=s.url) for s in (states or [])]) + registry = BackendRegistry(cfg) + for state in (states or []): + registry._states[state.url] = state + registry._rebuild_index() + + slot_tracker = SlotTracker() + request_queue = RequestQueue() + session_store = SessionStore() + stats = ProxyStats() + + app = FastAPI() + router = build_router( + registry=registry, + slot_tracker=slot_tracker, + request_queue=request_queue, + session_store=session_store, + stats=stats, + ) + app.include_router(router) + return app, registry, slot_tracker, request_queue, session_store, stats + + +class TestMonitorEndpoints(unittest.TestCase): + def test_monitor_page_returns_html(self): + app, *_ = _make_monitor_app() + resp = TestClient(app).get("/monitor") + self.assertEqual(resp.status_code, 200) + self.assertIn("text/html", resp.headers["content-type"]) + self.assertIn("llamacpp-ha", resp.text) + + def test_monitor_page_no_external_deps(self): + app, *_ = _make_monitor_app() + text = TestClient(app).get("/monitor").text + self.assertNotIn("cdn.", text) + self.assertNotIn("googleapis.com", text) + self.assertNotIn("unpkg.com", text) + + def test_monitor_data_structure(self): + state = BackendState(config=BackendConfig(url="http://b1"), live=True, models=["m1", "m2"]) + app, _, slot_tracker, *_ = _make_monitor_app([state]) + slot_tracker.set_capacity("http://b1", 4) + data = TestClient(app).get("/monitor/data").json() + for key in ("uptime", "total_requests", "queue_depth", "session_count", + "live_backend_count", "backends", "queue", "sessions_by_model"): + self.assertIn(key, data) + + def test_monitor_data_backend_fields(self): + state = BackendState(config=BackendConfig(url="http://b1"), live=True, models=["m1"]) + app, _, slot_tracker, *_ = _make_monitor_app([state]) + slot_tracker.set_capacity("http://b1", 3) + backend = TestClient(app).get("/monitor/data").json()["backends"][0] + self.assertEqual(backend["url"], "http://b1") + self.assertTrue(backend["live"]) + self.assertEqual(backend["models"], ["m1"]) + self.assertEqual(backend["slots_total"], 3) + self.assertIn("slots_acquired", backend) + + def test_monitor_data_dead_backend(self): + state = BackendState(config=BackendConfig(url="http://b2-dead"), live=False, models=[]) + app, *_ = _make_monitor_app([state]) + data = TestClient(app).get("/monitor/data").json() + backends = {b["url"]: b for b in data["backends"]} + self.assertFalse(backends["http://b2-dead"]["live"]) + self.assertEqual(data["live_backend_count"], 0) + + def test_monitor_data_empty_state(self): + app, *_ = _make_monitor_app([]) + data = TestClient(app).get("/monitor/data").json() + self.assertEqual(data["backends"], []) + self.assertEqual(data["queue"], []) + self.assertEqual(data["session_count"], 0) + + def test_monitor_total_requests_reflects_stats(self): + app, *_, stats = _make_monitor_app() + stats.increment_requests() + stats.increment_requests() + data = TestClient(app).get("/monitor/data").json() + self.assertEqual(data["total_requests"], 2) + + def test_monitor_uptime_is_string(self): + app, *_ = _make_monitor_app() + data = TestClient(app).get("/monitor/data").json() + self.assertRegex(data["uptime"], r"^\d{2}:\d{2}:\d{2}$") + + def test_monitor_last_poll_age_never_polled(self): + """Backend that has never been polled should show null last_poll_age.""" + state = BackendState(config=BackendConfig(url="http://b1"), live=False, models=[]) + app, *_ = _make_monitor_app([state]) + data = TestClient(app).get("/monitor/data").json() + self.assertIsNone(data["backends"][0]["last_poll_age"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_policies.py b/tests/test_policies.py new file mode 100644 index 0000000..3c33f4f --- /dev/null +++ b/tests/test_policies.py @@ -0,0 +1,59 @@ +import unittest + +from llamacpp_ha.policies import BackendCandidate, RoundRobinPolicy + + +def _cands(n: int) -> list[BackendCandidate]: + return [BackendCandidate(url=f"http://b{i}", index=i) for i in range(n)] + + +class TestRoundRobinPolicy(unittest.TestCase): + def test_single_backend_always_chosen(self): + policy = RoundRobinPolicy() + cands = _cands(1) + for _ in range(5): + result = policy.select("m", cands) + self.assertEqual(result.url, "http://b0") + + def test_distributes_evenly_across_two(self): + policy = RoundRobinPolicy() + cands = _cands(2) + chosen = [policy.select("m", cands).url for _ in range(6)] + self.assertEqual(chosen.count("http://b0"), 3) + self.assertEqual(chosen.count("http://b1"), 3) + + def test_distributes_evenly_across_three(self): + policy = RoundRobinPolicy() + cands = _cands(3) + chosen = [policy.select("m", cands).url for _ in range(9)] + for i in range(3): + self.assertEqual(chosen.count(f"http://b{i}"), 3) + + def test_per_model_counters_independent(self): + policy = RoundRobinPolicy() + cands = _cands(2) + r1 = policy.select("model-a", cands) + r2 = policy.select("model-b", cands) + # Both start at 0 -> both get b0 + self.assertEqual(r1.url, "http://b0") + self.assertEqual(r2.url, "http://b0") + # Next calls for each model independently advance + r3 = policy.select("model-a", cands) + r4 = policy.select("model-b", cands) + self.assertEqual(r3.url, "http://b1") + self.assertEqual(r4.url, "http://b1") + + def test_empty_raises(self): + policy = RoundRobinPolicy() + with self.assertRaises(ValueError): + policy.select("m", []) + + def test_wraps_around(self): + policy = RoundRobinPolicy() + cands = _cands(2) + urls = [policy.select("m", cands).url for _ in range(5)] + self.assertEqual(urls, ["http://b0", "http://b1", "http://b0", "http://b1", "http://b0"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_queue.py b/tests/test_queue.py new file mode 100644 index 0000000..be0d56d --- /dev/null +++ b/tests/test_queue.py @@ -0,0 +1,92 @@ +import asyncio +import unittest + +from llamacpp_ha.queue import QueueEntry, RequestQueue + + +def _entry(**kwargs) -> QueueEntry: + e = QueueEntry(**kwargs) + e.future = asyncio.get_running_loop().create_future() + return e + + +class TestRequestQueue(unittest.IsolatedAsyncioTestCase): + async def test_enqueue_and_pending(self): + q = RequestQueue() + e1 = _entry(model_id="m1") + e2 = _entry(model_id="m2") + await q.enqueue(e1) + await q.enqueue(e2) + pending = await q.pending() + self.assertEqual(len(pending), 2) + self.assertEqual(pending[0].model_id, "m1") + self.assertEqual(pending[1].model_id, "m2") + + async def test_fifo_order(self): + q = RequestQueue() + ids = [] + for i in range(5): + e = _entry(model_id=f"m{i}") + await q.enqueue(e) + ids.append(e.request_id) + pending = await q.pending() + self.assertEqual([e.request_id for e in pending], ids) + + async def test_remove(self): + q = RequestQueue() + e1 = _entry(model_id="m1") + e2 = _entry(model_id="m2") + await q.enqueue(e1) + await q.enqueue(e2) + await q.remove(e1) + pending = await q.pending() + self.assertEqual(len(pending), 1) + self.assertEqual(pending[0].model_id, "m2") + + async def test_remove_nonexistent_no_error(self): + q = RequestQueue() + e = _entry(model_id="m") + await q.remove(e) # should not raise + + async def test_depth(self): + q = RequestQueue() + self.assertEqual(await q.depth(), 0) + for _ in range(3): + await q.enqueue(_entry(model_id="m")) + self.assertEqual(await q.depth(), 3) + + async def test_wakeup_event_set_on_enqueue(self): + q = RequestQueue() + self.assertFalse(q.wakeup_event.is_set()) + await q.enqueue(_entry(model_id="m")) + self.assertTrue(q.wakeup_event.is_set()) + + def test_notify_sets_wakeup(self): + q = RequestQueue() + q.notify() + self.assertTrue(q.wakeup_event.is_set()) + + async def test_entry_metadata(self): + import time + before = time.monotonic() + e = QueueEntry(model_id="llama3", session_id="sess123", estimated_tokens=42) + after = time.monotonic() + self.assertEqual(e.model_id, "llama3") + self.assertEqual(e.session_id, "sess123") + self.assertEqual(e.estimated_tokens, 42) + self.assertIsNone(e.future) + self.assertGreaterEqual(e.arrival_time, before) + self.assertLessEqual(e.arrival_time, after) + self.assertGreaterEqual(e.wait_seconds, 0) + + async def test_snapshot_truncates_session_id(self): + q = RequestQueue() + e = _entry(model_id="m", session_id="abcdef1234567890") + await q.enqueue(e) + snap = await q.snapshot() + self.assertEqual(len(snap), 1) + self.assertEqual(snap[0]["session_id"], "abcdef12") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..2382308 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,152 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from llamacpp_ha.config import BackendConfig, ProxyConfig +from llamacpp_ha.registry import BackendRegistry, BackendState + + +def _make_config(backends=None): + if backends is None: + backends = [{"url": "http://b1"}, {"url": "http://b2"}] + return ProxyConfig(backends=[BackendConfig(**b) for b in backends]) + + +class TestBackendRegistry(unittest.IsolatedAsyncioTestCase): + def _make_registry(self, backends=None): + cfg = _make_config(backends) + return BackendRegistry(cfg) + + async def test_initial_state_all_dead(self): + reg = self._make_registry() + live = reg.get_all_live_backends() + self.assertEqual(live, []) + + async def test_get_all_states_returns_all(self): + reg = self._make_registry() + states = reg.get_all_states() + self.assertEqual(len(states), 2) + urls = {s.url for s in states} + self.assertEqual(urls, {"http://b1", "http://b2"}) + + async def test_liveness_transition(self): + reg = self._make_registry([{"url": "http://b1"}]) + state = reg.get_state("http://b1") + self.assertFalse(state.live) + + # Simulate a successful poll by directly mutating state + async with reg._lock: + state.live = True + state.models = ["llama3"] + reg._rebuild_index() + + live = reg.get_all_live_backends() + self.assertEqual(len(live), 1) + + async def test_model_index_rebuilt(self): + reg = self._make_registry( + [{"url": "http://b1"}, {"url": "http://b2"}] + ) + async with reg._lock: + b1 = reg.get_state("http://b1") + b2 = reg.get_state("http://b2") + b1.live = True + b1.models = ["m1", "m2"] + b2.live = True + b2.models = ["m2", "m3"] + reg._rebuild_index() + + self.assertEqual(len(reg.get_backends_for_model("m1")), 1) + self.assertEqual(len(reg.get_backends_for_model("m2")), 2) + self.assertEqual(len(reg.get_backends_for_model("m3")), 1) + + async def test_get_all_models_deduplicated(self): + reg = self._make_registry( + [{"url": "http://b1"}, {"url": "http://b2"}] + ) + async with reg._lock: + b1 = reg.get_state("http://b1") + b2 = reg.get_state("http://b2") + b1.live = True + b1.models = ["m1", "shared"] + b2.live = True + b2.models = ["m2", "shared"] + reg._rebuild_index() + + models = reg.get_all_models() + self.assertEqual(len(models), 3) + self.assertEqual(len(set(models)), 3) # no duplicates + + async def test_dead_backend_excluded_from_model_index(self): + reg = self._make_registry( + [{"url": "http://b1"}, {"url": "http://b2"}] + ) + async with reg._lock: + b1 = reg.get_state("http://b1") + b2 = reg.get_state("http://b2") + b1.live = True + b1.models = ["m1"] + b2.live = False + b2.models = ["m1"] + reg._rebuild_index() + + backends = reg.get_backends_for_model("m1") + self.assertEqual(len(backends), 1) + self.assertEqual(backends[0].url, "http://b1") + + async def test_explicit_model_ids_used_over_backend_report(self): + cfg = ProxyConfig( + backends=[BackendConfig(url="http://b1", model_ids=["only-this"])] + ) + reg = BackendRegistry(cfg) + + # Simulate fetching models: should use config model_ids, not backend response + state = reg.get_state("http://b1") + mock_session = AsyncMock() + reg._session = mock_session + + models = await reg._fetch_models(state, {}) + self.assertEqual(models, ["only-this"]) + + async def test_last_poll_age_increases(self): + import time + reg = self._make_registry([{"url": "http://b1"}]) + state = reg.get_state("http://b1") + state.last_poll_time = time.monotonic() - 10 + age = state.last_poll_age + self.assertGreaterEqual(age, 9.9) + + async def test_no_backends_for_unknown_model(self): + reg = self._make_registry() + result = reg.get_backends_for_model("no-such-model") + self.assertEqual(result, []) + + async def test_poll_all_concurrent(self): + """poll_all runs backends concurrently; a slow backend doesn't block others.""" + cfg = ProxyConfig( + backends=[BackendConfig(url=f"http://b{i}") for i in range(5)] + ) + reg = BackendRegistry(cfg) + + poll_times = [] + + async def fake_poll_one(state): + poll_times.append(asyncio.get_event_loop().time()) + await asyncio.sleep(0.05) + async with reg._lock: + state.live = True + state.models = ["m1"] + + with patch.object(reg, "_poll_one", side_effect=fake_poll_one): + import time + start = time.monotonic() + await reg._poll_all() + elapsed = time.monotonic() - start + + # 5 backends * 0.05s each, but concurrent: should finish in ~0.1s not 0.25s + self.assertLess(elapsed, 0.2) + self.assertEqual(len(reg.get_all_live_backends()), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..6540df7 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,300 @@ +import asyncio +import unittest + +from llamacpp_ha.config import BackendConfig, ProxyConfig +from llamacpp_ha.policies import RoundRobinPolicy +from llamacpp_ha.queue import QueueEntry, RequestQueue +from llamacpp_ha.registry import BackendRegistry, BackendState +from llamacpp_ha.scheduler import Scheduler +from llamacpp_ha.session_store import SessionStore +from llamacpp_ha.slot_tracker import SlotTracker + + +def _make_state(url: str, models: list[str] | None = None) -> BackendState: + cfg = BackendConfig(url=url) + return BackendState(config=cfg, live=True, models=models or ["m1"]) + + +def _entry(**kwargs) -> QueueEntry: + e = QueueEntry(**kwargs) + e.future = asyncio.get_running_loop().create_future() + return e + + +class TestScheduler(unittest.IsolatedAsyncioTestCase): + def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0): + cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])]) + registry = BackendRegistry(cfg) + for state in (live_backends or []): + registry._states[state.url] = state + registry._rebuild_index() + + slot_tracker = SlotTracker() + for state in (live_backends or []): + slot_tracker.set_capacity(state.url, 2) + + session_store = SessionStore() + queue = RequestQueue() + scheduler = Scheduler( + queue=queue, + registry=registry, + slot_tracker=slot_tracker, + session_store=session_store, + policy=RoundRobinPolicy(), + max_queue_skip=max_queue_skip, + ) + return scheduler, queue, registry, slot_tracker, session_store + + async def test_dispatches_entry_to_live_backend(self): + b1 = _make_state("http://b1") + scheduler, queue, *_ = self._make_scheduler([b1]) + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertTrue(entry.future.done()) + self.assertEqual(entry.future.result().url, "http://b1") + + async def test_skips_full_backends(self): + b1 = _make_state("http://b1") + scheduler, queue, _, slots, _ = self._make_scheduler([b1]) + slots.set_capacity("http://b1", 1) + async with asyncio.timeout(1.0): + await slots.acquire("http://b1") + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertFalse(entry.future.done()) + + async def test_session_affinity_preferred(self): + b1 = _make_state("http://b1") + b2 = _make_state("http://b2") + scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2]) + + await sessions.get_or_create("sess1") + await sessions.update("sess1", preferred_backend="http://b2") + + entry = _entry(model_id="m1", session_id="sess1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertTrue(entry.future.done()) + self.assertEqual(entry.future.result().url, "http://b2") + + async def test_session_affinity_fallback_when_full(self): + b1 = _make_state("http://b1") + b2 = _make_state("http://b2") + scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2]) + slots.set_capacity("http://b2", 1) + async with asyncio.timeout(1.0): + await slots.acquire("http://b2") + + await sessions.get_or_create("sess1") + await sessions.update("sess1", preferred_backend="http://b2") + + entry = _entry(model_id="m1", session_id="sess1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertTrue(entry.future.done()) + self.assertEqual(entry.future.result().url, "http://b1") + + async def test_no_live_backends_entry_stays(self): + scheduler, queue, *_ = self._make_scheduler([]) + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + + self.assertFalse(entry.future.done()) + self.assertEqual(await queue.depth(), 1) + + def test_notify_slot_released_wakes_queue(self): + b1 = _make_state("http://b1") + scheduler, queue, *_ = self._make_scheduler([b1]) + scheduler.notify_slot_released() + self.assertTrue(queue.wakeup_event.is_set()) + + async def test_cancelled_future_cleaned_up(self): + b1 = _make_state("http://b1") + scheduler, queue, *_ = self._make_scheduler([b1]) + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + entry.future.cancel() + + await scheduler._dispatch_all() + + self.assertEqual(await queue.depth(), 0) + + async def test_round_robin_across_backends(self): + b1 = _make_state("http://b1") + b2 = _make_state("http://b2") + scheduler, queue, *_ = self._make_scheduler([b1, b2]) + + results = [] + for _ in range(4): + e = _entry(model_id="m1") + await queue.enqueue(e) + await scheduler._dispatch_all() + results.append(e.future.result().url) + + self.assertEqual(results.count("http://b1"), 2) + self.assertEqual(results.count("http://b2"), 2) + + async def test_slot_released_then_dispatch(self): + b1 = _make_state("http://b1") + scheduler, queue, _, slots, _ = self._make_scheduler([b1]) + slots.set_capacity("http://b1", 1) + async with asyncio.timeout(1.0): + await slots.acquire("http://b1") + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + self.assertFalse(entry.future.done()) + + await slots.release("http://b1") + scheduler.notify_slot_released() + await scheduler._dispatch_all() + self.assertTrue(entry.future.done()) + + # ------------------------------------------------------------------ + # max_models / preemption prevention + # ------------------------------------------------------------------ + + async def test_max_models_blocks_second_model_on_same_backend(self): + b1 = _make_state("http://b1") + _, queue, _, slots, _ = self._make_scheduler([b1]) + slots.set_capacity("http://b1", 4) + slots.set_max_models("http://b1", 1) + + # Occupy a slot with model-a + async with asyncio.timeout(1.0): + await slots.acquire("http://b1", "m1") + + # Request for model-b should stay queued + cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])]) + registry = BackendRegistry(cfg) + registry._states["http://b1"] = BackendState( + config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]), + live=True, + models=["m1", "m2"], + ) + registry._rebuild_index() + + sched2 = Scheduler( + queue=queue, + registry=registry, + slot_tracker=slots, + session_store=SessionStore(), + policy=RoundRobinPolicy(), + ) + entry = _entry(model_id="m2") + await queue.enqueue(entry) + await sched2._dispatch_all() + self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1") + + async def test_max_models_allows_same_model(self): + b1 = _make_state("http://b1") + scheduler, queue, _, slots, _ = self._make_scheduler([b1]) + slots.set_capacity("http://b1", 4) + slots.set_max_models("http://b1", 1) + async with asyncio.timeout(1.0): + await slots.acquire("http://b1", "m1") + + entry = _entry(model_id="m1") + await queue.enqueue(entry) + await scheduler._dispatch_all() + self.assertTrue(entry.future.done(), "same model should still be dispatchable") + + # ------------------------------------------------------------------ + # N-skip reordering + # ------------------------------------------------------------------ + + async def test_no_reorder_when_max_queue_skip_zero(self): + """Default FIFO: model-B request is not promoted over model-A.""" + b1 = _make_state("http://b1", models=["m1"]) + b2 = _make_state("http://b2", models=["m2"]) + scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0) + slots.set_capacity("http://b1", 1) + slots.set_capacity("http://b2", 1) + + # Fill b1; b2 is free with m2 active + async with asyncio.timeout(1.0): + await slots.acquire("http://b1", "m1") + async with asyncio.timeout(1.0): + await slots.acquire("http://b2", "m2") + + # Queue: [m1-request (blocked), m2-request (could go to b2)] + e_m1 = _entry(model_id="m1") + e_m2 = _entry(model_id="m2") + await queue.enqueue(e_m1) + await queue.enqueue(e_m2) + + # Release b2 slot so m2 can be served + await slots.release("http://b2", "m2") + await scheduler._dispatch_all() + + # m2 can be dispatched even with max_queue_skip=0 because _dispatch_all + # scans all entries (not strict head-of-line per model) + self.assertFalse(e_m1.future.done()) + self.assertTrue(e_m2.future.done()) + # skip_count must NOT be bumped when max_queue_skip=0 + self.assertEqual(e_m1.skip_count, 0) + + async def test_affinity_promotes_matching_model(self): + """With max_queue_skip>0, a matching model gets promoted.""" + b1 = _make_state("http://b1", models=["m1"]) + scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3) + slots.set_capacity("http://b1", 2) + + # b1 already has m1 in-flight + async with asyncio.timeout(1.0): + await slots.acquire("http://b1", "m1") + + # Queue: [m2-entry (no affinity), m1-entry (affinity match)] + e_other = _entry(model_id="m2") # no backend serves m2 + e_m1 = _entry(model_id="m1") + await queue.enqueue(e_other) + await queue.enqueue(e_m1) + + await scheduler._dispatch_all() + + # m1 entry is promoted (affinity pass), m2 stays (no backend) + self.assertTrue(e_m1.future.done()) + self.assertFalse(e_other.future.done()) + # e_other was bypassed once + self.assertEqual(e_other.skip_count, 1) + + async def test_skip_count_caps_reordering(self): + """Once skip_count reaches max_queue_skip the entry freezes at head.""" + b1 = _make_state("http://b1", models=["m1"]) + scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2) + slots.set_capacity("http://b1", 4) + + # b1 has m1 active + async with asyncio.timeout(1.0): + await slots.acquire("http://b1", "m1") + + e_other = _entry(model_id="m2") + e_other.skip_count = 2 # already at limit — must not be bypassed + e_m1 = _entry(model_id="m1") + await queue.enqueue(e_other) + await queue.enqueue(e_m1) + + await scheduler._dispatch_all() + + # Affinity pass stops at e_other (skip_count >= max_queue_skip), + # so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass. + # e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass. + self.assertTrue(e_m1.future.done()) + # skip_count must NOT increase further (entry was frozen) + self.assertEqual(e_other.skip_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_session_store.py b/tests/test_session_store.py new file mode 100644 index 0000000..6d413a7 --- /dev/null +++ b/tests/test_session_store.py @@ -0,0 +1,112 @@ +import asyncio +import unittest +from unittest.mock import patch + +from llamacpp_ha.session_store import SessionStore, compute_prefix_hash + + +class TestComputePrefixHash(unittest.TestCase): + def test_same_messages_same_hash(self): + msgs = [{"role": "user", "content": "hello"}] + self.assertEqual(compute_prefix_hash(msgs), compute_prefix_hash(msgs)) + + def test_different_messages_different_hash(self): + a = [{"role": "user", "content": "hello"}] + b = [{"role": "user", "content": "world"}] + self.assertNotEqual(compute_prefix_hash(a), compute_prefix_hash(b)) + + def test_empty_messages(self): + h = compute_prefix_hash([]) + self.assertIsInstance(h, str) + self.assertEqual(len(h), 16) + + +class TestSessionStore(unittest.IsolatedAsyncioTestCase): + async def test_get_or_create_new(self): + store = SessionStore(ttl=300.0) + session = await store.get_or_create("abc") + self.assertEqual(session.session_id, "abc") + self.assertIsNone(session.model_id) + + async def test_get_or_create_existing(self): + store = SessionStore(ttl=300.0) + s1 = await store.get_or_create("abc") + await store.update("abc", model_id="llama3") + s2 = await store.get_or_create("abc") + self.assertEqual(s2.model_id, "llama3") + + async def test_update_model_and_messages(self): + store = SessionStore(ttl=300.0) + await store.get_or_create("s1") + msgs = [{"role": "user", "content": "hi"}] + await store.update("s1", model_id="m1", messages=msgs, preferred_backend="http://b1") + pref = await store.get_preferred_backend("s1") + self.assertEqual(pref, "http://b1") + + async def test_affinity_hit(self): + store = SessionStore(ttl=300.0) + await store.get_or_create("s1") + await store.update("s1", preferred_backend="http://b2") + pref = await store.get_preferred_backend("s1") + self.assertEqual(pref, "http://b2") + + async def test_affinity_miss_unknown_session(self): + store = SessionStore(ttl=300.0) + pref = await store.get_preferred_backend("nonexistent") + self.assertIsNone(pref) + + async def test_ttl_expiry(self): + store = SessionStore(ttl=0.05) + await store.get_or_create("s1") + await asyncio.sleep(0.1) + pref = await store.get_preferred_backend("s1") + self.assertIsNone(pref) + + async def test_expire_removes_stale(self): + store = SessionStore(ttl=0.05) + await store.get_or_create("s1") + await store.get_or_create("s2") + await asyncio.sleep(0.1) + removed = await store.expire() + self.assertEqual(removed, 2) + count = await store.count() + self.assertEqual(count, 0) + + async def test_count_active(self): + store = SessionStore(ttl=300.0) + await store.get_or_create("s1") + await store.get_or_create("s2") + count = await store.count() + self.assertEqual(count, 2) + + async def test_count_by_model(self): + store = SessionStore(ttl=300.0) + await store.get_or_create("s1") + await store.get_or_create("s2") + await store.get_or_create("s3") + await store.update("s1", model_id="m1") + await store.update("s2", model_id="m1") + await store.update("s3", model_id="m2") + by_model = await store.count_by_model() + self.assertEqual(by_model["m1"], 2) + self.assertEqual(by_model["m2"], 1) + + async def test_update_nonexistent_session_no_error(self): + store = SessionStore(ttl=300.0) + await store.update("nope", model_id="m1") # should not raise + + async def test_concurrent_access(self): + store = SessionStore(ttl=300.0) + + async def worker(i): + sid = f"s{i}" + await store.get_or_create(sid) + await store.update(sid, model_id="m", preferred_backend=f"http://b{i}") + + await asyncio.gather(*[worker(i) for i in range(20)]) + count = await store.count() + self.assertEqual(count, 20) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_slot_tracker.py b/tests/test_slot_tracker.py new file mode 100644 index 0000000..58adf59 --- /dev/null +++ b/tests/test_slot_tracker.py @@ -0,0 +1,187 @@ +import asyncio +import unittest + +from llamacpp_ha.slot_tracker import SlotTracker + + +class TestSlotTracker(unittest.IsolatedAsyncioTestCase): + async def test_acquire_when_free(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + await tracker.acquire("http://b") + acquired, total = tracker.usage("http://b") + self.assertEqual(acquired, 1) + self.assertEqual(total, 2) + + async def test_has_free_slot(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + self.assertTrue(tracker.has_free_slot("http://b")) + await tracker.acquire("http://b") + self.assertFalse(tracker.has_free_slot("http://b")) + + async def test_timeout_when_full(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + await tracker.acquire("http://b") + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.05): + await tracker.acquire("http://b") + + async def test_release_unblocks_waiter(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + await tracker.acquire("http://b") + + results = [] + + async def waiter(): + async with asyncio.timeout(2.0): + await tracker.acquire("http://b") + results.append(True) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.05) + await tracker.release("http://b") + await task + self.assertEqual(results, [True]) + + async def test_release_below_zero(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + await tracker.release("http://b") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 0) + + def test_set_capacity_increase(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + tracker.set_capacity("http://b", 3) + _, total = tracker.usage("http://b") + self.assertEqual(total, 3) + + def test_unknown_url_defaults(self): + tracker = SlotTracker() + self.assertTrue(tracker.has_free_slot("http://unknown")) + acquired, total = tracker.usage("http://unknown") + self.assertEqual(acquired, 0) + self.assertEqual(total, 1) + + async def test_acquire_zero_timeout_succeeds_then_fails(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + async with asyncio.timeout(0): + await tracker.acquire("http://b") + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0): + await tracker.acquire("http://b") + + async def test_release_decrements(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + await tracker.acquire("http://b") + await tracker.acquire("http://b") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 2) + await tracker.release("http://b") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 1) + + # ------------------------------------------------------------------ + # Model-aware tests + # ------------------------------------------------------------------ + + def test_can_accept_respects_max_models(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + tracker.set_max_models("http://b", 1) + self.assertTrue(tracker.can_accept("http://b", "model-a")) + + async def test_max_models_blocks_second_model(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + tracker.set_max_models("http://b", 1) + await tracker.acquire("http://b", "model-a") + # model-a is still accepted (same model, slot available) + self.assertTrue(tracker.can_accept("http://b", "model-a")) + # model-b is blocked (max_models=1 already reached) + self.assertFalse(tracker.can_accept("http://b", "model-b")) + + async def test_max_models_unblocks_after_release(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + tracker.set_max_models("http://b", 1) + await tracker.acquire("http://b", "model-a") + self.assertFalse(tracker.can_accept("http://b", "model-b")) + await tracker.release("http://b", "model-a") + self.assertTrue(tracker.can_accept("http://b", "model-b")) + + async def test_active_model_set(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + self.assertEqual(tracker.active_model_set("http://b"), frozenset()) + await tracker.acquire("http://b", "model-a") + self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"})) + await tracker.acquire("http://b", "model-b") + self.assertEqual( + tracker.active_model_set("http://b"), frozenset({"model-a", "model-b"}) + ) + await tracker.release("http://b", "model-a") + self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-b"})) + + async def test_acquire_tracks_active_models(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + await tracker.acquire("http://b", "model-a") + await tracker.acquire("http://b", "model-a") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 2) + self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"})) + await tracker.release("http://b", "model-a") + self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"})) + await tracker.release("http://b", "model-a") + self.assertEqual(tracker.active_model_set("http://b"), frozenset()) + + async def test_reset_acquired_clears_state(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + await tracker.acquire("http://b", "model-a") + await tracker.acquire("http://b", "model-a") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 2) + await tracker.reset_acquired("http://b") + acquired, _ = tracker.usage("http://b") + self.assertEqual(acquired, 0) + self.assertEqual(tracker.active_model_set("http://b"), frozenset()) + + async def test_reset_acquired_unblocks_waiters(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 1) + tracker.set_max_models("http://b", 1) + await tracker.acquire("http://b", "model-a") + + unblocked = [] + + async def waiter(): + async with asyncio.timeout(2.0): + await tracker.acquire("http://b", "model-b") + unblocked.append(True) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.05) + self.assertFalse(unblocked) + await tracker.reset_acquired("http://b") + await task + self.assertEqual(unblocked, [True]) + + async def test_max_models_none_allows_any(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + tracker.set_max_models("http://b", None) + await tracker.acquire("http://b", "model-a") + self.assertTrue(tracker.can_accept("http://b", "model-b")) + self.assertTrue(tracker.can_accept("http://b", "model-c")) + + +if __name__ == "__main__": + unittest.main()