This commit is contained in:
2026-05-17 22:15:13 +02:00
parent 7344aa4ef4
commit d826b038ab
12 changed files with 323 additions and 214 deletions

View File

@@ -1,7 +1,8 @@
{
"permissions": {
"allow": [
"Bash(python -m pytest)"
"Bash(python -m pytest)",
"Bash(python -m pytest --tb=short -q)"
]
}
}

View File

@@ -34,7 +34,7 @@ def main() -> None:
overrides = {}
if args.host:
overrides["host"] = args.host
if args.port:
if args.port is not None:
overrides["port"] = args.port
try:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import itertools
import logging
from typing import AsyncIterator
@@ -118,43 +119,56 @@ async def forward_request(
raise
_best_effort_counter = itertools.count()
async def forward_best_effort(
request: Request,
registry_backends: list[BackendState],
session: aiohttp.ClientSession,
) -> Response:
"""Forward without slot gating to any live backend (catch-all paths)."""
"""Forward without slot gating to any live backend (catch-all paths).
Backends are tried in round-robin order; on connection failure the next
backend is attempted until all live backends are exhausted.
"""
if not registry_backends:
return Response(content="No live backends", status_code=503)
backend = registry_backends[0]
n = len(registry_backends)
start = next(_best_effort_counter) % n
ordered = registry_backends[start:] + registry_backends[:start]
path = request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
try:
async with session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
) as resp:
content_type = resp.headers.get("Content-Type", "")
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
data = await resp.read()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.error("Best-effort forward error: %s", exc)
return Response(content="Backend error", status_code=502)
for backend in ordered:
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
try:
async with session.request(
method=request.method,
url=target,
headers=headers,
data=body if body else None,
allow_redirects=False,
) as resp:
content_type = resp.headers.get("Content-Type", "")
response_headers = {
k: v
for k, v in resp.headers.items()
if k.lower() not in _HOP_BY_HOP
}
data = await resp.read()
return Response(
content=data,
status_code=resp.status,
headers=response_headers,
media_type=content_type,
)
except Exception as exc:
log.warning("Best-effort forward to %s failed: %s", backend.url, exc)
continue
return Response(content="All backends failed", status_code=502)

View File

@@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/data"])
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/"])
class ApiKeyMiddleware(BaseHTTPMiddleware):

View File

@@ -75,8 +75,8 @@ _HTML = """<!DOCTYPE html>
<h2>Queue</h2>
<table>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th></tr></thead>
<tbody id="queue-body"><tr><td colspan="5" class="empty">Queue is empty</td></tr></tbody>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th><th>Skips</th></tr></thead>
<tbody id="queue-body"><tr><td colspan="6" class="empty">Queue is empty</td></tr></tbody>
</table>
<h2>Sessions by Model</h2>
@@ -115,12 +115,12 @@ _HTML = """<!DOCTYPE html>
const qBody = document.getElementById('queue-body');
if (!data.queue.length) {
qBody.innerHTML = '<tr><td colspan="5" class="empty">Queue is empty</td></tr>';
qBody.innerHTML = '<tr><td colspan="6" class="empty">Queue is empty</td></tr>';
} else {
qBody.innerHTML = data.queue.map(e => {
const tok = e.estimated_tokens != null ? esc(e.estimated_tokens) : '<span class="empty">-</span>';
const sid = e.session_id ? esc(e.session_id) : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td></tr>`;
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td><td>${esc(e.skip_count)}</td></tr>`;
}).join('');
}

View File

@@ -80,11 +80,18 @@ def _session_id_from(request: Request) -> str | None:
def _attach_session(response: Response, session_id: str) -> None:
response.set_cookie(
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=True
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=False
)
response.headers[_SESSION_HEADER] = session_id
def _parse_body(raw: bytes) -> dict[str, Any]:
try:
return json.loads(raw) if raw else {}
except json.JSONDecodeError:
return {}
def _init_slot_tracker(
registry: BackendRegistry,
slot_tracker: SlotTracker,
@@ -127,6 +134,8 @@ async def _dispatch_entry(
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
slot_tracker: SlotTracker,
scheduler: Scheduler,
model_id: str,
session_id: str | None,
body: dict,
@@ -147,7 +156,13 @@ async def _dispatch_entry(
)
except asyncio.TimeoutError:
await request_queue.remove(entry)
if not entry.future.done():
if entry.future.done() and not entry.future.cancelled():
# Scheduler won the race: it acquired a slot and resolved the future
# between our timeout and our remove. Release that slot now or it leaks.
resolved: BackendState = entry.future.result()
await slot_tracker.release(resolved.url, model_id)
scheduler.notify_slot_released()
elif not entry.future.done():
entry.future.cancel()
return JSONResponse(
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
@@ -156,6 +171,43 @@ async def _dispatch_entry(
return backend
async def _recover_session_affinity(
session_id: str,
messages: list,
session_store: SessionStore,
) -> None:
"""Pre-seed a cookieless new session's preferred backend via message-prefix matching."""
if len(messages) < 2:
return
hint = await session_store.find_by_prefix(messages)
if hint:
await session_store.update(session_id, preferred_backend=hint)
async def _find_failover(
model_id: str,
excluded_urls: set[str],
registry: BackendRegistry,
slot_tracker: SlotTracker,
) -> BackendState | None:
"""Acquire a slot on any live backend not already tried. Returns None if none are free."""
candidates = (
registry.get_backends_for_model(model_id)
if model_id
else registry.get_all_live_backends()
)
for b in candidates:
if b.url in excluded_urls or not slot_tracker.can_accept(b.url, model_id):
continue
try:
async with asyncio.timeout(0):
await slot_tracker.acquire(b.url, model_id)
return b
except TimeoutError:
continue
return None
def _list_models(*, registry: BackendRegistry) -> JSONResponse:
models = registry.get_all_models()
data = [
@@ -199,6 +251,7 @@ async def _inference_endpoint(
request_queue: RequestQueue,
stats: ProxyStats,
config: ProxyConfig,
registry: BackendRegistry,
) -> Response:
if http.client is None:
return Response(
@@ -207,22 +260,46 @@ async def _inference_endpoint(
media_type=_APPLICATION_JSON,
)
raw = await request.body()
try:
body: dict[str, Any] = json.loads(raw) if raw else {}
except json.JSONDecodeError:
body = {}
body = _parse_body(await request.body())
model_id = _get_model(body)
session_id = _session_id_from(request) or session_store.new_session_id()
incoming_session_id = _session_id_from(request)
session_id = incoming_session_id or session_store.new_session_id()
if not incoming_session_id:
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
result = await _dispatch_entry(
request_queue, stats, config, model_id, session_id, body
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
)
if isinstance(result, Response):
return result
backend = result
tried: set[str] = {backend.url}
while True:
try:
response = await forward_request(
request=request,
backend=backend,
session=http.client,
slot_tracker=slot_tracker,
scheduler=scheduler,
model_id=model_id,
)
break
except aiohttp.ClientError as exc:
log.warning("Backend %s unreachable (%s); trying failover", backend.url, exc)
failover = await _find_failover(model_id, tried, registry, slot_tracker)
if failover is None:
return JSONResponse(
{"error": {"message": "Backend unreachable, no failover available", "type": "server_error"}},
status_code=502,
media_type=_APPLICATION_JSON,
)
tried.add(failover.url)
backend = failover
if model_id:
messages = body.get("messages", [])
await session_store.update(
@@ -232,14 +309,6 @@ async def _inference_endpoint(
preferred_backend=backend.url,
)
response = await forward_request(
request=request,
backend=backend,
session=http.client,
slot_tracker=slot_tracker,
scheduler=scheduler,
model_id=model_id,
)
_attach_session(response, session_id)
return response
@@ -321,6 +390,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
request_queue=request_queue,
stats=stats,
config=config,
registry=registry,
)
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001

View File

@@ -173,6 +173,10 @@ class BackendRegistry:
state.url + _SLOTS_PATH, headers=headers
) as resp:
if resp.status != 200:
log.debug(
"Slot capacity endpoint %s returned %d; keeping previous count %d",
state.url, resp.status, state.slot_capacity,
)
return state.slot_capacity
data = await resp.json()
if isinstance(data, list):

View File

@@ -100,6 +100,33 @@ class SessionStore:
result[s.model_id] = result.get(s.model_id, 0) + 1
return result
def _prefix_match_length(self, s: Session, messages: list[dict]) -> int:
"""Return the matched prefix length for s against messages, or 0 if no match."""
k = s.last_message_index
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
return 0
return k if compute_prefix_hash(messages[:k]) == s.prefix_hash else 0
async def find_by_prefix(self, messages: list[dict]) -> str | None:
"""Return the preferred backend whose stored conversation is a prefix of messages.
Checks whether hash(messages[:k]) equals a session's prefix_hash (k is
that session's last_message_index). Returns the preferred_backend of the
longest matching session, or None. This lets clients that omit the
session cookie still land on the backend holding their KV-cache.
"""
if not messages:
return None
best: Session | None = None
async with self._lock:
for s in self._sessions.values():
k = self._prefix_match_length(s, messages)
if k > 0 and (best is None or k > best.last_message_index):
best = s
if best is not None:
best.touch()
return best.preferred_backend if best is not None else None
async def snapshot(self) -> list[Session]:
async with self._lock:
return [

View File

@@ -6,7 +6,7 @@ from fastapi import Request
from starlette.datastructures import Headers
from llamacpp_ha.config import BackendConfig
from llamacpp_ha.forwarder import _forward_headers, forward_request
from llamacpp_ha.forwarder import _forward_headers, forward_best_effort, forward_request
from llamacpp_ha.registry import BackendState
from llamacpp_ha.slot_tracker import SlotTracker
@@ -272,5 +272,68 @@ class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
self.assertEqual(acquired, 0)
class TestForwardBestEffort(unittest.IsolatedAsyncioTestCase):
async def test_no_backends_returns_503(self):
req = _make_request({})
session = MagicMock()
response = await forward_best_effort(req, [], session)
self.assertEqual(response.status_code, 503)
async def test_single_backend_success(self):
state = _make_state("http://b1")
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"ok":true}')
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_best_effort(req, [state], session)
self.assertEqual(response.status_code, 200)
async def test_falls_back_on_connection_error(self):
"""If the first selected backend fails, the next one is tried."""
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
ctx_fail = MagicMock()
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("connection refused"))
ctx_fail.__aexit__ = AsyncMock(return_value=False)
ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok")
def pick_ctx(method, url, **kwargs):
return ctx_fail if "b1" in url else ctx_ok
session = MagicMock()
session.request = MagicMock(side_effect=pick_ctx)
req = _make_request({})
response = await forward_best_effort(req, [state1, state2], session)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.body, b"ok")
async def test_all_backends_fail_returns_502(self):
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
ctx_fail = MagicMock()
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("network error"))
ctx_fail.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx_fail)
req = _make_request({})
response = await forward_best_effort(req, [state1, state2], session)
self.assertEqual(response.status_code, 502)
async def test_status_code_and_body_preserved(self):
state = _make_state("http://b1")
ctx, _ = _mock_aiohttp_response(status=404, body=b"not found")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_best_effort(req, [state], session)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.body, b"not found")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,21 +1,8 @@
"""
Integration tests using an in-process fake llama.cpp backend.
The fake backend runs as a FastAPI app via httpx.AsyncClient(transport=...).
We patch aiohttp calls in the registry/forwarder so that requests go to our
fake server without opening real TCP sockets.
"""
from __future__ import annotations
import asyncio
import json
import unittest
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.testclient import TestClient
from llamacpp_ha.config import BackendConfig, ProxyConfig
@@ -23,7 +10,7 @@ from llamacpp_ha.proxy import create_app
# ---------------------------------------------------------------------------
# Helpers to build a minimal proxy with pre-seeded live backends
# Config helpers
# ---------------------------------------------------------------------------
@@ -44,78 +31,12 @@ def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) ->
)
def _seed_live(app_obj, backend_urls: list[str], models=None) -> None:
"""Directly seed the registry with live backends (skipping real HTTP poll)."""
from llamacpp_ha.registry import BackendRegistry
# Access the app's state via lifespan-created objects
# We do this by patching _poll_all to be a no-op and manually setting state
pass
# ---------------------------------------------------------------------------
# Fake llama.cpp server (in-process, no real sockets)
# ---------------------------------------------------------------------------
def _make_fake_backend(
model_id: str = "test-model",
response_content: str = '{"choices":[{"message":{"role":"assistant","content":"hello"}}]}',
streaming: bool = False,
slow_seconds: float = 0,
slot_count: int = 2,
) -> FastAPI:
fake = FastAPI()
@fake.get("/health")
async def health():
if slow_seconds:
await asyncio.sleep(slow_seconds)
return Response(status_code=200)
@fake.get("/v1/models")
async def models():
return JSONResponse({"object": "list", "data": [{"id": model_id, "object": "model"}]})
@fake.get("/slots")
async def slots():
return JSONResponse([{"id": i, "state": 0} for i in range(slot_count)])
@fake.post("/v1/chat/completions")
async def chat(request: Request):
if slow_seconds:
await asyncio.sleep(slow_seconds)
body = await request.json()
stream = body.get("stream", False)
if stream or streaming:
async def gen():
yield b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n'
yield b"data: [DONE]\n\n"
return StreamingResponse(gen(), media_type="text/event-stream")
return JSONResponse(json.loads(response_content))
@fake.post("/v1/completions")
async def completions(request: Request):
return JSONResponse({"choices": [{"text": "hello"}]})
return fake
# ---------------------------------------------------------------------------
# Integration test cases
# ---------------------------------------------------------------------------
class TestHealthEndpoint(unittest.TestCase):
def _make_proxy_with_live_backend(self):
cfg = _proxy_config(["http://fake-b1"])
app = create_app(cfg)
# Seed registry directly before first request
def _patch_registry(app_obj):
pass
return app, cfg
def test_health_no_live_backend_returns_503(self):
cfg = _proxy_config([])
app = create_app(cfg)
@@ -134,47 +55,24 @@ class TestHealthEndpoint(unittest.TestCase):
cfg = _proxy_config([], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
# No live backends -> 503 but authenticated
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
self.assertEqual(resp.status_code, 503)
class TestModelsEndpoint(unittest.TestCase):
def test_models_returns_list_format(self):
"""GET /v1/models returns the OpenAI list envelope even with no live backends."""
cfg = _proxy_config(["http://b1"])
app = create_app(cfg)
# Seed registry
from llamacpp_ha.registry import BackendState
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/v1/models")
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["object"], "list")
self.assertIn("data", data)
self.assertIsInstance(data["data"], list)
def test_models_deduplication(self):
"""Models appearing on multiple backends appear once."""
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["shared", "only-b1"]),
BackendConfig(url="http://b2", model_ids=["shared", "only-b2"]),
]
)
app = create_app(cfg)
from llamacpp_ha.registry import BackendState
# Manually seed live
import asyncio
async def _seed():
from llamacpp_ha import proxy as _proxy_module
# Can't easily seed without internal access in this test structure
pass
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/v1/models")
data = resp.json()
model_ids = [m["id"] for m in data["data"]]
self.assertEqual(len(model_ids), len(set(model_ids)), "Duplicates found")
# Model deduplication is tested at the registry level in
# test_registry.py::test_get_all_models_deduplicated.
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
@@ -199,6 +97,15 @@ class TestApiKeyMiddlewareIntegration(unittest.TestCase):
resp = client.get("/monitor/data")
self.assertEqual(resp.status_code, 200)
def test_monitor_trailing_slash_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor/")
# The trailing-slash path is auth-exempt; it may hit the catch-all (503)
# if the router doesn't redirect it, but it must never return 401.
self.assertNotEqual(resp.status_code, 401)
class TestSlotExhaustionTimeout(unittest.TestCase):
def test_timeout_when_no_slot_available(self):
@@ -209,7 +116,6 @@ class TestSlotExhaustionTimeout(unittest.TestCase):
default_slot_capacity=1,
)
app = create_app(cfg)
# No live backends -> scheduler never dispatches -> timeout -> 503
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/v1/chat/completions",
@@ -251,23 +157,17 @@ class TestMonitorIntegration(unittest.TestCase):
class TestFullForwardPath(unittest.TestCase):
"""
Full path test: proxy starts up, monitor/data reflects backend state
(dead until polled), scheduler queues requests when no backends are live.
"""
def test_backend_initially_dead_no_poll(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
slot_wait_timeout=2.0,
default_slot_capacity=2,
poll_interval=9999, # no background polling in test
poll_interval=9999,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/monitor/data")
data = resp.json()
# Backend exists but is dead until first poll completes
self.assertEqual(len(data["backends"]), 1)
self.assertFalse(data["backends"][0]["live"])
@@ -275,8 +175,7 @@ class TestFullForwardPath(unittest.TestCase):
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
async def test_dead_backend_not_in_model_index(self):
"""Dead backends are not returned for model routing."""
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.registry import BackendRegistry, BackendState
from llamacpp_ha.registry import BackendRegistry
cfg = ProxyConfig(
backends=[
@@ -286,7 +185,6 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
)
reg = BackendRegistry(cfg)
# Simulate: b1 live, b2 dead
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
@@ -297,15 +195,10 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(backends), 1)
self.assertEqual(backends[0].url, "http://b1")
async def test_all_dead_returns_empty(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
def test_all_dead_returns_empty(self):
from llamacpp_ha.registry import BackendRegistry
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["m"]),
]
)
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
backends = reg.get_backends_for_model("m")
self.assertEqual(backends, [])
@@ -313,7 +206,6 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
async def test_routes_to_backend_serving_model(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
@@ -352,13 +244,7 @@ class TestModelRouting(unittest.IsolatedAsyncioTestCase):
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
"""
Slot is fully occupied → request queues → slot released → request dispatched.
Tests the complete wait-and-unblock path without involving real HTTP.
"""
async def test_queued_request_dispatched_after_slot_release(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
@@ -379,28 +265,22 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
# Occupy the only slot (simulating an in-flight request)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
# Enqueue a request — scheduler should not be able to dispatch
loop = asyncio.get_running_loop()
entry = QueueEntry(model_id="m", future=loop.create_future())
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done(), "Should be queued, not dispatched yet")
self.assertFalse(entry.future.done())
# Release the slot (simulates a prior request completing)
await slots.release("http://b1")
scheduler.notify_slot_released()
# Scheduler re-evaluates on next wakeup
await scheduler._dispatch_all()
self.assertTrue(entry.future.done(), "Should be dispatched after slot freed")
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
@@ -421,7 +301,6 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
# Fill both slots
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
async with asyncio.timeout(1.0):
@@ -433,30 +312,20 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
await queue.enqueue(e)
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 0)
self.assertEqual(sum(1 for e in entries if e.future.done()), 0)
# Free one slot → one request dispatched
await slots.release("http://b1")
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 1)
self.assertEqual(sum(1 for e in entries if e.future.done()), 1)
# Free second slot → another dispatched
await slots.release("http://b1")
await scheduler._dispatch_all()
dispatched = sum(1 for e in entries if e.future.done())
self.assertEqual(dispatched, 2)
self.assertEqual(sum(1 for e in entries if e.future.done()), 2)
class TestSessionCookieIntegration(unittest.TestCase):
def test_session_cookie_set_in_response(self):
"""Proxy sets X-Session-ID header and cookie on inference responses."""
# No live backends → times out → 503, but we still get session tracking
# We verify the session machinery works by checking that when the proxy
# processes a request, it assigns a session and would attach it to the response.
# For a full round-trip we'd need a live backend; here we verify the 503
# path still doesn't crash the session handling.
def test_session_timeout_returns_503_with_error_body(self):
"""On slot timeout the 503 body is valid JSON with an error field."""
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.1,
@@ -467,9 +336,7 @@ class TestSessionCookieIntegration(unittest.TestCase):
"/v1/chat/completions",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
# Times out (no live backend) → 503
self.assertEqual(resp.status_code, 503)
# Even on timeout, no crash
self.assertIn("error", resp.json())

View File

@@ -131,7 +131,7 @@ class TestBackendRegistry(unittest.IsolatedAsyncioTestCase):
poll_times = []
async def fake_poll_one(state):
poll_times.append(asyncio.get_event_loop().time())
poll_times.append(asyncio.get_running_loop().time())
await asyncio.sleep(0.05)
async with reg._lock:
state.live = True

View File

@@ -30,7 +30,7 @@ class TestSessionStore(unittest.IsolatedAsyncioTestCase):
async def test_get_or_create_existing(self):
store = SessionStore(ttl=300.0)
s1 = await store.get_or_create("abc")
await store.get_or_create("abc")
await store.update("abc", model_id="llama3")
s2 = await store.get_or_create("abc")
self.assertEqual(s2.model_id, "llama3")
@@ -95,6 +95,69 @@ class TestSessionStore(unittest.IsolatedAsyncioTestCase):
store = SessionStore(ttl=300.0)
await store.update("nope", model_id="m1") # should not raise
async def test_find_by_prefix_no_messages(self):
store = SessionStore(ttl=300.0)
result = await store.find_by_prefix([])
self.assertIsNone(result)
async def test_find_by_prefix_exact_match(self):
"""Single-turn session is found when messages match exactly."""
store = SessionStore(ttl=300.0)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs, preferred_backend="http://b1")
result = await store.find_by_prefix(msgs)
self.assertEqual(result, "http://b1")
async def test_find_by_prefix_continuation(self):
"""Turn-1 session is found when turn-2 messages extend the stored prefix."""
store = SessionStore(ttl=300.0)
turn1 = [{"role": "user", "content": "hello"}]
turn2 = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "tell me more"},
]
await store.update("s1", messages=turn1, preferred_backend="http://b1")
result = await store.find_by_prefix(turn2)
self.assertEqual(result, "http://b1")
async def test_find_by_prefix_prefers_longest_match(self):
"""When multiple sessions match, the one with the longer stored prefix wins."""
store = SessionStore(ttl=300.0)
turn1 = [{"role": "user", "content": "hello"}]
turn2 = [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]
turn3 = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "more"},
]
await store.update("s1", messages=turn1, preferred_backend="http://b1")
await store.update("s2", messages=turn2, preferred_backend="http://b2")
result = await store.find_by_prefix(turn3)
self.assertEqual(result, "http://b2") # longer match wins
async def test_find_by_prefix_no_match(self):
store = SessionStore(ttl=300.0)
await store.update("s1", messages=[{"role": "user", "content": "hello"}], preferred_backend="http://b1")
result = await store.find_by_prefix([{"role": "user", "content": "completely different"}])
self.assertIsNone(result)
async def test_find_by_prefix_ignores_expired(self):
store = SessionStore(ttl=0.05)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs, preferred_backend="http://b1")
import asyncio
await asyncio.sleep(0.1)
result = await store.find_by_prefix(msgs)
self.assertIsNone(result)
async def test_find_by_prefix_ignores_session_without_backend(self):
store = SessionStore(ttl=300.0)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs) # no preferred_backend
result = await store.find_by_prefix(msgs)
self.assertIsNone(result)
async def test_concurrent_access(self):
store = SessionStore(ttl=300.0)