fixes
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(python -m pytest)"
|
||||
"Bash(python -m pytest)",
|
||||
"Bash(python -m pytest --tb=short -q)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ def main() -> None:
|
||||
overrides = {}
|
||||
if args.host:
|
||||
overrides["host"] = args.host
|
||||
if args.port:
|
||||
if args.port is not None:
|
||||
overrides["port"] = args.port
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import AsyncIterator
|
||||
|
||||
@@ -118,43 +119,56 @@ async def forward_request(
|
||||
raise
|
||||
|
||||
|
||||
_best_effort_counter = itertools.count()
|
||||
|
||||
|
||||
async def forward_best_effort(
|
||||
request: Request,
|
||||
registry_backends: list[BackendState],
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response:
|
||||
"""Forward without slot gating to any live backend (catch-all paths)."""
|
||||
"""Forward without slot gating to any live backend (catch-all paths).
|
||||
|
||||
Backends are tried in round-robin order; on connection failure the next
|
||||
backend is attempted until all live backends are exhausted.
|
||||
"""
|
||||
if not registry_backends:
|
||||
return Response(content="No live backends", status_code=503)
|
||||
|
||||
backend = registry_backends[0]
|
||||
n = len(registry_backends)
|
||||
start = next(_best_effort_counter) % n
|
||||
ordered = registry_backends[start:] + registry_backends[:start]
|
||||
|
||||
path = request.url.path
|
||||
query = request.url.query
|
||||
target = backend.url + path + (f"?{query}" if query else "")
|
||||
headers = _forward_headers(request, backend)
|
||||
body = await request.body()
|
||||
|
||||
try:
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=target,
|
||||
headers=headers,
|
||||
data=body if body else None,
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
response_headers = {
|
||||
k: v
|
||||
for k, v in resp.headers.items()
|
||||
if k.lower() not in _HOP_BY_HOP
|
||||
}
|
||||
data = await resp.read()
|
||||
return Response(
|
||||
content=data,
|
||||
status_code=resp.status,
|
||||
headers=response_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.error("Best-effort forward error: %s", exc)
|
||||
return Response(content="Backend error", status_code=502)
|
||||
for backend in ordered:
|
||||
target = backend.url + path + (f"?{query}" if query else "")
|
||||
headers = _forward_headers(request, backend)
|
||||
try:
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=target,
|
||||
headers=headers,
|
||||
data=body if body else None,
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
response_headers = {
|
||||
k: v
|
||||
for k, v in resp.headers.items()
|
||||
if k.lower() not in _HOP_BY_HOP
|
||||
}
|
||||
data = await resp.read()
|
||||
return Response(
|
||||
content=data,
|
||||
status_code=resp.status,
|
||||
headers=response_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Best-effort forward to %s failed: %s", backend.url, exc)
|
||||
continue
|
||||
|
||||
return Response(content="All backends failed", status_code=502)
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/data"])
|
||||
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/"])
|
||||
|
||||
|
||||
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
@@ -75,8 +75,8 @@ _HTML = """<!DOCTYPE html>
|
||||
|
||||
<h2>Queue</h2>
|
||||
<table>
|
||||
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th></tr></thead>
|
||||
<tbody id="queue-body"><tr><td colspan="5" class="empty">Queue is empty</td></tr></tbody>
|
||||
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th><th>Skips</th></tr></thead>
|
||||
<tbody id="queue-body"><tr><td colspan="6" class="empty">Queue is empty</td></tr></tbody>
|
||||
</table>
|
||||
|
||||
<h2>Sessions by Model</h2>
|
||||
@@ -115,12 +115,12 @@ _HTML = """<!DOCTYPE html>
|
||||
|
||||
const qBody = document.getElementById('queue-body');
|
||||
if (!data.queue.length) {
|
||||
qBody.innerHTML = '<tr><td colspan="5" class="empty">Queue is empty</td></tr>';
|
||||
qBody.innerHTML = '<tr><td colspan="6" class="empty">Queue is empty</td></tr>';
|
||||
} else {
|
||||
qBody.innerHTML = data.queue.map(e => {
|
||||
const tok = e.estimated_tokens != null ? esc(e.estimated_tokens) : '<span class="empty">-</span>';
|
||||
const sid = e.session_id ? esc(e.session_id) : '<span class="empty">-</span>';
|
||||
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td></tr>`;
|
||||
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td><td>${esc(e.skip_count)}</td></tr>`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
|
||||
@@ -80,11 +80,18 @@ def _session_id_from(request: Request) -> str | None:
|
||||
|
||||
def _attach_session(response: Response, session_id: str) -> None:
|
||||
response.set_cookie(
|
||||
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=True
|
||||
_SESSION_COOKIE, session_id, httponly=True, samesite="lax", secure=False
|
||||
)
|
||||
response.headers[_SESSION_HEADER] = session_id
|
||||
|
||||
|
||||
def _parse_body(raw: bytes) -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(raw) if raw else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
|
||||
def _init_slot_tracker(
|
||||
registry: BackendRegistry,
|
||||
slot_tracker: SlotTracker,
|
||||
@@ -127,6 +134,8 @@ async def _dispatch_entry(
|
||||
request_queue: RequestQueue,
|
||||
stats: ProxyStats,
|
||||
config: ProxyConfig,
|
||||
slot_tracker: SlotTracker,
|
||||
scheduler: Scheduler,
|
||||
model_id: str,
|
||||
session_id: str | None,
|
||||
body: dict,
|
||||
@@ -147,7 +156,13 @@ async def _dispatch_entry(
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await request_queue.remove(entry)
|
||||
if not entry.future.done():
|
||||
if entry.future.done() and not entry.future.cancelled():
|
||||
# Scheduler won the race: it acquired a slot and resolved the future
|
||||
# between our timeout and our remove. Release that slot now or it leaks.
|
||||
resolved: BackendState = entry.future.result()
|
||||
await slot_tracker.release(resolved.url, model_id)
|
||||
scheduler.notify_slot_released()
|
||||
elif not entry.future.done():
|
||||
entry.future.cancel()
|
||||
return JSONResponse(
|
||||
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
|
||||
@@ -156,6 +171,43 @@ async def _dispatch_entry(
|
||||
return backend
|
||||
|
||||
|
||||
async def _recover_session_affinity(
|
||||
session_id: str,
|
||||
messages: list,
|
||||
session_store: SessionStore,
|
||||
) -> None:
|
||||
"""Pre-seed a cookieless new session's preferred backend via message-prefix matching."""
|
||||
if len(messages) < 2:
|
||||
return
|
||||
hint = await session_store.find_by_prefix(messages)
|
||||
if hint:
|
||||
await session_store.update(session_id, preferred_backend=hint)
|
||||
|
||||
|
||||
async def _find_failover(
|
||||
model_id: str,
|
||||
excluded_urls: set[str],
|
||||
registry: BackendRegistry,
|
||||
slot_tracker: SlotTracker,
|
||||
) -> BackendState | None:
|
||||
"""Acquire a slot on any live backend not already tried. Returns None if none are free."""
|
||||
candidates = (
|
||||
registry.get_backends_for_model(model_id)
|
||||
if model_id
|
||||
else registry.get_all_live_backends()
|
||||
)
|
||||
for b in candidates:
|
||||
if b.url in excluded_urls or not slot_tracker.can_accept(b.url, model_id):
|
||||
continue
|
||||
try:
|
||||
async with asyncio.timeout(0):
|
||||
await slot_tracker.acquire(b.url, model_id)
|
||||
return b
|
||||
except TimeoutError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _list_models(*, registry: BackendRegistry) -> JSONResponse:
|
||||
models = registry.get_all_models()
|
||||
data = [
|
||||
@@ -199,6 +251,7 @@ async def _inference_endpoint(
|
||||
request_queue: RequestQueue,
|
||||
stats: ProxyStats,
|
||||
config: ProxyConfig,
|
||||
registry: BackendRegistry,
|
||||
) -> Response:
|
||||
if http.client is None:
|
||||
return Response(
|
||||
@@ -207,22 +260,46 @@ async def _inference_endpoint(
|
||||
media_type=_APPLICATION_JSON,
|
||||
)
|
||||
|
||||
raw = await request.body()
|
||||
try:
|
||||
body: dict[str, Any] = json.loads(raw) if raw else {}
|
||||
except json.JSONDecodeError:
|
||||
body = {}
|
||||
|
||||
body = _parse_body(await request.body())
|
||||
model_id = _get_model(body)
|
||||
session_id = _session_id_from(request) or session_store.new_session_id()
|
||||
incoming_session_id = _session_id_from(request)
|
||||
session_id = incoming_session_id or session_store.new_session_id()
|
||||
|
||||
if not incoming_session_id:
|
||||
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
|
||||
|
||||
result = await _dispatch_entry(
|
||||
request_queue, stats, config, model_id, session_id, body
|
||||
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
|
||||
)
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
|
||||
backend = result
|
||||
tried: set[str] = {backend.url}
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await forward_request(
|
||||
request=request,
|
||||
backend=backend,
|
||||
session=http.client,
|
||||
slot_tracker=slot_tracker,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
)
|
||||
break
|
||||
except aiohttp.ClientError as exc:
|
||||
log.warning("Backend %s unreachable (%s); trying failover", backend.url, exc)
|
||||
failover = await _find_failover(model_id, tried, registry, slot_tracker)
|
||||
if failover is None:
|
||||
return JSONResponse(
|
||||
{"error": {"message": "Backend unreachable, no failover available", "type": "server_error"}},
|
||||
status_code=502,
|
||||
media_type=_APPLICATION_JSON,
|
||||
)
|
||||
tried.add(failover.url)
|
||||
backend = failover
|
||||
|
||||
if model_id:
|
||||
messages = body.get("messages", [])
|
||||
await session_store.update(
|
||||
@@ -232,14 +309,6 @@ async def _inference_endpoint(
|
||||
preferred_backend=backend.url,
|
||||
)
|
||||
|
||||
response = await forward_request(
|
||||
request=request,
|
||||
backend=backend,
|
||||
session=http.client,
|
||||
slot_tracker=slot_tracker,
|
||||
scheduler=scheduler,
|
||||
model_id=model_id,
|
||||
)
|
||||
_attach_session(response, session_id)
|
||||
return response
|
||||
|
||||
@@ -321,6 +390,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
|
||||
request_queue=request_queue,
|
||||
stats=stats,
|
||||
config=config,
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001
|
||||
|
||||
@@ -173,6 +173,10 @@ class BackendRegistry:
|
||||
state.url + _SLOTS_PATH, headers=headers
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
log.debug(
|
||||
"Slot capacity endpoint %s returned %d; keeping previous count %d",
|
||||
state.url, resp.status, state.slot_capacity,
|
||||
)
|
||||
return state.slot_capacity
|
||||
data = await resp.json()
|
||||
if isinstance(data, list):
|
||||
|
||||
@@ -100,6 +100,33 @@ class SessionStore:
|
||||
result[s.model_id] = result.get(s.model_id, 0) + 1
|
||||
return result
|
||||
|
||||
def _prefix_match_length(self, s: Session, messages: list[dict]) -> int:
|
||||
"""Return the matched prefix length for s against messages, or 0 if no match."""
|
||||
k = s.last_message_index
|
||||
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
|
||||
return 0
|
||||
return k if compute_prefix_hash(messages[:k]) == s.prefix_hash else 0
|
||||
|
||||
async def find_by_prefix(self, messages: list[dict]) -> str | None:
|
||||
"""Return the preferred backend whose stored conversation is a prefix of messages.
|
||||
|
||||
Checks whether hash(messages[:k]) equals a session's prefix_hash (k is
|
||||
that session's last_message_index). Returns the preferred_backend of the
|
||||
longest matching session, or None. This lets clients that omit the
|
||||
session cookie still land on the backend holding their KV-cache.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
best: Session | None = None
|
||||
async with self._lock:
|
||||
for s in self._sessions.values():
|
||||
k = self._prefix_match_length(s, messages)
|
||||
if k > 0 and (best is None or k > best.last_message_index):
|
||||
best = s
|
||||
if best is not None:
|
||||
best.touch()
|
||||
return best.preferred_backend if best is not None else None
|
||||
|
||||
async def snapshot(self) -> list[Session]:
|
||||
async with self._lock:
|
||||
return [
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi import Request
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from llamacpp_ha.config import BackendConfig
|
||||
from llamacpp_ha.forwarder import _forward_headers, forward_request
|
||||
from llamacpp_ha.forwarder import _forward_headers, forward_best_effort, forward_request
|
||||
from llamacpp_ha.registry import BackendState
|
||||
from llamacpp_ha.slot_tracker import SlotTracker
|
||||
|
||||
@@ -272,5 +272,68 @@ class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(acquired, 0)
|
||||
|
||||
|
||||
class TestForwardBestEffort(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_no_backends_returns_503(self):
|
||||
req = _make_request({})
|
||||
session = MagicMock()
|
||||
response = await forward_best_effort(req, [], session)
|
||||
self.assertEqual(response.status_code, 503)
|
||||
|
||||
async def test_single_backend_success(self):
|
||||
state = _make_state("http://b1")
|
||||
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"ok":true}')
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(return_value=ctx)
|
||||
req = _make_request({})
|
||||
response = await forward_best_effort(req, [state], session)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
async def test_falls_back_on_connection_error(self):
|
||||
"""If the first selected backend fails, the next one is tried."""
|
||||
state1 = _make_state("http://b1")
|
||||
state2 = _make_state("http://b2")
|
||||
|
||||
ctx_fail = MagicMock()
|
||||
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("connection refused"))
|
||||
ctx_fail.__aexit__ = AsyncMock(return_value=False)
|
||||
ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok")
|
||||
|
||||
def pick_ctx(method, url, **kwargs):
|
||||
return ctx_fail if "b1" in url else ctx_ok
|
||||
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(side_effect=pick_ctx)
|
||||
req = _make_request({})
|
||||
|
||||
response = await forward_best_effort(req, [state1, state2], session)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.body, b"ok")
|
||||
|
||||
async def test_all_backends_fail_returns_502(self):
|
||||
state1 = _make_state("http://b1")
|
||||
state2 = _make_state("http://b2")
|
||||
|
||||
ctx_fail = MagicMock()
|
||||
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("network error"))
|
||||
ctx_fail.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(return_value=ctx_fail)
|
||||
req = _make_request({})
|
||||
|
||||
response = await forward_best_effort(req, [state1, state2], session)
|
||||
self.assertEqual(response.status_code, 502)
|
||||
|
||||
async def test_status_code_and_body_preserved(self):
|
||||
state = _make_state("http://b1")
|
||||
ctx, _ = _mock_aiohttp_response(status=404, body=b"not found")
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(return_value=ctx)
|
||||
req = _make_request({})
|
||||
response = await forward_best_effort(req, [state], session)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response.body, b"not found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,21 +1,8 @@
|
||||
"""
|
||||
Integration tests using an in-process fake llama.cpp backend.
|
||||
|
||||
The fake backend runs as a FastAPI app via httpx.AsyncClient(transport=...).
|
||||
We patch aiohttp calls in the registry/forwarder so that requests go to our
|
||||
fake server without opening real TCP sockets.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
@@ -23,7 +10,7 @@ from llamacpp_ha.proxy import create_app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build a minimal proxy with pre-seeded live backends
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -44,78 +31,12 @@ def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) ->
|
||||
)
|
||||
|
||||
|
||||
def _seed_live(app_obj, backend_urls: list[str], models=None) -> None:
|
||||
"""Directly seed the registry with live backends (skipping real HTTP poll)."""
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
# Access the app's state via lifespan-created objects
|
||||
# We do this by patching _poll_all to be a no-op and manually setting state
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake llama.cpp server (in-process, no real sockets)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_fake_backend(
|
||||
model_id: str = "test-model",
|
||||
response_content: str = '{"choices":[{"message":{"role":"assistant","content":"hello"}}]}',
|
||||
streaming: bool = False,
|
||||
slow_seconds: float = 0,
|
||||
slot_count: int = 2,
|
||||
) -> FastAPI:
|
||||
fake = FastAPI()
|
||||
|
||||
@fake.get("/health")
|
||||
async def health():
|
||||
if slow_seconds:
|
||||
await asyncio.sleep(slow_seconds)
|
||||
return Response(status_code=200)
|
||||
|
||||
@fake.get("/v1/models")
|
||||
async def models():
|
||||
return JSONResponse({"object": "list", "data": [{"id": model_id, "object": "model"}]})
|
||||
|
||||
@fake.get("/slots")
|
||||
async def slots():
|
||||
return JSONResponse([{"id": i, "state": 0} for i in range(slot_count)])
|
||||
|
||||
@fake.post("/v1/chat/completions")
|
||||
async def chat(request: Request):
|
||||
if slow_seconds:
|
||||
await asyncio.sleep(slow_seconds)
|
||||
body = await request.json()
|
||||
stream = body.get("stream", False)
|
||||
if stream or streaming:
|
||||
async def gen():
|
||||
yield b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n'
|
||||
yield b"data: [DONE]\n\n"
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
return JSONResponse(json.loads(response_content))
|
||||
|
||||
@fake.post("/v1/completions")
|
||||
async def completions(request: Request):
|
||||
return JSONResponse({"choices": [{"text": "hello"}]})
|
||||
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration test cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthEndpoint(unittest.TestCase):
|
||||
def _make_proxy_with_live_backend(self):
|
||||
cfg = _proxy_config(["http://fake-b1"])
|
||||
app = create_app(cfg)
|
||||
|
||||
# Seed registry directly before first request
|
||||
def _patch_registry(app_obj):
|
||||
pass
|
||||
|
||||
return app, cfg
|
||||
|
||||
def test_health_no_live_backend_returns_503(self):
|
||||
cfg = _proxy_config([])
|
||||
app = create_app(cfg)
|
||||
@@ -134,47 +55,24 @@ class TestHealthEndpoint(unittest.TestCase):
|
||||
cfg = _proxy_config([], api_keys=["mykey"])
|
||||
app = create_app(cfg)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
# No live backends -> 503 but authenticated
|
||||
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
|
||||
self.assertEqual(resp.status_code, 503)
|
||||
|
||||
|
||||
class TestModelsEndpoint(unittest.TestCase):
|
||||
def test_models_returns_list_format(self):
|
||||
"""GET /v1/models returns the OpenAI list envelope even with no live backends."""
|
||||
cfg = _proxy_config(["http://b1"])
|
||||
app = create_app(cfg)
|
||||
# Seed registry
|
||||
from llamacpp_ha.registry import BackendState
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
resp = client.get("/v1/models")
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
data = resp.json()
|
||||
self.assertEqual(data["object"], "list")
|
||||
self.assertIn("data", data)
|
||||
self.assertIsInstance(data["data"], list)
|
||||
|
||||
def test_models_deduplication(self):
|
||||
"""Models appearing on multiple backends appear once."""
|
||||
cfg = ProxyConfig(
|
||||
backends=[
|
||||
BackendConfig(url="http://b1", model_ids=["shared", "only-b1"]),
|
||||
BackendConfig(url="http://b2", model_ids=["shared", "only-b2"]),
|
||||
]
|
||||
)
|
||||
app = create_app(cfg)
|
||||
from llamacpp_ha.registry import BackendState
|
||||
# Manually seed live
|
||||
import asyncio
|
||||
|
||||
async def _seed():
|
||||
from llamacpp_ha import proxy as _proxy_module
|
||||
# Can't easily seed without internal access in this test structure
|
||||
pass
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
resp = client.get("/v1/models")
|
||||
data = resp.json()
|
||||
model_ids = [m["id"] for m in data["data"]]
|
||||
self.assertEqual(len(model_ids), len(set(model_ids)), "Duplicates found")
|
||||
# Model deduplication is tested at the registry level in
|
||||
# test_registry.py::test_get_all_models_deduplicated.
|
||||
|
||||
|
||||
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
|
||||
@@ -199,6 +97,15 @@ class TestApiKeyMiddlewareIntegration(unittest.TestCase):
|
||||
resp = client.get("/monitor/data")
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
|
||||
def test_monitor_trailing_slash_exempt_from_auth(self):
|
||||
cfg = _proxy_config([], api_keys=["secret"])
|
||||
app = create_app(cfg)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
resp = client.get("/monitor/")
|
||||
# The trailing-slash path is auth-exempt; it may hit the catch-all (503)
|
||||
# if the router doesn't redirect it, but it must never return 401.
|
||||
self.assertNotEqual(resp.status_code, 401)
|
||||
|
||||
|
||||
class TestSlotExhaustionTimeout(unittest.TestCase):
|
||||
def test_timeout_when_no_slot_available(self):
|
||||
@@ -209,7 +116,6 @@ class TestSlotExhaustionTimeout(unittest.TestCase):
|
||||
default_slot_capacity=1,
|
||||
)
|
||||
app = create_app(cfg)
|
||||
# No live backends -> scheduler never dispatches -> timeout -> 503
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
@@ -251,23 +157,17 @@ class TestMonitorIntegration(unittest.TestCase):
|
||||
|
||||
|
||||
class TestFullForwardPath(unittest.TestCase):
|
||||
"""
|
||||
Full path test: proxy starts up, monitor/data reflects backend state
|
||||
(dead until polled), scheduler queues requests when no backends are live.
|
||||
"""
|
||||
|
||||
def test_backend_initially_dead_no_poll(self):
|
||||
cfg = ProxyConfig(
|
||||
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
|
||||
slot_wait_timeout=2.0,
|
||||
default_slot_capacity=2,
|
||||
poll_interval=9999, # no background polling in test
|
||||
poll_interval=9999,
|
||||
)
|
||||
app = create_app(cfg)
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/monitor/data")
|
||||
data = resp.json()
|
||||
# Backend exists but is dead until first poll completes
|
||||
self.assertEqual(len(data["backends"]), 1)
|
||||
self.assertFalse(data["backends"][0]["live"])
|
||||
|
||||
@@ -275,8 +175,7 @@ class TestFullForwardPath(unittest.TestCase):
|
||||
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_dead_backend_not_in_model_index(self):
|
||||
"""Dead backends are not returned for model routing."""
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
from llamacpp_ha.registry import BackendRegistry, BackendState
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
|
||||
cfg = ProxyConfig(
|
||||
backends=[
|
||||
@@ -286,7 +185,6 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
reg = BackendRegistry(cfg)
|
||||
|
||||
# Simulate: b1 live, b2 dead
|
||||
async with reg._lock:
|
||||
reg._states["http://b1"].live = True
|
||||
reg._states["http://b1"].models = ["m"]
|
||||
@@ -297,15 +195,10 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(backends), 1)
|
||||
self.assertEqual(backends[0].url, "http://b1")
|
||||
|
||||
async def test_all_dead_returns_empty(self):
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
def test_all_dead_returns_empty(self):
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
|
||||
cfg = ProxyConfig(
|
||||
backends=[
|
||||
BackendConfig(url="http://b1", model_ids=["m"]),
|
||||
]
|
||||
)
|
||||
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
||||
reg = BackendRegistry(cfg)
|
||||
backends = reg.get_backends_for_model("m")
|
||||
self.assertEqual(backends, [])
|
||||
@@ -313,7 +206,6 @@ class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_routes_to_backend_serving_model(self):
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
from llamacpp_ha.policies import RoundRobinPolicy
|
||||
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
@@ -352,13 +244,7 @@ class TestModelRouting(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
|
||||
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
||||
"""
|
||||
Slot is fully occupied → request queues → slot released → request dispatched.
|
||||
Tests the complete wait-and-unblock path without involving real HTTP.
|
||||
"""
|
||||
|
||||
async def test_queued_request_dispatched_after_slot_release(self):
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
from llamacpp_ha.policies import RoundRobinPolicy
|
||||
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
@@ -379,28 +265,22 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
||||
queue = RequestQueue()
|
||||
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
||||
|
||||
# Occupy the only slot (simulating an in-flight request)
|
||||
async with asyncio.timeout(1.0):
|
||||
await slots.acquire("http://b1")
|
||||
|
||||
# Enqueue a request — scheduler should not be able to dispatch
|
||||
loop = asyncio.get_running_loop()
|
||||
entry = QueueEntry(model_id="m", future=loop.create_future())
|
||||
await queue.enqueue(entry)
|
||||
await scheduler._dispatch_all()
|
||||
self.assertFalse(entry.future.done(), "Should be queued, not dispatched yet")
|
||||
self.assertFalse(entry.future.done())
|
||||
|
||||
# Release the slot (simulates a prior request completing)
|
||||
await slots.release("http://b1")
|
||||
scheduler.notify_slot_released()
|
||||
|
||||
# Scheduler re-evaluates on next wakeup
|
||||
await scheduler._dispatch_all()
|
||||
self.assertTrue(entry.future.done(), "Should be dispatched after slot freed")
|
||||
self.assertTrue(entry.future.done())
|
||||
self.assertEqual(entry.future.result().url, "http://b1")
|
||||
|
||||
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
|
||||
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
||||
from llamacpp_ha.policies import RoundRobinPolicy
|
||||
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
||||
from llamacpp_ha.registry import BackendRegistry
|
||||
@@ -421,7 +301,6 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
||||
queue = RequestQueue()
|
||||
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
||||
|
||||
# Fill both slots
|
||||
async with asyncio.timeout(1.0):
|
||||
await slots.acquire("http://b1")
|
||||
async with asyncio.timeout(1.0):
|
||||
@@ -433,30 +312,20 @@ class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
||||
await queue.enqueue(e)
|
||||
|
||||
await scheduler._dispatch_all()
|
||||
dispatched = sum(1 for e in entries if e.future.done())
|
||||
self.assertEqual(dispatched, 0)
|
||||
self.assertEqual(sum(1 for e in entries if e.future.done()), 0)
|
||||
|
||||
# Free one slot → one request dispatched
|
||||
await slots.release("http://b1")
|
||||
await scheduler._dispatch_all()
|
||||
dispatched = sum(1 for e in entries if e.future.done())
|
||||
self.assertEqual(dispatched, 1)
|
||||
self.assertEqual(sum(1 for e in entries if e.future.done()), 1)
|
||||
|
||||
# Free second slot → another dispatched
|
||||
await slots.release("http://b1")
|
||||
await scheduler._dispatch_all()
|
||||
dispatched = sum(1 for e in entries if e.future.done())
|
||||
self.assertEqual(dispatched, 2)
|
||||
self.assertEqual(sum(1 for e in entries if e.future.done()), 2)
|
||||
|
||||
|
||||
class TestSessionCookieIntegration(unittest.TestCase):
|
||||
def test_session_cookie_set_in_response(self):
|
||||
"""Proxy sets X-Session-ID header and cookie on inference responses."""
|
||||
# No live backends → times out → 503, but we still get session tracking
|
||||
# We verify the session machinery works by checking that when the proxy
|
||||
# processes a request, it assigns a session and would attach it to the response.
|
||||
# For a full round-trip we'd need a live backend; here we verify the 503
|
||||
# path still doesn't crash the session handling.
|
||||
def test_session_timeout_returns_503_with_error_body(self):
|
||||
"""On slot timeout the 503 body is valid JSON with an error field."""
|
||||
cfg = ProxyConfig(
|
||||
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
||||
slot_wait_timeout=0.1,
|
||||
@@ -467,9 +336,7 @@ class TestSessionCookieIntegration(unittest.TestCase):
|
||||
"/v1/chat/completions",
|
||||
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
# Times out (no live backend) → 503
|
||||
self.assertEqual(resp.status_code, 503)
|
||||
# Even on timeout, no crash
|
||||
self.assertIn("error", resp.json())
|
||||
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestBackendRegistry(unittest.IsolatedAsyncioTestCase):
|
||||
poll_times = []
|
||||
|
||||
async def fake_poll_one(state):
|
||||
poll_times.append(asyncio.get_event_loop().time())
|
||||
poll_times.append(asyncio.get_running_loop().time())
|
||||
await asyncio.sleep(0.05)
|
||||
async with reg._lock:
|
||||
state.live = True
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestSessionStore(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_get_or_create_existing(self):
|
||||
store = SessionStore(ttl=300.0)
|
||||
s1 = await store.get_or_create("abc")
|
||||
await store.get_or_create("abc")
|
||||
await store.update("abc", model_id="llama3")
|
||||
s2 = await store.get_or_create("abc")
|
||||
self.assertEqual(s2.model_id, "llama3")
|
||||
@@ -95,6 +95,69 @@ class TestSessionStore(unittest.IsolatedAsyncioTestCase):
|
||||
store = SessionStore(ttl=300.0)
|
||||
await store.update("nope", model_id="m1") # should not raise
|
||||
|
||||
async def test_find_by_prefix_no_messages(self):
|
||||
store = SessionStore(ttl=300.0)
|
||||
result = await store.find_by_prefix([])
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_find_by_prefix_exact_match(self):
|
||||
"""Single-turn session is found when messages match exactly."""
|
||||
store = SessionStore(ttl=300.0)
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
await store.update("s1", messages=msgs, preferred_backend="http://b1")
|
||||
result = await store.find_by_prefix(msgs)
|
||||
self.assertEqual(result, "http://b1")
|
||||
|
||||
async def test_find_by_prefix_continuation(self):
|
||||
"""Turn-1 session is found when turn-2 messages extend the stored prefix."""
|
||||
store = SessionStore(ttl=300.0)
|
||||
turn1 = [{"role": "user", "content": "hello"}]
|
||||
turn2 = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "tell me more"},
|
||||
]
|
||||
await store.update("s1", messages=turn1, preferred_backend="http://b1")
|
||||
result = await store.find_by_prefix(turn2)
|
||||
self.assertEqual(result, "http://b1")
|
||||
|
||||
async def test_find_by_prefix_prefers_longest_match(self):
|
||||
"""When multiple sessions match, the one with the longer stored prefix wins."""
|
||||
store = SessionStore(ttl=300.0)
|
||||
turn1 = [{"role": "user", "content": "hello"}]
|
||||
turn2 = [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]
|
||||
turn3 = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "more"},
|
||||
]
|
||||
await store.update("s1", messages=turn1, preferred_backend="http://b1")
|
||||
await store.update("s2", messages=turn2, preferred_backend="http://b2")
|
||||
result = await store.find_by_prefix(turn3)
|
||||
self.assertEqual(result, "http://b2") # longer match wins
|
||||
|
||||
async def test_find_by_prefix_no_match(self):
|
||||
store = SessionStore(ttl=300.0)
|
||||
await store.update("s1", messages=[{"role": "user", "content": "hello"}], preferred_backend="http://b1")
|
||||
result = await store.find_by_prefix([{"role": "user", "content": "completely different"}])
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_find_by_prefix_ignores_expired(self):
|
||||
store = SessionStore(ttl=0.05)
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
await store.update("s1", messages=msgs, preferred_backend="http://b1")
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
result = await store.find_by_prefix(msgs)
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_find_by_prefix_ignores_session_without_backend(self):
|
||||
store = SessionStore(ttl=300.0)
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
await store.update("s1", messages=msgs) # no preferred_backend
|
||||
result = await store.find_by_prefix(msgs)
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_concurrent_access(self):
|
||||
store = SessionStore(ttl=300.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user