first commit

This commit is contained in:
2026-05-17 09:54:18 +02:00
commit 7344aa4ef4
32 changed files with 3921 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(python -m pytest)"
]
}
}

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__

70
AGENTS.md Normal file
View File

@@ -0,0 +1,70 @@
# llamacpp-ha — Agent instructions
## Project overview
`llamacpp-ha` is a slot-aware load balancer for llama.cpp servers. It exposes a single OpenAI-compatible HTTP API and distributes inference requests across multiple backends using a global request queue, per-backend slot tracking, and session affinity. The codebase is pure Python 3.13 async with FastAPI + aiohttp.
## Commands
```bash
# Install (editable, with test deps)
pip install -e ".[test]"
# Run all tests
python -m pytest
# Run a single test file
python -m pytest tests/test_scheduler.py -v
# Start the proxy
llamacpp-ha --config config.json
```
## Architecture
All source lives in `src/llamacpp_ha/`. Module responsibilities:
| Module | Responsibility |
|---|---|
| `config.py` | `BackendConfig` + `ProxyConfig` (pydantic-settings, `LLAMACPP_HA_` env prefix) |
| `registry.py` | Polls `/health`, `/v1/models`, `/slots` on each backend; maintains live/dead state |
| `slot_tracker.py` | Per-backend `asyncio.Condition`; `acquire()` blocks until a slot is free |
| `queue.py` | FIFO `asyncio.Future` queue; `asyncio.Event` wakeup for the scheduler |
| `scheduler.py` | Drains the queue, resolves futures with a chosen `BackendState`; `notify_slot_released()` triggers re-drain |
| `policies.py` | `RoundRobinPolicy` — per-model atomic counter over `BackendCandidate` list |
| `session_store.py` | SHA-256 prefix hash → preferred backend; TTL eviction |
| `forwarder.py` | aiohttp outbound request; streaming SSE via `iter_chunked`; releases slot in `finally` |
| `middleware.py` | `ApiKeyMiddleware``Bearer` token validation; `/monitor` and `/monitor/data` are exempt |
| `monitor.py` | `ProxyStats` dataclass + `build_router()` for `/monitor` (HTML) and `/monitor/data` (JSON) |
| `proxy.py` | `create_app()` — wires all components; registers FastAPI routes; lifespan manages tasks |
| `__main__.py` | CLI entry point (`llamacpp-ha` script) |
## Critical design invariants
**Slot release must be awaited, not task-spawned.** `SlotTracker.release()` is `async` and must be `await`ed directly. Scheduling it as a task causes a race where the scheduler checks `has_free_slot()` before the release runs. This applies in `forwarder.py` (both the streaming `finally` block and the error path).
**`QueueEntry.future` has no default.** The `future` field is `None` by default. Callers must create it explicitly: `entry.future = asyncio.get_running_loop().create_future()`. Never use `asyncio.get_event_loop()` — it is deprecated in async contexts.
**`build_router()` creates a new `APIRouter` instance each call.** The router must not be a module-level singleton; doing so causes route handlers to share state across test cases.
**`ProxyStats` is per-app.** Created in `create_app()` and passed into `build_router()`. Never use module-level counters or timestamps in `monitor.py`.
**`scheduler.start()` is `def`, not `async def`.** It creates a task; it does not need to be awaited.
**`asyncio.shield` in `_dispatch`.** The queue entry future is shielded so that a client timeout doesn't cancel the backend dispatch. On timeout, the entry is removed from the queue and the future is cancelled manually.
## Testing
- Tests are in `tests/` using `unittest` (not pytest-native style), with `IsolatedAsyncioTestCase` for async tests.
- `pytest.ini_options` sets `asyncio_mode = "auto"` so async test methods run without additional decorators.
- Integration tests (`test_integration.py`) use `with TestClient(app) as client:` — the context manager form is required to run the FastAPI lifespan (which starts the scheduler and opens the aiohttp session).
- Test helpers: `_entry(**kwargs)` creates a `QueueEntry` with a future attached via `asyncio.get_running_loop().create_future()`.
- All 113 tests must pass: `python -m pytest` should exit 0.
## Style
- Python 3.13+; use `from __future__ import annotations` in all modules.
- No comments unless the WHY is non-obvious (hidden constraint, asyncio race, workaround).
- No docstrings on internal helpers.
- Type annotations on all public functions and dataclass fields.
- `asyncio.get_running_loop()` everywhere; never `asyncio.get_event_loop()`.

1
CLAUDE.md Normal file
View File

@@ -0,0 +1 @@
@AGENTS.md

225
README.md Normal file
View File

@@ -0,0 +1,225 @@
# llamacpp-ha
Smart load balancer for [llama.cpp](https://github.com/ggerganov/llama.cpp) servers. Presents a single OpenAI-compatible API endpoint while distributing inference requests across multiple backends with slot-aware scheduling, session affinity, model-affinity reordering, and a live monitor page.
## Features
- **OpenAI-compatible API** — drop-in replacement for any client using `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/models`, etc.
- **Slot-aware scheduling** — tracks each backend's inference slot capacity (from `/slots`) and queues requests instead of sending them to overloaded backends
- **Preemption prevention** — optional `max_models` limit per backend prevents llama.cpp from evicting a running model's KV cache when a new model request arrives
- **Model-affinity reordering** — queued requests for an already-loaded model can be promoted ahead of waiting requests (configurable with starvation protection)
- **Session affinity** — routes follow-up turns in a conversation back to the same backend, improving KV-cache reuse
- **Round-robin policy** — distributes load evenly across backends per model
- **Backend health polling** — continuously polls `/health` and `/v1/models`; removes dead backends from rotation automatically; resets slot counters on recovery
- **Request queue** — FIFO queue with configurable timeout; returns 503 with a JSON error body when no slot becomes available in time
- **API key auth** — optional `Bearer` token validation; the proxy rewrites outbound auth to match each backend's own key
- **Monitor page** — self-contained HTML dashboard at `/monitor` (no CDN dependencies); auto-refreshes every 3 seconds
- **Catch-all proxy** — non-inference paths are forwarded best-effort to a live backend
## Requirements
- Python 3.13+
- llama.cpp server(s) with `--slots` endpoint enabled
## Installation
```bash
pip install .
# or in editable mode for development:
pip install -e ".[test]"
```
## Quick start
Copy the example config and edit it:
```bash
cp config.json.example config.json
llamacpp-ha --config config.json
```
The proxy starts on `http://0.0.0.0:8080` by default.
## Configuration
Configuration is a JSON file. All fields also accept environment variable overrides with the `LLAMACPP_HA_` prefix (nested fields use `__` as delimiter).
```json
{
"host": "0.0.0.0",
"port": 8080,
"api_keys": ["your-secret-key"],
"poll_interval": 5,
"slot_wait_timeout": 30,
"session_idle_ttl": 300,
"default_max_models": 1,
"max_queue_skip": 4,
"backends": [
{
"url": "http://localhost:8081",
"api_key": null,
"model_ids": [],
"max_models": 1
},
{
"url": "http://localhost:8082",
"api_key": "backend-secret",
"model_ids": ["llama3"]
}
]
}
```
### Global fields
| Field | Default | Description |
|---|---|---|
| `host` | `0.0.0.0` | Listen address |
| `port` | `8080` | Listen port |
| `api_keys` | `[]` | Accepted bearer tokens. Empty = no auth. |
| `poll_interval` | `5.0` | Seconds between backend health polls |
| `slot_wait_timeout` | `30.0` | Max seconds a request waits for a free slot |
| `session_idle_ttl` | `300.0` | Seconds before an idle session is evicted |
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
### Per-backend fields
| Field | Default | Description |
|---|---|---|
| `url` | required | Backend base URL |
| `api_key` | `null` | Injected as `Authorization: Bearer <key>` on outbound requests; client key is stripped |
| `model_ids` | `[]` | Override the model list instead of polling `/v1/models` |
| `max_models` | `null` | Maximum concurrent distinct models on this backend. Overrides `default_max_models`. `null` = unlimited. |
### Environment variable overrides
```bash
LLAMACPP_HA_PORT=9090 llamacpp-ha --config config.json
LLAMACPP_HA_API_KEYS='["key1","key2"]' llamacpp-ha --config config.json
```
## Preemption prevention (`max_models`)
llama.cpp evicts the current model's KV cache when a different model is loaded, which can interrupt an in-flight request. Setting `max_models: 1` on a backend tells the proxy to block requests for a second model until all slots for the first model are released.
```json
{
"default_max_models": 1,
"backends": [
{"url": "http://localhost:8081"},
{"url": "http://localhost:8082"}
]
}
```
With this configuration each backend serves exactly one model at a time. When both backends are busy with different models, a third request waits in the queue until a slot is freed on a compatible backend.
Set `default_max_models: 2` (or higher) to allow two models to share a backend's slots simultaneously when hardware permits.
## Model-affinity reordering (`max_queue_skip`)
By default the queue is strict FIFO. Setting `max_queue_skip` to a positive integer enables a two-phase dispatch:
1. **Affinity pass** — scans the queue for requests whose model is already active on a free backend. Matching requests are promoted and dispatched immediately, bypassing earlier entries.
2. **FIFO pass** — remaining entries are dispatched in arrival order.
Each time a request is bypassed, its `skip_count` is incremented. Once `skip_count` reaches `max_queue_skip` the entry is frozen at head-of-line and blocks the affinity pass — preventing indefinite starvation.
```json
{
"max_queue_skip": 4
}
```
This trades strict fairness for throughput: a warm model serves back-to-back requests without stalling for a cold-start on another model, while the `max_queue_skip` cap guarantees every request is eventually served.
## CLI reference
```
llamacpp-ha [--config PATH] [--host HOST] [--port PORT] [--log-level LEVEL]
```
`--config` defaults to `config.json` in the current directory. `--host` and `--port` override the values in the config file.
## API endpoints
| Method | Path | Description |
|---|---|---|
| `GET` | `/health` | Returns `{"status":"ok"}` if at least one backend is live, 503 otherwise |
| `GET` | `/v1/models` | Aggregated model list across all live backends |
| `POST` | `/v1/chat/completions` | Slot-gated, session-aware inference (streaming supported) |
| `POST` | `/v1/completions` | Same as above |
| `POST` | `/v1/embeddings` | Slot-gated pass-through |
| `POST` | `/v1/images/*`, `/v1/audio/*` | Slot-gated pass-through |
| `*` | `/*` | Best-effort forward to any live backend |
| `GET` | `/monitor` | HTML dashboard |
| `GET` | `/monitor/data` | Dashboard data as JSON (exempt from API key auth) |
### Session affinity
The proxy assigns a session ID to every request. It is sent back via both a cookie (`x-llm-session`) and a response header (`X-Session-ID`). Clients can echo either on subsequent requests to pin their conversation to the same backend. The affinity record expires after `session_idle_ttl` seconds of inactivity.
### Streaming
SSE streaming responses (`text/event-stream`) are passed through transparently. The backend slot is held for the duration of the stream and released when the final chunk is sent.
## Monitor
Open `http://localhost:8080/monitor` in a browser. The page polls `/monitor/data` every 3 seconds and shows:
- Uptime, total requests served, queue depth, active sessions, live backend count
- Per-backend: URL, live/dead status, models, slot usage (`acquired/total`), time since last poll
- Current queue contents with wait time and estimated token count
- Active sessions grouped by model
The monitor page and its data endpoint are exempt from API key authentication.
## Architecture
```
client
ApiKeyMiddleware
FastAPI app (proxy.py)
├── GET /v1/models ──► BackendRegistry.get_all_models()
├── POST /v1/chat/... ──► RequestQueue ──► Scheduler ──► SlotTracker
│ │
│ BackendState ◄─┘
│ │
│ forwarder.py ──► aiohttp ──► backend
├── GET /health
├── GET /monitor[/data] ──► monitor.py
└── /* catch-all ──► forwarder.forward_best_effort()
BackendRegistry polls /health + /v1/models + /slots every poll_interval;
calls on_backend_recovered when a dead backend comes back live
SlotTracker asyncio.Condition per backend; acquire blocks until slot free;
enforces max_models (preemption prevention)
SessionStore SHA-256 prefix hash → preferred backend URL; TTL eviction
RequestQueue FIFO asyncio.Future queue; asyncio.Event for wakeup
Scheduler two-phase dispatch (affinity pass + FIFO); N-skip reordering
RoundRobinPolicy per-model atomic counter for backend selection
```
## Development
```bash
# Install with test dependencies
pip install -e ".[test]"
# Run all tests
python -m pytest
# Run a specific module
python -m pytest tests/test_scheduler.py -v
```
Tests use `unittest.IsolatedAsyncioTestCase` for async tests and Starlette's `TestClient` as a context manager for integration tests that require the full lifespan (aiohttp session, scheduler, registry).
## License
MIT

20
config.json.example Normal file
View File

@@ -0,0 +1,20 @@
{
"host": "0.0.0.0",
"port": 8080,
"api_keys": ["your-secret-key"],
"poll_interval": 5,
"slot_wait_timeout": 30,
"session_idle_ttl": 300,
"backends": [
{
"url": "http://localhost:8081",
"api_key": null,
"model_ids": []
},
{
"url": "http://localhost:8082",
"api_key": "backend-secret",
"model_ids": ["llama3"]
}
]
}

33
pyproject.toml Normal file
View File

@@ -0,0 +1,33 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "llamacpp-ha"
version = "0.1.0"
description = "Smart load balancer for llama.cpp servers"
requires-python = ">=3.13"
dependencies = [
"fastapi>=0.115",
"uvicorn[standard]>=0.32",
"aiohttp>=3.11",
"pydantic-settings>=2.7",
"pydantic>=2.10",
]
[project.optional-dependencies]
test = [
"httpx>=0.28",
"pytest>=8",
"pytest-asyncio>=0.24",
]
[project.scripts]
llamacpp-ha = "llamacpp_ha.__main__:main"
[tool.hatch.build.targets.wheel]
packages = ["src/llamacpp_ha"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View File

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import argparse
import logging
import sys
import uvicorn
from .config import ProxyConfig
from .proxy import create_app
def main() -> None:
parser = argparse.ArgumentParser(description="llamacpp-ha load balancer")
parser.add_argument(
"--config",
default="config.json",
help="Path to JSON config file (default: config.json)",
)
parser.add_argument("--host", help="Override listen host")
parser.add_argument("--port", type=int, help="Override listen port")
parser.add_argument(
"--log-level",
default="info",
choices=["debug", "info", "warning", "error"],
)
args = parser.parse_args()
logging.basicConfig(
level=args.log_level.upper(),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
overrides = {}
if args.host:
overrides["host"] = args.host
if args.port:
overrides["port"] = args.port
try:
config = ProxyConfig.from_file(args.config, **overrides)
except FileNotFoundError:
print(f"Config file not found: {args.config}", file=sys.stderr)
sys.exit(1)
app = create_app(config)
uvicorn.run(app, host=config.host, port=config.port, log_level=args.log_level)
if __name__ == "__main__":
main()

54
src/llamacpp_ha/config.py Normal file
View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import json
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class BackendConfig(BaseModel):
url: str
api_key: str | None = None
model_ids: list[str] = []
# Maximum number of distinct models allowed in-flight simultaneously on this
# backend. Set to 1 for llama.cpp to prevent mid-request model preemption.
# None means the backend handles concurrency itself (no proxy-level limit).
max_models: int | None = None
@field_validator("url")
@classmethod
def strip_trailing_slash(cls, v: str) -> str:
return v.rstrip("/")
class ProxyConfig(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="LLAMACPP_HA_",
env_nested_delimiter="__",
)
host: str = "0.0.0.0"
port: int = 8080
api_keys: list[str] = []
poll_interval: float = 5.0
slot_wait_timeout: float = 30.0
session_idle_ttl: float = 300.0
default_slot_capacity: int = 1
# Fallback max_models applied to any backend that does not set its own.
# None = unlimited. Set to 1 globally when all backends are llama.cpp.
default_max_models: int | None = None
# How many queue positions a model-affinity request may skip ahead.
# 0 = pure FIFO (default). N > 0 enables reordering: the scheduler looks
# for a request matching an already-active model and can promote it up to N
# positions; each bypassed entry accumulates a skip count and is immune to
# further skipping once it reaches N.
max_queue_skip: int = 0
backends: list[BackendConfig] = []
@classmethod
def from_file(cls, path: str, **overrides: Any) -> "ProxyConfig":
with open(path) as f:
data = json.load(f)
data.update(overrides)
return cls(**data)

View File

@@ -0,0 +1,160 @@
from __future__ import annotations
import logging
from typing import AsyncIterator
import aiohttp
from fastapi import Request
from fastapi.responses import Response, StreamingResponse
from starlette.datastructures import Headers
from .registry import BackendState
from .scheduler import Scheduler
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_STREAM_CHUNK = 8192
_HOP_BY_HOP = frozenset(
[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
"content-encoding",
"content-length",
]
)
def _forward_headers(request: Request, backend: BackendState) -> dict[str, str]:
headers: dict[str, str] = {}
for name, value in request.headers.items():
if name.lower() not in _HOP_BY_HOP and name.lower() != "host":
headers[name] = value
if backend.config.api_key:
headers["Authorization"] = f"Bearer {backend.config.api_key}"
else:
headers.pop("Authorization", None)
headers.pop("authorization", None)
return headers
async def forward_request(
request: Request,
backend: BackendState,
session: aiohttp.ClientSession,
slot_tracker: SlotTracker,
scheduler: Scheduler,
model_id: str = "",
path_override: str | None = None,
) -> Response:
path = path_override or request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
resp_ctx = session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
)
resp = await resp_ctx.__aenter__()
content_type = resp.headers.get("Content-Type", "")
is_streaming = "text/event-stream" in content_type
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
if is_streaming:
async def generate() -> AsyncIterator[bytes]:
try:
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
yield chunk
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return StreamingResponse(
generate(),
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
else:
try:
data = await resp.read()
finally:
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Forward error to %s: %s", backend.url, exc)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
raise
async def forward_best_effort(
request: Request,
registry_backends: list[BackendState],
session: aiohttp.ClientSession,
) -> Response:
"""Forward without slot gating to any live backend (catch-all paths)."""
if not registry_backends:
return Response(content="No live backends", status_code=503)
backend = registry_backends[0]
path = request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
async with session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
) as resp:
content_type = resp.headers.get("Content-Type", "")
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
data = await resp.read()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Best-effort forward error: %s", exc)
return Response(content="Backend error", status_code=502)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/data"])
class ApiKeyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, api_keys: list[str]) -> None:
super().__init__(app)
self._keys: frozenset[str] = frozenset(api_keys)
async def dispatch(self, request: Request, call_next) -> Response:
if not self._keys:
return await call_next(request)
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse(
{"error": {"message": "Missing or invalid API key", "type": "auth_error"}},
status_code=401,
)
token = auth[len("Bearer "):]
if token not in self._keys:
return JSONResponse(
{"error": {"message": "Invalid API key", "type": "auth_error"}},
status_code=401,
)
return await call_next(request)

208
src/llamacpp_ha/monitor.py Normal file
View File

@@ -0,0 +1,208 @@
from __future__ import annotations
import time
from dataclasses import dataclass, field
from fastapi import APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
from .queue import RequestQueue
from .registry import BackendRegistry
from .session_store import SessionStore
from .slot_tracker import SlotTracker
@dataclass
class ProxyStats:
"""Per-app counters and timing. Created by create_app, passed to build_router."""
start_time: float = field(default_factory=time.monotonic)
total_requests: int = 0
def increment_requests(self) -> None:
self.total_requests += 1
def uptime_str(self) -> str:
secs = int(time.monotonic() - self.start_time)
h, remainder = divmod(secs, 3600)
m, s = divmod(remainder, 60)
return f"{h:02d}:{m:02d}:{s:02d}"
_HTML = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>llamacpp-ha Monitor</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: monospace; background: #0d1117; color: #c9d1d9; padding: 20px; }
h1 { color: #58a6ff; margin-bottom: 4px; font-size: 1.4em; }
.subtitle { color: #8b949e; font-size: 0.85em; margin-bottom: 20px; }
h2 { color: #79c0ff; margin: 20px 0 8px; font-size: 1em; text-transform: uppercase; letter-spacing: 1px; }
table { width: 100%; border-collapse: collapse; font-size: 0.85em; }
th { background: #161b22; color: #8b949e; text-align: left; padding: 6px 10px; border-bottom: 1px solid #30363d; }
td { padding: 6px 10px; border-bottom: 1px solid #21262d; }
tr:hover td { background: #161b22; }
.badge { display: inline-block; padding: 2px 8px; border-radius: 10px; font-size: 0.8em; }
.badge-live { background: #0f2c12; color: #3fb950; }
.badge-dead { background: #2c0f0f; color: #f85149; }
.slots { color: #d29922; }
.empty { color: #484f58; font-style: italic; }
#status { float: right; font-size: 0.8em; color: #8b949e; }
.summary { display: flex; gap: 20px; flex-wrap: wrap; margin: 10px 0 20px; }
.stat { background: #161b22; border: 1px solid #30363d; border-radius: 6px; padding: 10px 16px; }
.stat-val { font-size: 1.6em; color: #58a6ff; }
.stat-label { font-size: 0.75em; color: #8b949e; margin-top: 2px; }
</style>
</head>
<body>
<h1>llamacpp-ha <span id="status">loading...</span></h1>
<div class="subtitle">Smart Load Balancer for llama.cpp</div>
<div class="summary">
<div class="stat"><div class="stat-val" id="uptime">-</div><div class="stat-label">Uptime</div></div>
<div class="stat"><div class="stat-val" id="total-req">-</div><div class="stat-label">Requests Served</div></div>
<div class="stat"><div class="stat-val" id="queue-depth">-</div><div class="stat-label">Queue Depth</div></div>
<div class="stat"><div class="stat-val" id="session-count">-</div><div class="stat-label">Active Sessions</div></div>
<div class="stat"><div class="stat-val" id="live-count">-</div><div class="stat-label">Live Backends</div></div>
</div>
<h2>Backends</h2>
<table>
<thead><tr><th>URL</th><th>Status</th><th>Models</th><th>Slots</th><th>Last Poll</th></tr></thead>
<tbody id="backends-body"><tr><td colspan="5" class="empty">Loading...</td></tr></tbody>
</table>
<h2>Queue</h2>
<table>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th></tr></thead>
<tbody id="queue-body"><tr><td colspan="5" class="empty">Queue is empty</td></tr></tbody>
</table>
<h2>Sessions by Model</h2>
<table>
<thead><tr><th>Model</th><th>Active Sessions</th></tr></thead>
<tbody id="sessions-body"><tr><td colspan="2" class="empty">No active sessions</td></tr></tbody>
</table>
<script>
(function() {
function esc(s) {
return String(s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;');
}
function render(data) {
document.getElementById('uptime').textContent = data.uptime;
document.getElementById('total-req').textContent = data.total_requests;
document.getElementById('queue-depth').textContent = data.queue_depth;
document.getElementById('session-count').textContent = data.session_count;
document.getElementById('live-count').textContent = data.live_backend_count;
const bBody = document.getElementById('backends-body');
if (!data.backends.length) {
bBody.innerHTML = '<tr><td colspan="5" class="empty">No backends configured</td></tr>';
} else {
bBody.innerHTML = data.backends.map(b => {
const badge = b.live
? '<span class="badge badge-live">live</span>'
: '<span class="badge badge-dead">dead</span>';
const models = b.models.length ? esc(b.models.join(', ')) : '<span class="empty">none</span>';
const slots = `<span class="slots">${b.slots_acquired}/${b.slots_total}</span>`;
const age = b.last_poll_age == null ? '<span class="empty">never</span>' : esc(b.last_poll_age.toFixed(1)) + 's';
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${models}</td><td>${slots}</td><td>${age}</td></tr>`;
}).join('');
}
const qBody = document.getElementById('queue-body');
if (!data.queue.length) {
qBody.innerHTML = '<tr><td colspan="5" class="empty">Queue is empty</td></tr>';
} else {
qBody.innerHTML = data.queue.map(e => {
const tok = e.estimated_tokens != null ? esc(e.estimated_tokens) : '<span class="empty">-</span>';
const sid = e.session_id ? esc(e.session_id) : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td></tr>`;
}).join('');
}
const sBody = document.getElementById('sessions-body');
const sbm = data.sessions_by_model;
const keys = Object.keys(sbm);
if (!keys.length) {
sBody.innerHTML = '<tr><td colspan="2" class="empty">No active sessions</td></tr>';
} else {
sBody.innerHTML = keys.map(m =>
`<tr><td>${esc(m)}</td><td>${esc(sbm[m])}</td></tr>`
).join('');
}
document.getElementById('status').textContent = 'updated ' + new Date().toLocaleTimeString();
}
function poll() {
fetch('/monitor/data')
.then(r => r.json())
.then(render)
.catch(err => {
document.getElementById('status').textContent = 'error: ' + err.message;
});
}
poll();
setInterval(poll, 3000);
})();
</script>
</body>
</html>
"""
def build_router(
registry: BackendRegistry,
slot_tracker: SlotTracker,
request_queue: RequestQueue,
session_store: SessionStore,
stats: ProxyStats,
) -> APIRouter:
router = APIRouter()
@router.get("/monitor", response_class=HTMLResponse, include_in_schema=False)
async def monitor_page() -> HTMLResponse:
return HTMLResponse(content=_HTML)
@router.get("/monitor/data", include_in_schema=False)
async def monitor_data() -> JSONResponse:
states = registry.get_all_states()
backends_data = []
for state in states:
acquired, total = slot_tracker.usage(state.url)
age = state.last_poll_age
backends_data.append(
{
"url": state.url,
"live": state.live,
"models": list(state.models),
"slots_acquired": acquired,
"slots_total": total,
"last_poll_age": None if age == float("inf") else round(age, 1),
}
)
queue_snapshot = await request_queue.snapshot()
session_count = await session_store.count()
sessions_by_model = await session_store.count_by_model()
live_count = sum(1 for s in states if s.live)
return JSONResponse(
{
"uptime": stats.uptime_str(),
"total_requests": stats.total_requests,
"queue_depth": len(queue_snapshot),
"session_count": session_count,
"live_backend_count": live_count,
"backends": backends_data,
"queue": queue_snapshot,
"sessions_by_model": sessions_by_model,
}
)
return router

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
@dataclass
class BackendCandidate:
"""Minimal view of a backend passed to routing policies."""
url: str
index: int # position in the original backends list
class RoutingPolicy(ABC):
@abstractmethod
def select(self, model_id: str, candidates: list[BackendCandidate]) -> BackendCandidate:
"""Select one backend from the candidate list for the given model."""
class RoundRobinPolicy(RoutingPolicy):
def __init__(self) -> None:
self._counters: dict[str, int] = {}
def select(self, model_id: str, candidates: list[BackendCandidate]) -> BackendCandidate:
if not candidates:
raise ValueError("No candidates available")
count = self._counters.get(model_id, 0)
chosen = candidates[count % len(candidates)]
self._counters[model_id] = count + 1
return chosen

342
src/llamacpp_ha/proxy.py Normal file
View File

@@ -0,0 +1,342 @@
from __future__ import annotations
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
import aiohttp
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from .config import ProxyConfig
from .forwarder import forward_best_effort, forward_request
from .middleware import ApiKeyMiddleware
from .monitor import ProxyStats, build_router as build_monitor_router
from .policies import RoundRobinPolicy
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry, BackendState
from .scheduler import Scheduler
from .session_store import SessionStore
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_APPLICATION_JSON = "application/json"
_SESSION_COOKIE = "x-llm-session"
_SESSION_HEADER = "X-Session-ID"
_SLOT_GATED_PATHS = [
"/v1/chat/completions",
"/v1/completions",
"/v1/embeddings",
"/v1/images/generations",
"/v1/images/edits",
"/v1/images/variations",
"/v1/audio/speech",
"/v1/audio/transcriptions",
"/v1/audio/translations",
]
class _HttpSession:
"""Mutable holder for the shared aiohttp session, created inside lifespan."""
__slots__ = ("client",)
def __init__(self) -> None:
self.client: aiohttp.ClientSession | None = None
# ------------------------------------------------------------------
# Pure helpers
# ------------------------------------------------------------------
def _estimate_tokens(body: dict) -> int | None:
msgs = body.get("messages", [])
if not msgs:
prompt = body.get("prompt", "")
if isinstance(prompt, str):
return max(1, len(prompt) // 4)
return None
total = 0
for m in msgs:
content = m.get("content", "")
if isinstance(content, str):
total += max(1, len(content) // 4)
return total or None
def _get_model(body: dict) -> str:
return body.get("model", "")
def _session_id_from(request: Request) -> str | None:
return (
request.cookies.get(_SESSION_COOKIE)
or request.headers.get(_SESSION_HEADER)
)
def _attach_session(response: Response, session_id: str) -> None:
response.set_cookie(
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=True
)
response.headers[_SESSION_HEADER] = session_id
def _init_slot_tracker(
registry: BackendRegistry,
slot_tracker: SlotTracker,
default_max_models: int | None,
) -> None:
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
effective_max = (
state.config.max_models
if state.config.max_models is not None
else default_max_models
)
slot_tracker.set_max_models(state.url, effective_max)
# ------------------------------------------------------------------
# Background tasks
# ------------------------------------------------------------------
async def _sync_capacities(
registry: BackendRegistry, slot_tracker: SlotTracker, interval: float
) -> None:
while True:
await asyncio.sleep(interval)
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
async def _expire_sessions(session_store: SessionStore) -> None:
while True:
await asyncio.sleep(60)
await session_store.expire()
# ------------------------------------------------------------------
# Route handlers (bound to app components via functools.partial)
# ------------------------------------------------------------------
async def _dispatch_entry(
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
model_id: str,
session_id: str | None,
body: dict,
) -> BackendState | JSONResponse:
stats.increment_requests()
loop = asyncio.get_running_loop()
entry = QueueEntry(
model_id=model_id,
session_id=session_id,
estimated_tokens=_estimate_tokens(body),
future=loop.create_future(),
)
await request_queue.enqueue(entry)
try:
backend: BackendState = await asyncio.wait_for(
asyncio.shield(entry.future),
timeout=config.slot_wait_timeout,
)
except asyncio.TimeoutError:
await request_queue.remove(entry)
if not entry.future.done():
entry.future.cancel()
return JSONResponse(
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
status_code=503,
)
return backend
def _list_models(*, registry: BackendRegistry) -> JSONResponse:
models = registry.get_all_models()
data = [
{"id": m, "object": "model", "created": 0, "owned_by": "llamacpp-ha"}
for m in models
]
return JSONResponse({"object": "list", "data": data})
def _health(*, registry: BackendRegistry) -> Response:
if registry.get_all_live_backends():
return Response(content='{"status":"ok"}', media_type=_APPLICATION_JSON)
return Response(
content='{"status":"no live backends"}',
status_code=503,
media_type=_APPLICATION_JSON,
)
async def _catch_all(
request: Request,
*,
http: _HttpSession,
registry: BackendRegistry,
stats: ProxyStats,
) -> Response:
if http.client is None:
return Response(content="Proxy not ready", status_code=503)
stats.increment_requests()
live = registry.get_all_live_backends()
return await forward_best_effort(request, live, http.client)
async def _inference_endpoint(
request: Request,
*,
http: _HttpSession,
slot_tracker: SlotTracker,
scheduler: Scheduler,
session_store: SessionStore,
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
) -> Response:
if http.client is None:
return Response(
content='{"error":{"message":"Proxy not ready","type":"server_error"}}',
status_code=503,
media_type=_APPLICATION_JSON,
)
raw = await request.body()
try:
body: dict[str, Any] = json.loads(raw) if raw else {}
except json.JSONDecodeError:
body = {}
model_id = _get_model(body)
session_id = _session_id_from(request) or session_store.new_session_id()
result = await _dispatch_entry(
request_queue, stats, config, model_id, session_id, body
)
if isinstance(result, Response):
return result
backend = result
if model_id:
messages = body.get("messages", [])
await session_store.update(
session_id,
model_id=model_id,
messages=messages if messages else None,
preferred_backend=backend.url,
)
response = await forward_request(
request=request,
backend=backend,
session=http.client,
slot_tracker=slot_tracker,
scheduler=scheduler,
model_id=model_id,
)
_attach_session(response, session_id)
return response
# ------------------------------------------------------------------
# App factory
# ------------------------------------------------------------------
def create_app(config: ProxyConfig) -> FastAPI:
slot_tracker = SlotTracker()
session_store = SessionStore(ttl=config.session_idle_ttl)
request_queue = RequestQueue()
stats = ProxyStats()
http = _HttpSession()
async def _on_backend_recovered(url: str) -> None:
await slot_tracker.reset_acquired(url)
scheduler.notify_slot_released()
registry = BackendRegistry(config, on_backend_recovered=_on_backend_recovered)
scheduler = Scheduler(
queue=request_queue,
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=config.max_queue_skip,
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
http.client = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300),
connector=aiohttp.TCPConnector(ssl=False),
)
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
registry.start()
scheduler.start()
cap_task = asyncio.create_task(
_sync_capacities(registry, slot_tracker, config.poll_interval),
name="capacity-sync",
)
expire_task = asyncio.create_task(
_expire_sessions(session_store), name="session-expiry"
)
yield
cap_task.cancel()
expire_task.cancel()
await scheduler.stop()
await registry.stop()
if http.client:
await http.client.close()
app = FastAPI(title="llamacpp-ha", lifespan=lifespan)
app.add_middleware(ApiKeyMiddleware, api_keys=config.api_keys)
app.include_router(
build_monitor_router(
registry=registry,
slot_tracker=slot_tracker,
request_queue=request_queue,
session_store=session_store,
stats=stats,
)
)
def list_models_handler() -> JSONResponse:
return _list_models(registry=registry)
def health_handler() -> Response:
return _health(registry=registry)
async def inference_handler(request: Request) -> Response:
return await _inference_endpoint(
request,
http=http,
slot_tracker=slot_tracker,
scheduler=scheduler,
session_store=session_store,
request_queue=request_queue,
stats=stats,
config=config,
)
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001
return await _catch_all(request, http=http, registry=registry, stats=stats)
app.add_api_route("/v1/models", list_models_handler, methods=["GET"])
app.add_api_route("/health", health_handler, methods=["GET"])
for path in _SLOT_GATED_PATHS:
app.add_api_route(path, inference_handler, methods=["POST"])
app.add_api_route(
"/{full_path:path}",
catch_all_handler,
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
include_in_schema=False,
)
return app

78
src/llamacpp_ha/queue.py Normal file
View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import asyncio
import time
import uuid
from dataclasses import dataclass, field
@dataclass
class QueueEntry:
request_id: str = field(default_factory=lambda: uuid.uuid4().hex)
model_id: str = ""
session_id: str | None = None
arrival_time: float = field(default_factory=time.monotonic)
estimated_tokens: int | None = None
# Created explicitly by the enqueuing coroutine via asyncio.get_running_loop()
future: asyncio.Future | None = field(default=None)
# Populated by scheduler at dispatch time
assigned_backend: str | None = None
# Incremented each time a later entry is dispatched ahead of this one via
# model-affinity reordering. When it reaches max_queue_skip the entry
# becomes immune to further skipping (starvation prevention).
skip_count: int = 0
@property
def wait_seconds(self) -> float:
return time.monotonic() - self.arrival_time
class RequestQueue:
"""Global FIFO queue for pending inference requests."""
def __init__(self) -> None:
self._entries: list[QueueEntry] = []
self._lock = asyncio.Lock()
self._wakeup = asyncio.Event()
@property
def wakeup_event(self) -> asyncio.Event:
return self._wakeup
async def enqueue(self, entry: QueueEntry) -> None:
async with self._lock:
self._entries.append(entry)
self._wakeup.set()
async def remove(self, entry: QueueEntry) -> None:
async with self._lock:
try:
self._entries.remove(entry)
except ValueError:
pass
async def pending(self) -> list[QueueEntry]:
async with self._lock:
return list(self._entries)
async def depth(self) -> int:
async with self._lock:
return len(self._entries)
async def snapshot(self) -> list[dict]:
async with self._lock:
return [
{
"request_id": e.request_id,
"model_id": e.model_id,
"session_id": (e.session_id or "")[:8] or None,
"wait_seconds": round(e.wait_seconds, 2),
"estimated_tokens": e.estimated_tokens,
"skip_count": e.skip_count,
}
for e in self._entries
]
def notify(self) -> None:
"""Wake the scheduler without holding the lock."""
self._wakeup.set()

191
src/llamacpp_ha/registry.py Normal file
View File

@@ -0,0 +1,191 @@
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
import aiohttp
from .config import BackendConfig, ProxyConfig
log = logging.getLogger(__name__)
_HEALTH_PATH = "/health"
_MODELS_PATH = "/v1/models"
_SLOTS_PATH = "/slots"
@dataclass
class BackendState:
config: BackendConfig
live: bool = False
models: list[str] = field(default_factory=list)
slot_capacity: int = 1
last_poll_time: float | None = None
@property
def url(self) -> str:
return self.config.url
@property
def last_poll_age(self) -> float:
if self.last_poll_time is None:
return float("inf")
return time.monotonic() - self.last_poll_time
class BackendRegistry:
def __init__(
self,
config: ProxyConfig,
on_backend_recovered: Callable[[str], Awaitable[None]] | None = None,
) -> None:
self._config = config
self._states: dict[str, BackendState] = {
b.url: BackendState(config=b, slot_capacity=config.default_slot_capacity)
for b in config.backends
}
self._model_index: dict[str, list[BackendState]] = {}
self._lock = asyncio.Lock()
self._session: aiohttp.ClientSession | None = None
self._task: asyncio.Task | None = None
self._on_backend_recovered = on_backend_recovered
# ------------------------------------------------------------------
# Public read API (non-blocking, returns snapshots)
# ------------------------------------------------------------------
def get_backends_for_model(self, model_id: str) -> list[BackendState]:
return list(self._model_index.get(model_id, []))
def get_all_live_backends(self) -> list[BackendState]:
return [s for s in self._states.values() if s.live]
def get_all_models(self) -> list[str]:
seen: set[str] = set()
result: list[str] = []
for state in self._states.values():
if state.live:
for m in state.models:
if m not in seen:
seen.add(m)
result.append(m)
return result
def get_all_states(self) -> list[BackendState]:
return list(self._states.values())
def get_state(self, url: str) -> BackendState | None:
return self._states.get(url)
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def start(self) -> None:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=10),
connector=aiohttp.TCPConnector(ssl=False),
)
self._task = asyncio.create_task(self._poll_loop(), name="registry-poll")
async def stop(self) -> None:
if self._task:
self._task.cancel()
await asyncio.gather(self._task, return_exceptions=True)
if self._session:
await self._session.close()
# ------------------------------------------------------------------
# Polling
# ------------------------------------------------------------------
async def _poll_loop(self) -> None:
while True:
await self._poll_all()
await asyncio.sleep(self._config.poll_interval)
async def _poll_all(self) -> None:
states = list(self._states.values())
results = await asyncio.gather(
*[self._poll_one(s) for s in states], return_exceptions=True
)
for state, result in zip(states, results):
if isinstance(result, Exception):
log.warning("Poll error for %s: %s", state.url, result)
async with self._lock:
self._rebuild_index()
async def _poll_one(self, state: BackendState) -> None:
assert self._session is not None
url = state.url
headers = {}
if state.config.api_key:
headers["Authorization"] = f"Bearer {state.config.api_key}"
try:
async with self._session.get(url + _HEALTH_PATH, headers=headers) as resp:
live = resp.status == 200
except Exception:
live = False
if live:
models = await self._fetch_models(state, headers)
capacity = await self._fetch_slot_capacity(state, headers)
else:
models = state.models
capacity = state.slot_capacity
was_live = state.live
async with self._lock:
state.live = live
if live:
state.models = models
state.slot_capacity = capacity
state.last_poll_time = time.monotonic()
if live and not was_live and self._on_backend_recovered:
await self._on_backend_recovered(url)
async def _fetch_models(self, state: BackendState, headers: dict) -> list[str]:
assert self._session is not None
if state.config.model_ids:
return list(state.config.model_ids)
try:
async with self._session.get(
state.url + _MODELS_PATH, headers=headers
) as resp:
if resp.status != 200:
return state.models
data = await resp.json()
return [m["id"] for m in data.get("data", [])]
except Exception as exc:
log.debug("Failed to fetch models from %s: %s", state.url, exc)
return state.models
async def _fetch_slot_capacity(self, state: BackendState, headers: dict) -> int:
assert self._session is not None
try:
async with self._session.get(
state.url + _SLOTS_PATH, headers=headers
) as resp:
if resp.status != 200:
return state.slot_capacity
data = await resp.json()
if isinstance(data, list):
return max(len(data), 1)
except Exception:
pass
return state.slot_capacity
def _rebuild_index(self) -> None:
index: dict[str, list[BackendState]] = {}
for state in self._states.values():
if not state.live:
continue
for model in state.models:
index.setdefault(model, []).append(state)
self._model_index = index

View File

@@ -0,0 +1,212 @@
from __future__ import annotations
import asyncio
import logging
from .policies import BackendCandidate, RoutingPolicy
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry, BackendState
from .session_store import SessionStore
from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
class Scheduler:
"""
Continuously binds queued requests to backends.
Dispatch order
--------------
When max_queue_skip == 0 (default) the queue is pure FIFO.
When max_queue_skip > N the scheduler runs a two-phase dispatch on every
wakeup:
Phase 1 — model-affinity promotion
The scheduler scans the queue looking for entries whose model is already
in-flight on a free backend. It promotes those entries ahead of earlier
entries that would require a model switch. For every entry that is
bypassed, its skip_count is incremented. Scanning stops as soon as an
entry with skip_count >= max_queue_skip is encountered (that entry is
frozen at the head and must be served next, preventing starvation).
Phase 2 — standard FIFO
Remaining entries are dispatched in arrival order to any backend that
can accept the model (slot free, max_models constraint satisfied).
Session affinity
----------------
Session affinity is applied inside _resolve_backend as a hint: if the
preferred backend is in the candidate set it is chosen first. Affinity
is re-pinned to whichever backend ultimately serves the request; it is
never queued on a preferred backend when a free alternative exists.
Preemption prevention
---------------------
SlotTracker.can_accept() enforces the per-backend max_models limit. A
backend with max_models=1 that is serving model-A will reject model-B
requests until all model-A slots are released, preventing llama.cpp from
preempting the in-flight generation.
"""
def __init__(
self,
queue: RequestQueue,
registry: BackendRegistry,
slot_tracker: SlotTracker,
session_store: SessionStore,
policy: RoutingPolicy,
max_queue_skip: int = 0,
) -> None:
self._queue = queue
self._registry = registry
self._slots = slot_tracker
self._sessions = session_store
self._policy = policy
self._max_queue_skip = max_queue_skip
self._task: asyncio.Task | None = None
def notify_slot_released(self) -> None:
"""Called by the forwarder after a slot is freed; wakes the dispatch loop."""
self._queue.notify()
def start(self) -> None:
self._task = asyncio.create_task(self._loop(), name="scheduler")
async def stop(self) -> None:
if self._task:
self._task.cancel()
await asyncio.gather(self._task, return_exceptions=True)
async def _loop(self) -> None:
while True:
await self._queue.wakeup_event.wait()
self._queue.wakeup_event.clear()
await self._dispatch_all()
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
async def _dispatch_all(self) -> None:
entries = await self._queue.pending()
# Prune stale futures first so they don't count in skip logic.
for entry in entries:
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
entries = [e for e in entries if e.future is not None and not e.future.done()]
if not entries:
return
dispatched: set[int] = set()
if self._max_queue_skip > 0:
await self._affinity_pass(entries, dispatched)
# Standard FIFO pass for all remaining live entries.
for entry in entries:
if id(entry) in dispatched:
continue
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
continue
if await self._try_dispatch(entry):
dispatched.add(id(entry))
await self._queue.remove(entry)
async def _affinity_pass(
self, entries: list[QueueEntry], dispatched: set[int]
) -> None:
"""Phase 1: promote entries whose model is already active on a free backend.
Stops scanning as soon as an entry with skip_count >= max_queue_skip is
encountered — that entry is frozen and must be served by the FIFO pass.
"""
for i, entry in enumerate(entries):
if id(entry) in dispatched:
continue
if entry.skip_count >= self._max_queue_skip:
break # frozen head-of-line; FIFO pass must handle it next
if await self._try_dispatch_affinity(entry):
dispatched.add(id(entry))
await self._queue.remove(entry)
# Bump skip_count for every earlier entry we bypassed.
for j in range(i):
if id(entries[j]) not in dispatched:
entries[j].skip_count += 1
async def _try_dispatch_affinity(self, entry: QueueEntry) -> bool:
"""Dispatch only to a backend that already has entry.model_id in-flight."""
if not entry.model_id:
return False
live_backends = self._registry.get_backends_for_model(entry.model_id)
active_backends = [
b for b in live_backends
if entry.model_id in self._slots.active_model_set(b.url)
and self._slots.can_accept(b.url, entry.model_id)
]
if not active_backends:
return False
return await self._acquire_and_resolve(entry, active_backends)
async def _try_dispatch(self, entry: QueueEntry) -> bool:
"""Standard dispatch: any live backend that can accept this model."""
if entry.model_id:
live_backends = self._registry.get_backends_for_model(entry.model_id)
else:
live_backends = self._registry.get_all_live_backends()
if not live_backends:
return False
free_backends = [
b for b in live_backends if self._slots.can_accept(b.url, entry.model_id)
]
if not free_backends:
return False
return await self._acquire_and_resolve(entry, free_backends)
async def _acquire_and_resolve(
self, entry: QueueEntry, candidates: list[BackendState]
) -> bool:
chosen = await self._resolve_backend(entry, candidates)
# has_free_slot + acquire are not atomic: another coroutine could have
# taken the slot between the check above and here. timeout=0 means
# "succeed now or not at all" so we simply leave the entry queued.
try:
async with asyncio.timeout(0):
await self._slots.acquire(chosen.url, entry.model_id)
except TimeoutError:
return False
if entry.future is not None and not entry.future.done():
entry.assigned_backend = chosen.url
entry.future.set_result(chosen)
log.debug("Dispatched %s -> %s", entry.request_id, chosen.url)
return True
# Future was cancelled between our acquire and here — release the slot back.
await self._slots.release(chosen.url, entry.model_id)
return False
async def _resolve_backend(
self, entry: QueueEntry, candidates: list[BackendState]
) -> BackendState:
"""Apply session affinity then routing policy to pick one backend."""
if entry.session_id:
preferred_url = await self._sessions.get_preferred_backend(entry.session_id)
if preferred_url:
for b in candidates:
if b.url == preferred_url:
return b
backend_candidates = [
BackendCandidate(url=b.url, index=i) for i, b in enumerate(candidates)
]
chosen = self._policy.select(entry.model_id or "_any", backend_candidates)
return next(b for b in candidates if b.url == chosen.url)

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import asyncio
import hashlib
import json
import time
import uuid
from dataclasses import dataclass, field
@dataclass
class Session:
session_id: str
model_id: str | None = None
last_message_index: int = 0
prefix_hash: str = ""
preferred_backend: str | None = None
last_active: float = field(default_factory=time.monotonic)
def touch(self) -> None:
self.last_active = time.monotonic()
def is_expired(self, ttl: float) -> bool:
return (time.monotonic() - self.last_active) > ttl
def compute_prefix_hash(messages: list[dict]) -> str:
"""Hash the messages list for KV-cache affinity tracking."""
blob = json.dumps(messages, sort_keys=True, separators=(",", ":")).encode()
return hashlib.sha256(blob).hexdigest()[:16]
class SessionStore:
def __init__(self, ttl: float = 300.0) -> None:
self._ttl = ttl
self._sessions: dict[str, Session] = {}
self._lock = asyncio.Lock()
def new_session_id(self) -> str:
return uuid.uuid4().hex
async def get_or_create(self, session_id: str) -> Session:
async with self._lock:
if session_id not in self._sessions:
self._sessions[session_id] = Session(session_id=session_id)
session = self._sessions[session_id]
session.touch()
return session
async def update(
self,
session_id: str,
model_id: str | None = None,
messages: list[dict] | None = None,
preferred_backend: str | None = None,
) -> None:
async with self._lock:
if session_id not in self._sessions:
self._sessions[session_id] = Session(session_id=session_id)
session = self._sessions[session_id]
if model_id is not None:
session.model_id = model_id
if messages is not None:
session.last_message_index = len(messages)
session.prefix_hash = compute_prefix_hash(messages)
if preferred_backend is not None:
session.preferred_backend = preferred_backend
session.touch()
async def get_preferred_backend(self, session_id: str) -> str | None:
async with self._lock:
session = self._sessions.get(session_id)
if session is None or session.is_expired(self._ttl):
return None
return session.preferred_backend
async def expire(self) -> int:
"""Remove expired sessions. Returns count removed."""
async with self._lock:
expired = [
sid
for sid, s in self._sessions.items()
if s.is_expired(self._ttl)
]
for sid in expired:
del self._sessions[sid]
return len(expired)
async def count(self) -> int:
async with self._lock:
return sum(
1 for s in self._sessions.values() if not s.is_expired(self._ttl)
)
async def count_by_model(self) -> dict[str, int]:
async with self._lock:
result: dict[str, int] = {}
for s in self._sessions.values():
if not s.is_expired(self._ttl) and s.model_id:
result[s.model_id] = result.get(s.model_id, 0) + 1
return result
async def snapshot(self) -> list[Session]:
async with self._lock:
return [
s for s in self._sessions.values() if not s.is_expired(self._ttl)
]

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
@dataclass
class _SlotState:
capacity: int
max_models: int | None = None
acquired: int = 0
active_models: dict[str, int] = field(default_factory=dict)
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
class SlotTracker:
"""Tracks per-backend slot usage with model-aware acquire/release.
Callers control timeouts via asyncio.timeout() at the call site rather than
passing a timeout parameter here — acquire() blocks until a slot is free.
"""
def __init__(self) -> None:
self._slots: dict[str, _SlotState] = {}
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
if url not in self._slots:
self._slots[url] = _SlotState(capacity=capacity)
return self._slots[url]
def set_capacity(self, url: str, capacity: int) -> None:
state = self._ensure(url, capacity)
state.capacity = max(capacity, 1)
def set_max_models(self, url: str, max_models: int | None) -> None:
state = self._ensure(url)
state.max_models = max_models
# ------------------------------------------------------------------
# Read helpers (non-blocking)
# ------------------------------------------------------------------
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
if state.acquired >= state.capacity:
return False
if model_id and model_id not in state.active_models:
if state.max_models is not None and len(state.active_models) >= state.max_models:
return False
return True
def has_free_slot(self, url: str) -> bool:
state = self._slots.get(url)
if state is None:
return True
return state.acquired < state.capacity
def can_accept(self, url: str, model_id: str) -> bool:
"""True if this backend can accept a new request for model_id right now."""
state = self._slots.get(url)
if state is None:
return True
return self._can_acquire(state, model_id)
def active_model_set(self, url: str) -> frozenset[str]:
"""Models currently in-flight on this backend."""
state = self._slots.get(url)
if state is None:
return frozenset()
return frozenset(state.active_models)
def usage(self, url: str) -> tuple[int, int]:
state = self._slots.get(url)
if state is None:
return 0, 1
return state.acquired, state.capacity
# ------------------------------------------------------------------
# Acquire / release
# ------------------------------------------------------------------
async def acquire(self, url: str, model_id: str = "") -> None:
"""Acquire a slot for model_id. Blocks until one is available.
Use asyncio.timeout() at the call site to bound the wait:
try:
async with asyncio.timeout(5.0):
await tracker.acquire(url, model_id)
except TimeoutError:
...
"""
state = self._ensure(url)
async with state.condition:
while not self._can_acquire(state, model_id):
await state.condition.wait()
state.acquired += 1
if model_id:
state.active_models[model_id] = state.active_models.get(model_id, 0) + 1
async def release(self, url: str, model_id: str = "") -> None:
"""Release a slot and notify waiters."""
state = self._slots.get(url)
if state is None:
return
async with state.condition:
state.acquired = max(0, state.acquired - 1)
if model_id and model_id in state.active_models:
count = state.active_models[model_id] - 1
if count <= 0:
del state.active_models[model_id]
else:
state.active_models[model_id] = count
state.condition.notify_all()
async def reset_acquired(self, url: str) -> None:
"""Zero out slot tracking for a backend that just recovered from a crash."""
state = self._slots.get(url)
if state is None:
return
async with state.condition:
state.acquired = 0
state.active_models.clear()
state.condition.notify_all()

0
tests/__init__.py Normal file
View File

130
tests/test_config.py Normal file
View File

@@ -0,0 +1,130 @@
import json
import os
import tempfile
import unittest
from llamacpp_ha.config import BackendConfig, ProxyConfig
class TestBackendConfig(unittest.TestCase):
def test_url_trailing_slash_stripped(self):
b = BackendConfig(url="http://localhost:8080/")
self.assertEqual(b.url, "http://localhost:8080")
def test_defaults(self):
b = BackendConfig(url="http://localhost:8080")
self.assertIsNone(b.api_key)
self.assertEqual(b.model_ids, [])
self.assertIsNone(b.max_models)
def test_explicit_model_ids(self):
b = BackendConfig(url="http://x", model_ids=["llama3", "mistral"])
self.assertEqual(b.model_ids, ["llama3", "mistral"])
def test_max_models(self):
b = BackendConfig(url="http://x", max_models=1)
self.assertEqual(b.max_models, 1)
def test_max_models_none_is_unlimited(self):
b = BackendConfig(url="http://x", max_models=None)
self.assertIsNone(b.max_models)
class TestProxyConfig(unittest.TestCase):
def _write_config(self, data: dict) -> str:
f = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
)
json.dump(data, f)
f.close()
return f.name
def test_defaults(self):
cfg = ProxyConfig()
self.assertEqual(cfg.host, "0.0.0.0")
self.assertEqual(cfg.port, 8080)
self.assertEqual(cfg.api_keys, [])
self.assertEqual(cfg.poll_interval, 5.0)
self.assertEqual(cfg.slot_wait_timeout, 30.0)
self.assertEqual(cfg.session_idle_ttl, 300.0)
self.assertEqual(cfg.backends, [])
self.assertIsNone(cfg.default_max_models)
self.assertEqual(cfg.max_queue_skip, 0)
def test_from_file_minimal(self):
path = self._write_config(
{"backends": [{"url": "http://localhost:8081"}]}
)
try:
cfg = ProxyConfig.from_file(path)
self.assertEqual(len(cfg.backends), 1)
self.assertEqual(cfg.backends[0].url, "http://localhost:8081")
finally:
os.unlink(path)
def test_from_file_full(self):
data = {
"host": "127.0.0.1",
"port": 9090,
"api_keys": ["key1", "key2"],
"poll_interval": 10,
"slot_wait_timeout": 60,
"session_idle_ttl": 600,
"default_max_models": 2,
"max_queue_skip": 5,
"backends": [
{"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1},
{"url": "http://b2/"},
],
}
path = self._write_config(data)
try:
cfg = ProxyConfig.from_file(path)
self.assertEqual(cfg.host, "127.0.0.1")
self.assertEqual(cfg.port, 9090)
self.assertEqual(cfg.api_keys, ["key1", "key2"])
self.assertEqual(cfg.poll_interval, 10)
self.assertEqual(cfg.default_max_models, 2)
self.assertEqual(cfg.max_queue_skip, 5)
self.assertEqual(cfg.backends[0].api_key, "secret")
self.assertEqual(cfg.backends[0].model_ids, ["m1"])
self.assertEqual(cfg.backends[0].max_models, 1)
self.assertEqual(cfg.backends[1].url, "http://b2")
finally:
os.unlink(path)
def test_from_file_with_overrides(self):
path = self._write_config({"port": 8080, "backends": []})
try:
cfg = ProxyConfig.from_file(path, port=9999)
self.assertEqual(cfg.port, 9999)
finally:
os.unlink(path)
def test_env_var_override(self):
old = os.environ.get("LLAMACPP_HA_PORT")
try:
os.environ["LLAMACPP_HA_PORT"] = "7777"
cfg = ProxyConfig()
self.assertEqual(cfg.port, 7777)
finally:
if old is None:
os.environ.pop("LLAMACPP_HA_PORT", None)
else:
os.environ["LLAMACPP_HA_PORT"] = old
def test_invalid_json_raises(self):
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
f.write("not json {{{")
path = f.name
try:
with self.assertRaises(Exception):
ProxyConfig.from_file(path)
finally:
os.unlink(path)
if __name__ == "__main__":
unittest.main()

276
tests/test_forwarder.py Normal file
View File

@@ -0,0 +1,276 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import Request
from starlette.datastructures import Headers
from llamacpp_ha.config import BackendConfig
from llamacpp_ha.forwarder import _forward_headers, forward_request
from llamacpp_ha.registry import BackendState
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url="http://b1", api_key=None):
cfg = BackendConfig(url=url, api_key=api_key)
return BackendState(config=cfg)
def _make_request(headers: dict, method="POST", path="/v1/chat/completions") -> MagicMock:
req = MagicMock()
req.headers = Headers(headers=headers)
req.url.path = path
req.url.query = ""
req.method = method
req.body = AsyncMock(return_value=b"")
return req
def _make_scheduler():
sched = MagicMock()
sched.notify_slot_released = MagicMock()
return sched
class TestForwardHeaders(unittest.TestCase):
def test_injects_backend_api_key(self):
state = _make_state(api_key="backend-secret")
req = _make_request({"Authorization": "Bearer client-key", "Content-Type": "application/json"})
headers = _forward_headers(req, state)
self.assertEqual(headers["Authorization"], "Bearer backend-secret")
combined = {k.lower(): v for k, v in headers.items()}
self.assertIn("application/json", combined.get("content-type", ""))
def test_removes_auth_when_no_backend_key(self):
state = _make_state(api_key=None)
req = _make_request({"Authorization": "Bearer client-key"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("authorization", combined)
def test_hop_by_hop_stripped(self):
state = _make_state()
req = _make_request({
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
"X-Custom": "value",
})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("connection", combined)
self.assertNotIn("transfer-encoding", combined)
self.assertEqual(combined.get("x-custom"), "value")
def test_host_header_stripped(self):
state = _make_state()
req = _make_request({"Host": "proxy.local", "Accept": "application/json"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("host", combined)
self.assertIn("accept", combined)
def _mock_aiohttp_response(
status: int = 200,
content_type: str = "application/json",
body: bytes = b'{"ok":true}',
headers: dict | None = None,
) -> MagicMock:
resp = MagicMock()
resp.status = status
all_headers = {"Content-Type": content_type}
if headers:
all_headers.update(headers)
resp.headers = all_headers
resp.read = AsyncMock(return_value=body)
resp.content = MagicMock()
resp.content.iter_chunked = MagicMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx, resp
class TestForwardRequestNonStreaming(unittest.IsolatedAsyncioTestCase):
async def _acquire(self, tracker: SlotTracker, url: str) -> None:
async with asyncio.timeout(1.0):
await tracker.acquire(url)
async def test_successful_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"choices":[]}')
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.body, b'{"choices":[]}')
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_slot_released_on_forward_error(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(side_effect=Exception("network error"))
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
with self.assertRaises(Exception):
await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_non_streaming_returns_full_body(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
body = b'{"id":"xyz","choices":[{"text":"hello"}]}'
ctx, _ = _mock_aiohttp_response(body=body)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.body, body)
async def test_status_code_preserved(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=429, body=b"rate limited")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 429)
async def test_model_id_tracked_in_active_models(self):
"""model_id passed to forward_request is reflected in slot tracker."""
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1", "my-model")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b"{}")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
await forward_request(req, state, session, slot_tracker, scheduler, model_id="my-model")
self.assertEqual(slot_tracker.active_model_set("http://b1"), frozenset())
class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
async def test_streaming_sse_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
sse_chunks = [
b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n',
b'data: [DONE]\n\n',
]
async def fake_iter_chunked(_size):
for chunk in sse_chunks:
yield chunk
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
from fastapi.responses import StreamingResponse
self.assertIsInstance(response, StreamingResponse)
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk)
self.assertEqual(chunks, sse_chunks)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_streaming_slot_released_on_stream_end(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
async def fake_iter_chunked(_size):
yield b"data: chunk1\n\n"
yield b"data: chunk2\n\n"
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 1)
chunks = [chunk async for chunk in response.body_iterator]
self.assertEqual(len(chunks), 2)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
if __name__ == "__main__":
unittest.main()

477
tests/test_integration.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Integration tests using an in-process fake llama.cpp backend.
The fake backend runs as a FastAPI app via httpx.AsyncClient(transport=...).
We patch aiohttp calls in the registry/forwarder so that requests go to our
fake server without opening real TCP sockets.
"""
from __future__ import annotations
import asyncio
import json
import unittest
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.testclient import TestClient
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.proxy import create_app
# ---------------------------------------------------------------------------
# Helpers to build a minimal proxy with pre-seeded live backends
# ---------------------------------------------------------------------------
def _live_backend_config(url: str, models: list[str] | None = None) -> BackendConfig:
return BackendConfig(url=url, model_ids=models or ["test-model"])
def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) -> ProxyConfig:
return ProxyConfig(
host="127.0.0.1",
port=8080,
api_keys=api_keys or [],
poll_interval=9999, # disable background polling during tests
slot_wait_timeout=1.0,
session_idle_ttl=300.0,
default_slot_capacity=2,
backends=[_live_backend_config(url) for url in backend_urls],
)
def _seed_live(app_obj, backend_urls: list[str], models=None) -> None:
"""Directly seed the registry with live backends (skipping real HTTP poll)."""
from llamacpp_ha.registry import BackendRegistry
# Access the app's state via lifespan-created objects
# We do this by patching _poll_all to be a no-op and manually setting state
pass
# ---------------------------------------------------------------------------
# Fake llama.cpp server (in-process, no real sockets)
# ---------------------------------------------------------------------------
def _make_fake_backend(
model_id: str = "test-model",
response_content: str = '{"choices":[{"message":{"role":"assistant","content":"hello"}}]}',
streaming: bool = False,
slow_seconds: float = 0,
slot_count: int = 2,
) -> FastAPI:
fake = FastAPI()
@fake.get("/health")
async def health():
if slow_seconds:
await asyncio.sleep(slow_seconds)
return Response(status_code=200)
@fake.get("/v1/models")
async def models():
return JSONResponse({"object": "list", "data": [{"id": model_id, "object": "model"}]})
@fake.get("/slots")
async def slots():
return JSONResponse([{"id": i, "state": 0} for i in range(slot_count)])
@fake.post("/v1/chat/completions")
async def chat(request: Request):
if slow_seconds:
await asyncio.sleep(slow_seconds)
body = await request.json()
stream = body.get("stream", False)
if stream or streaming:
async def gen():
yield b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n'
yield b"data: [DONE]\n\n"
return StreamingResponse(gen(), media_type="text/event-stream")
return JSONResponse(json.loads(response_content))
@fake.post("/v1/completions")
async def completions(request: Request):
return JSONResponse({"choices": [{"text": "hello"}]})
return fake
# ---------------------------------------------------------------------------
# Integration test cases
# ---------------------------------------------------------------------------
class TestHealthEndpoint(unittest.TestCase):
def _make_proxy_with_live_backend(self):
cfg = _proxy_config(["http://fake-b1"])
app = create_app(cfg)
# Seed registry directly before first request
def _patch_registry(app_obj):
pass
return app, cfg
def test_health_no_live_backend_returns_503(self):
cfg = _proxy_config([])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
self.assertEqual(resp.status_code, 503)
def test_health_with_api_key_required(self):
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
self.assertEqual(resp.status_code, 401)
def test_health_with_valid_api_key(self):
cfg = _proxy_config([], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
# No live backends -> 503 but authenticated
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
self.assertEqual(resp.status_code, 503)
class TestModelsEndpoint(unittest.TestCase):
def test_models_returns_list_format(self):
cfg = _proxy_config(["http://b1"])
app = create_app(cfg)
# Seed registry
from llamacpp_ha.registry import BackendState
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/v1/models")
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["object"], "list")
self.assertIn("data", data)
def test_models_deduplication(self):
"""Models appearing on multiple backends appear once."""
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["shared", "only-b1"]),
BackendConfig(url="http://b2", model_ids=["shared", "only-b2"]),
]
)
app = create_app(cfg)
from llamacpp_ha.registry import BackendState
# Manually seed live
import asyncio
async def _seed():
from llamacpp_ha import proxy as _proxy_module
# Can't easily seed without internal access in this test structure
pass
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/v1/models")
data = resp.json()
model_ids = [m["id"] for m in data["data"]]
self.assertEqual(len(model_ids), len(set(model_ids)), "Duplicates found")
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
def test_request_rejected_without_key(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.post("/v1/chat/completions", json={"model": "m", "messages": []})
self.assertEqual(resp.status_code, 401)
def test_monitor_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor")
self.assertEqual(resp.status_code, 200)
def test_monitor_data_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor/data")
self.assertEqual(resp.status_code, 200)
class TestSlotExhaustionTimeout(unittest.TestCase):
def test_timeout_when_no_slot_available(self):
"""Request with no live backends times out and returns 503."""
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.2,
default_slot_capacity=1,
)
app = create_app(cfg)
# No live backends -> scheduler never dispatches -> timeout -> 503
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/v1/chat/completions",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
self.assertEqual(resp.status_code, 503)
class TestCatchAllForwarding(unittest.TestCase):
def test_catch_all_no_backends_returns_503(self):
cfg = _proxy_config([])
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/some/unknown/path")
self.assertEqual(resp.status_code, 503)
class TestMonitorIntegration(unittest.TestCase):
def test_monitor_data_reflects_queue_depth(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.1,
default_slot_capacity=0,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
data = client.get("/monitor/data").json()
self.assertIsInstance(data["queue_depth"], int)
self.assertGreaterEqual(data["queue_depth"], 0)
def test_monitor_data_shows_all_backends(self):
cfg = _proxy_config(["http://b1", "http://b2", "http://b3"])
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
data = client.get("/monitor/data").json()
self.assertEqual(len(data["backends"]), 3)
urls = {b["url"] for b in data["backends"]}
self.assertEqual(urls, {"http://b1", "http://b2", "http://b3"})
class TestFullForwardPath(unittest.TestCase):
"""
Full path test: proxy starts up, monitor/data reflects backend state
(dead until polled), scheduler queues requests when no backends are live.
"""
def test_backend_initially_dead_no_poll(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
slot_wait_timeout=2.0,
default_slot_capacity=2,
poll_interval=9999, # no background polling in test
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/monitor/data")
data = resp.json()
# Backend exists but is dead until first poll completes
self.assertEqual(len(data["backends"]), 1)
self.assertFalse(data["backends"][0]["live"])
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
async def test_dead_backend_not_in_model_index(self):
"""Dead backends are not returned for model routing."""
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.registry import BackendRegistry, BackendState
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["m"]),
BackendConfig(url="http://b2", model_ids=["m"]),
]
)
reg = BackendRegistry(cfg)
# Simulate: b1 live, b2 dead
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._states["http://b2"].live = False
reg._rebuild_index()
backends = reg.get_backends_for_model("m")
self.assertEqual(len(backends), 1)
self.assertEqual(backends[0].url, "http://b1")
async def test_all_dead_returns_empty(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.registry import BackendRegistry
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["m"]),
]
)
reg = BackendRegistry(cfg)
backends = reg.get_backends_for_model("m")
self.assertEqual(backends, [])
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
async def test_routes_to_backend_serving_model(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["model-a"]),
BackendConfig(url="http://b2", model_ids=["model-b"]),
]
)
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["model-a"]
reg._states["http://b2"].live = True
reg._states["http://b2"].models = ["model-b"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 2)
slots.set_capacity("http://b2", 2)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
loop = asyncio.get_running_loop()
entry = QueueEntry(model_id="model-b", future=loop.create_future())
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b2")
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
"""
Slot is fully occupied → request queues → slot released → request dispatched.
Tests the complete wait-and-unblock path without involving real HTTP.
"""
async def test_queued_request_dispatched_after_slot_release(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 1)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
# Occupy the only slot (simulating an in-flight request)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
# Enqueue a request — scheduler should not be able to dispatch
loop = asyncio.get_running_loop()
entry = QueueEntry(model_id="m", future=loop.create_future())
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done(), "Should be queued, not dispatched yet")
# Release the slot (simulates a prior request completing)
await slots.release("http://b1")
scheduler.notify_slot_released()
# Scheduler re-evaluates on next wakeup
await scheduler._dispatch_all()
self.assertTrue(entry.future.done(), "Should be dispatched after slot freed")
self.assertEqual(entry.future.result().url, "http://b1")
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 2)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
# Fill both slots
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
loop = asyncio.get_running_loop()
entries = [QueueEntry(model_id="m", future=loop.create_future()) for _ in range(3)]
for e in entries:
await queue.enqueue(e)
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 0)
# Free one slot → one request dispatched
await slots.release("http://b1")
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 1)
# Free second slot → another dispatched
await slots.release("http://b1")
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 2)
class TestSessionCookieIntegration(unittest.TestCase):
def test_session_cookie_set_in_response(self):
"""Proxy sets X-Session-ID header and cookie on inference responses."""
# No live backends → times out → 503, but we still get session tracking
# We verify the session machinery works by checking that when the proxy
# processes a request, it assigns a session and would attach it to the response.
# For a full round-trip we'd need a live backend; here we verify the 503
# path still doesn't crash the session handling.
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.1,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/v1/chat/completions",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
# Times out (no live backend) → 503
self.assertEqual(resp.status_code, 503)
# Even on timeout, no crash
self.assertIn("error", resp.json())
if __name__ == "__main__":
unittest.main()

79
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,79 @@
import unittest
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from llamacpp_ha.middleware import ApiKeyMiddleware
def _make_app(api_keys: list[str]) -> FastAPI:
app = FastAPI()
app.add_middleware(ApiKeyMiddleware, api_keys=api_keys)
@app.get("/test")
async def test_endpoint():
return {"ok": True}
@app.get("/monitor")
async def monitor():
return {"monitor": True}
@app.get("/monitor/data")
async def monitor_data():
return {"data": True}
return app
class TestApiKeyMiddleware(unittest.TestCase):
def test_no_keys_configured_passes_all(self):
client = TestClient(_make_app([]))
resp = client.get("/test")
self.assertEqual(resp.status_code, 200)
def test_valid_key_passes(self):
client = TestClient(_make_app(["key1", "key2"]))
resp = client.get("/test", headers={"Authorization": "Bearer key1"})
self.assertEqual(resp.status_code, 200)
def test_valid_second_key_passes(self):
client = TestClient(_make_app(["key1", "key2"]))
resp = client.get("/test", headers={"Authorization": "Bearer key2"})
self.assertEqual(resp.status_code, 200)
def test_missing_key_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test")
self.assertEqual(resp.status_code, 401)
def test_wrong_key_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test", headers={"Authorization": "Bearer wrongkey"})
self.assertEqual(resp.status_code, 401)
def test_malformed_auth_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test", headers={"Authorization": "key1"})
self.assertEqual(resp.status_code, 401)
def test_monitor_exempt(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/monitor")
self.assertEqual(resp.status_code, 200)
def test_monitor_data_exempt(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/monitor/data")
self.assertEqual(resp.status_code, 200)
def test_error_response_is_json(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test")
self.assertEqual(resp.headers["content-type"], "application/json")
body = resp.json()
self.assertIn("error", body)
if __name__ == "__main__":
unittest.main()

109
tests/test_monitor.py Normal file
View File

@@ -0,0 +1,109 @@
import unittest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.monitor import ProxyStats, build_router
from llamacpp_ha.queue import RequestQueue
from llamacpp_ha.registry import BackendRegistry, BackendState
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
def _make_monitor_app(states=None):
cfg = ProxyConfig(backends=[BackendConfig(url=s.url) for s in (states or [])])
registry = BackendRegistry(cfg)
for state in (states or []):
registry._states[state.url] = state
registry._rebuild_index()
slot_tracker = SlotTracker()
request_queue = RequestQueue()
session_store = SessionStore()
stats = ProxyStats()
app = FastAPI()
router = build_router(
registry=registry,
slot_tracker=slot_tracker,
request_queue=request_queue,
session_store=session_store,
stats=stats,
)
app.include_router(router)
return app, registry, slot_tracker, request_queue, session_store, stats
class TestMonitorEndpoints(unittest.TestCase):
def test_monitor_page_returns_html(self):
app, *_ = _make_monitor_app()
resp = TestClient(app).get("/monitor")
self.assertEqual(resp.status_code, 200)
self.assertIn("text/html", resp.headers["content-type"])
self.assertIn("llamacpp-ha", resp.text)
def test_monitor_page_no_external_deps(self):
app, *_ = _make_monitor_app()
text = TestClient(app).get("/monitor").text
self.assertNotIn("cdn.", text)
self.assertNotIn("googleapis.com", text)
self.assertNotIn("unpkg.com", text)
def test_monitor_data_structure(self):
state = BackendState(config=BackendConfig(url="http://b1"), live=True, models=["m1", "m2"])
app, _, slot_tracker, *_ = _make_monitor_app([state])
slot_tracker.set_capacity("http://b1", 4)
data = TestClient(app).get("/monitor/data").json()
for key in ("uptime", "total_requests", "queue_depth", "session_count",
"live_backend_count", "backends", "queue", "sessions_by_model"):
self.assertIn(key, data)
def test_monitor_data_backend_fields(self):
state = BackendState(config=BackendConfig(url="http://b1"), live=True, models=["m1"])
app, _, slot_tracker, *_ = _make_monitor_app([state])
slot_tracker.set_capacity("http://b1", 3)
backend = TestClient(app).get("/monitor/data").json()["backends"][0]
self.assertEqual(backend["url"], "http://b1")
self.assertTrue(backend["live"])
self.assertEqual(backend["models"], ["m1"])
self.assertEqual(backend["slots_total"], 3)
self.assertIn("slots_acquired", backend)
def test_monitor_data_dead_backend(self):
state = BackendState(config=BackendConfig(url="http://b2-dead"), live=False, models=[])
app, *_ = _make_monitor_app([state])
data = TestClient(app).get("/monitor/data").json()
backends = {b["url"]: b for b in data["backends"]}
self.assertFalse(backends["http://b2-dead"]["live"])
self.assertEqual(data["live_backend_count"], 0)
def test_monitor_data_empty_state(self):
app, *_ = _make_monitor_app([])
data = TestClient(app).get("/monitor/data").json()
self.assertEqual(data["backends"], [])
self.assertEqual(data["queue"], [])
self.assertEqual(data["session_count"], 0)
def test_monitor_total_requests_reflects_stats(self):
app, *_, stats = _make_monitor_app()
stats.increment_requests()
stats.increment_requests()
data = TestClient(app).get("/monitor/data").json()
self.assertEqual(data["total_requests"], 2)
def test_monitor_uptime_is_string(self):
app, *_ = _make_monitor_app()
data = TestClient(app).get("/monitor/data").json()
self.assertRegex(data["uptime"], r"^\d{2}:\d{2}:\d{2}$")
def test_monitor_last_poll_age_never_polled(self):
"""Backend that has never been polled should show null last_poll_age."""
state = BackendState(config=BackendConfig(url="http://b1"), live=False, models=[])
app, *_ = _make_monitor_app([state])
data = TestClient(app).get("/monitor/data").json()
self.assertIsNone(data["backends"][0]["last_poll_age"])
if __name__ == "__main__":
unittest.main()

59
tests/test_policies.py Normal file
View File

@@ -0,0 +1,59 @@
import unittest
from llamacpp_ha.policies import BackendCandidate, RoundRobinPolicy
def _cands(n: int) -> list[BackendCandidate]:
return [BackendCandidate(url=f"http://b{i}", index=i) for i in range(n)]
class TestRoundRobinPolicy(unittest.TestCase):
def test_single_backend_always_chosen(self):
policy = RoundRobinPolicy()
cands = _cands(1)
for _ in range(5):
result = policy.select("m", cands)
self.assertEqual(result.url, "http://b0")
def test_distributes_evenly_across_two(self):
policy = RoundRobinPolicy()
cands = _cands(2)
chosen = [policy.select("m", cands).url for _ in range(6)]
self.assertEqual(chosen.count("http://b0"), 3)
self.assertEqual(chosen.count("http://b1"), 3)
def test_distributes_evenly_across_three(self):
policy = RoundRobinPolicy()
cands = _cands(3)
chosen = [policy.select("m", cands).url for _ in range(9)]
for i in range(3):
self.assertEqual(chosen.count(f"http://b{i}"), 3)
def test_per_model_counters_independent(self):
policy = RoundRobinPolicy()
cands = _cands(2)
r1 = policy.select("model-a", cands)
r2 = policy.select("model-b", cands)
# Both start at 0 -> both get b0
self.assertEqual(r1.url, "http://b0")
self.assertEqual(r2.url, "http://b0")
# Next calls for each model independently advance
r3 = policy.select("model-a", cands)
r4 = policy.select("model-b", cands)
self.assertEqual(r3.url, "http://b1")
self.assertEqual(r4.url, "http://b1")
def test_empty_raises(self):
policy = RoundRobinPolicy()
with self.assertRaises(ValueError):
policy.select("m", [])
def test_wraps_around(self):
policy = RoundRobinPolicy()
cands = _cands(2)
urls = [policy.select("m", cands).url for _ in range(5)]
self.assertEqual(urls, ["http://b0", "http://b1", "http://b0", "http://b1", "http://b0"])
if __name__ == "__main__":
unittest.main()

92
tests/test_queue.py Normal file
View File

@@ -0,0 +1,92 @@
import asyncio
import unittest
from llamacpp_ha.queue import QueueEntry, RequestQueue
def _entry(**kwargs) -> QueueEntry:
e = QueueEntry(**kwargs)
e.future = asyncio.get_running_loop().create_future()
return e
class TestRequestQueue(unittest.IsolatedAsyncioTestCase):
async def test_enqueue_and_pending(self):
q = RequestQueue()
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m2")
await q.enqueue(e1)
await q.enqueue(e2)
pending = await q.pending()
self.assertEqual(len(pending), 2)
self.assertEqual(pending[0].model_id, "m1")
self.assertEqual(pending[1].model_id, "m2")
async def test_fifo_order(self):
q = RequestQueue()
ids = []
for i in range(5):
e = _entry(model_id=f"m{i}")
await q.enqueue(e)
ids.append(e.request_id)
pending = await q.pending()
self.assertEqual([e.request_id for e in pending], ids)
async def test_remove(self):
q = RequestQueue()
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m2")
await q.enqueue(e1)
await q.enqueue(e2)
await q.remove(e1)
pending = await q.pending()
self.assertEqual(len(pending), 1)
self.assertEqual(pending[0].model_id, "m2")
async def test_remove_nonexistent_no_error(self):
q = RequestQueue()
e = _entry(model_id="m")
await q.remove(e) # should not raise
async def test_depth(self):
q = RequestQueue()
self.assertEqual(await q.depth(), 0)
for _ in range(3):
await q.enqueue(_entry(model_id="m"))
self.assertEqual(await q.depth(), 3)
async def test_wakeup_event_set_on_enqueue(self):
q = RequestQueue()
self.assertFalse(q.wakeup_event.is_set())
await q.enqueue(_entry(model_id="m"))
self.assertTrue(q.wakeup_event.is_set())
def test_notify_sets_wakeup(self):
q = RequestQueue()
q.notify()
self.assertTrue(q.wakeup_event.is_set())
async def test_entry_metadata(self):
import time
before = time.monotonic()
e = QueueEntry(model_id="llama3", session_id="sess123", estimated_tokens=42)
after = time.monotonic()
self.assertEqual(e.model_id, "llama3")
self.assertEqual(e.session_id, "sess123")
self.assertEqual(e.estimated_tokens, 42)
self.assertIsNone(e.future)
self.assertGreaterEqual(e.arrival_time, before)
self.assertLessEqual(e.arrival_time, after)
self.assertGreaterEqual(e.wait_seconds, 0)
async def test_snapshot_truncates_session_id(self):
q = RequestQueue()
e = _entry(model_id="m", session_id="abcdef1234567890")
await q.enqueue(e)
snap = await q.snapshot()
self.assertEqual(len(snap), 1)
self.assertEqual(snap[0]["session_id"], "abcdef12")
if __name__ == "__main__":
unittest.main()

152
tests/test_registry.py Normal file
View File

@@ -0,0 +1,152 @@
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.registry import BackendRegistry, BackendState
def _make_config(backends=None):
if backends is None:
backends = [{"url": "http://b1"}, {"url": "http://b2"}]
return ProxyConfig(backends=[BackendConfig(**b) for b in backends])
class TestBackendRegistry(unittest.IsolatedAsyncioTestCase):
def _make_registry(self, backends=None):
cfg = _make_config(backends)
return BackendRegistry(cfg)
async def test_initial_state_all_dead(self):
reg = self._make_registry()
live = reg.get_all_live_backends()
self.assertEqual(live, [])
async def test_get_all_states_returns_all(self):
reg = self._make_registry()
states = reg.get_all_states()
self.assertEqual(len(states), 2)
urls = {s.url for s in states}
self.assertEqual(urls, {"http://b1", "http://b2"})
async def test_liveness_transition(self):
reg = self._make_registry([{"url": "http://b1"}])
state = reg.get_state("http://b1")
self.assertFalse(state.live)
# Simulate a successful poll by directly mutating state
async with reg._lock:
state.live = True
state.models = ["llama3"]
reg._rebuild_index()
live = reg.get_all_live_backends()
self.assertEqual(len(live), 1)
async def test_model_index_rebuilt(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1", "m2"]
b2.live = True
b2.models = ["m2", "m3"]
reg._rebuild_index()
self.assertEqual(len(reg.get_backends_for_model("m1")), 1)
self.assertEqual(len(reg.get_backends_for_model("m2")), 2)
self.assertEqual(len(reg.get_backends_for_model("m3")), 1)
async def test_get_all_models_deduplicated(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1", "shared"]
b2.live = True
b2.models = ["m2", "shared"]
reg._rebuild_index()
models = reg.get_all_models()
self.assertEqual(len(models), 3)
self.assertEqual(len(set(models)), 3) # no duplicates
async def test_dead_backend_excluded_from_model_index(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1"]
b2.live = False
b2.models = ["m1"]
reg._rebuild_index()
backends = reg.get_backends_for_model("m1")
self.assertEqual(len(backends), 1)
self.assertEqual(backends[0].url, "http://b1")
async def test_explicit_model_ids_used_over_backend_report(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["only-this"])]
)
reg = BackendRegistry(cfg)
# Simulate fetching models: should use config model_ids, not backend response
state = reg.get_state("http://b1")
mock_session = AsyncMock()
reg._session = mock_session
models = await reg._fetch_models(state, {})
self.assertEqual(models, ["only-this"])
async def test_last_poll_age_increases(self):
import time
reg = self._make_registry([{"url": "http://b1"}])
state = reg.get_state("http://b1")
state.last_poll_time = time.monotonic() - 10
age = state.last_poll_age
self.assertGreaterEqual(age, 9.9)
async def test_no_backends_for_unknown_model(self):
reg = self._make_registry()
result = reg.get_backends_for_model("no-such-model")
self.assertEqual(result, [])
async def test_poll_all_concurrent(self):
"""poll_all runs backends concurrently; a slow backend doesn't block others."""
cfg = ProxyConfig(
backends=[BackendConfig(url=f"http://b{i}") for i in range(5)]
)
reg = BackendRegistry(cfg)
poll_times = []
async def fake_poll_one(state):
poll_times.append(asyncio.get_event_loop().time())
await asyncio.sleep(0.05)
async with reg._lock:
state.live = True
state.models = ["m1"]
with patch.object(reg, "_poll_one", side_effect=fake_poll_one):
import time
start = time.monotonic()
await reg._poll_all()
elapsed = time.monotonic() - start
# 5 backends * 0.05s each, but concurrent: should finish in ~0.1s not 0.25s
self.assertLess(elapsed, 0.2)
self.assertEqual(len(reg.get_all_live_backends()), 5)
if __name__ == "__main__":
unittest.main()

300
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,300 @@
import asyncio
import unittest
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry, BackendState
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url: str, models: list[str] | None = None) -> BackendState:
cfg = BackendConfig(url=url)
return BackendState(config=cfg, live=True, models=models or ["m1"])
def _entry(**kwargs) -> QueueEntry:
e = QueueEntry(**kwargs)
e.future = asyncio.get_running_loop().create_future()
return e
class TestScheduler(unittest.IsolatedAsyncioTestCase):
def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
registry = BackendRegistry(cfg)
for state in (live_backends or []):
registry._states[state.url] = state
registry._rebuild_index()
slot_tracker = SlotTracker()
for state in (live_backends or []):
slot_tracker.set_capacity(state.url, 2)
session_store = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=max_queue_skip,
)
return scheduler, queue, registry, slot_tracker, session_store
async def test_dispatches_entry_to_live_backend(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_skips_full_backends(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
async def test_session_affinity_preferred(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2])
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b2")
async def test_session_affinity_fallback_when_full(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2])
slots.set_capacity("http://b2", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b2")
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_no_live_backends_entry_stays(self):
scheduler, queue, *_ = self._make_scheduler([])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
self.assertEqual(await queue.depth(), 1)
def test_notify_slot_released_wakes_queue(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
scheduler.notify_slot_released()
self.assertTrue(queue.wakeup_event.is_set())
async def test_cancelled_future_cleaned_up(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
entry.future.cancel()
await scheduler._dispatch_all()
self.assertEqual(await queue.depth(), 0)
async def test_round_robin_across_backends(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, *_ = self._make_scheduler([b1, b2])
results = []
for _ in range(4):
e = _entry(model_id="m1")
await queue.enqueue(e)
await scheduler._dispatch_all()
results.append(e.future.result().url)
self.assertEqual(results.count("http://b1"), 2)
self.assertEqual(results.count("http://b2"), 2)
async def test_slot_released_then_dispatch(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
await slots.release("http://b1")
scheduler.notify_slot_released()
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
# ------------------------------------------------------------------
# max_models / preemption prevention
# ------------------------------------------------------------------
async def test_max_models_blocks_second_model_on_same_backend(self):
b1 = _make_state("http://b1")
_, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
# Occupy a slot with model-a
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Request for model-b should stay queued
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])])
registry = BackendRegistry(cfg)
registry._states["http://b1"] = BackendState(
config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]),
live=True,
models=["m1", "m2"],
)
registry._rebuild_index()
sched2 = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slots,
session_store=SessionStore(),
policy=RoundRobinPolicy(),
)
entry = _entry(model_id="m2")
await queue.enqueue(entry)
await sched2._dispatch_all()
self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1")
async def test_max_models_allows_same_model(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
# ------------------------------------------------------------------
# N-skip reordering
# ------------------------------------------------------------------
async def test_no_reorder_when_max_queue_skip_zero(self):
"""Default FIFO: model-B request is not promoted over model-A."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0)
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# Fill b1; b2 is free with m2 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m2")
# Queue: [m1-request (blocked), m2-request (could go to b2)]
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
# Release b2 slot so m2 can be served
await slots.release("http://b2", "m2")
await scheduler._dispatch_all()
# m2 can be dispatched even with max_queue_skip=0 because _dispatch_all
# scans all entries (not strict head-of-line per model)
self.assertFalse(e_m1.future.done())
self.assertTrue(e_m2.future.done())
# skip_count must NOT be bumped when max_queue_skip=0
self.assertEqual(e_m1.skip_count, 0)
async def test_affinity_promotes_matching_model(self):
"""With max_queue_skip>0, a matching model gets promoted."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3)
slots.set_capacity("http://b1", 2)
# b1 already has m1 in-flight
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: [m2-entry (no affinity), m1-entry (affinity match)]
e_other = _entry(model_id="m2") # no backend serves m2
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m1 entry is promoted (affinity pass), m2 stays (no backend)
self.assertTrue(e_m1.future.done())
self.assertFalse(e_other.future.done())
# e_other was bypassed once
self.assertEqual(e_other.skip_count, 1)
async def test_skip_count_caps_reordering(self):
"""Once skip_count reaches max_queue_skip the entry freezes at head."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2)
slots.set_capacity("http://b1", 4)
# b1 has m1 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_other = _entry(model_id="m2")
e_other.skip_count = 2 # already at limit — must not be bypassed
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# Affinity pass stops at e_other (skip_count >= max_queue_skip),
# so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass.
# e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass.
self.assertTrue(e_m1.future.done())
# skip_count must NOT increase further (entry was frozen)
self.assertEqual(e_other.skip_count, 2)
if __name__ == "__main__":
unittest.main()

112
tests/test_session_store.py Normal file
View File

@@ -0,0 +1,112 @@
import asyncio
import unittest
from unittest.mock import patch
from llamacpp_ha.session_store import SessionStore, compute_prefix_hash
class TestComputePrefixHash(unittest.TestCase):
def test_same_messages_same_hash(self):
msgs = [{"role": "user", "content": "hello"}]
self.assertEqual(compute_prefix_hash(msgs), compute_prefix_hash(msgs))
def test_different_messages_different_hash(self):
a = [{"role": "user", "content": "hello"}]
b = [{"role": "user", "content": "world"}]
self.assertNotEqual(compute_prefix_hash(a), compute_prefix_hash(b))
def test_empty_messages(self):
h = compute_prefix_hash([])
self.assertIsInstance(h, str)
self.assertEqual(len(h), 16)
class TestSessionStore(unittest.IsolatedAsyncioTestCase):
async def test_get_or_create_new(self):
store = SessionStore(ttl=300.0)
session = await store.get_or_create("abc")
self.assertEqual(session.session_id, "abc")
self.assertIsNone(session.model_id)
async def test_get_or_create_existing(self):
store = SessionStore(ttl=300.0)
s1 = await store.get_or_create("abc")
await store.update("abc", model_id="llama3")
s2 = await store.get_or_create("abc")
self.assertEqual(s2.model_id, "llama3")
async def test_update_model_and_messages(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
msgs = [{"role": "user", "content": "hi"}]
await store.update("s1", model_id="m1", messages=msgs, preferred_backend="http://b1")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b1")
async def test_affinity_hit(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.update("s1", preferred_backend="http://b2")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b2")
async def test_affinity_miss_unknown_session(self):
store = SessionStore(ttl=300.0)
pref = await store.get_preferred_backend("nonexistent")
self.assertIsNone(pref)
async def test_ttl_expiry(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await asyncio.sleep(0.1)
pref = await store.get_preferred_backend("s1")
self.assertIsNone(pref)
async def test_expire_removes_stale(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await store.get_or_create("s2")
await asyncio.sleep(0.1)
removed = await store.expire()
self.assertEqual(removed, 2)
count = await store.count()
self.assertEqual(count, 0)
async def test_count_active(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
count = await store.count()
self.assertEqual(count, 2)
async def test_count_by_model(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
await store.get_or_create("s3")
await store.update("s1", model_id="m1")
await store.update("s2", model_id="m1")
await store.update("s3", model_id="m2")
by_model = await store.count_by_model()
self.assertEqual(by_model["m1"], 2)
self.assertEqual(by_model["m2"], 1)
async def test_update_nonexistent_session_no_error(self):
store = SessionStore(ttl=300.0)
await store.update("nope", model_id="m1") # should not raise
async def test_concurrent_access(self):
store = SessionStore(ttl=300.0)
async def worker(i):
sid = f"s{i}"
await store.get_or_create(sid)
await store.update(sid, model_id="m", preferred_backend=f"http://b{i}")
await asyncio.gather(*[worker(i) for i in range(20)])
count = await store.count()
self.assertEqual(count, 20)
if __name__ == "__main__":
unittest.main()

187
tests/test_slot_tracker.py Normal file
View File

@@ -0,0 +1,187 @@
import asyncio
import unittest
from llamacpp_ha.slot_tracker import SlotTracker
class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
async def test_acquire_when_free(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b")
acquired, total = tracker.usage("http://b")
self.assertEqual(acquired, 1)
self.assertEqual(total, 2)
async def test_has_free_slot(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
self.assertTrue(tracker.has_free_slot("http://b"))
await tracker.acquire("http://b")
self.assertFalse(tracker.has_free_slot("http://b"))
async def test_timeout_when_full(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.acquire("http://b")
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.05):
await tracker.acquire("http://b")
async def test_release_unblocks_waiter(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.acquire("http://b")
results = []
async def waiter():
async with asyncio.timeout(2.0):
await tracker.acquire("http://b")
results.append(True)
task = asyncio.create_task(waiter())
await asyncio.sleep(0.05)
await tracker.release("http://b")
await task
self.assertEqual(results, [True])
async def test_release_below_zero(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.release("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 0)
def test_set_capacity_increase(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_capacity("http://b", 3)
_, total = tracker.usage("http://b")
self.assertEqual(total, 3)
def test_unknown_url_defaults(self):
tracker = SlotTracker()
self.assertTrue(tracker.has_free_slot("http://unknown"))
acquired, total = tracker.usage("http://unknown")
self.assertEqual(acquired, 0)
self.assertEqual(total, 1)
async def test_acquire_zero_timeout_succeeds_then_fails(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
async with asyncio.timeout(0):
await tracker.acquire("http://b")
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0):
await tracker.acquire("http://b")
async def test_release_decrements(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b")
await tracker.acquire("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
await tracker.release("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 1)
# ------------------------------------------------------------------
# Model-aware tests
# ------------------------------------------------------------------
def test_can_accept_respects_max_models(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
async def test_max_models_blocks_second_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
# model-a is still accepted (same model, slot available)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
# model-b is blocked (max_models=1 already reached)
self.assertFalse(tracker.can_accept("http://b", "model-b"))
async def test_max_models_unblocks_after_release(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_active_model_set(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
await tracker.acquire("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.acquire("http://b", "model-b")
self.assertEqual(
tracker.active_model_set("http://b"), frozenset({"model-a", "model-b"})
)
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-b"}))
async def test_acquire_tracks_active_models(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
async def test_reset_acquired_clears_state(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
await tracker.reset_acquired("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 0)
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
async def test_reset_acquired_unblocks_waiters(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
unblocked = []
async def waiter():
async with asyncio.timeout(2.0):
await tracker.acquire("http://b", "model-b")
unblocked.append(True)
task = asyncio.create_task(waiter())
await asyncio.sleep(0.05)
self.assertFalse(unblocked)
await tracker.reset_acquired("http://b")
await task
self.assertEqual(unblocked, [True])
async def test_max_models_none_allows_any(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", None)
await tracker.acquire("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
self.assertTrue(tracker.can_accept("http://b", "model-c"))
if __name__ == "__main__":
unittest.main()