Files
llamacpp-ha/tests/test_integration.py
2026-05-19 21:21:42 +02:00

346 lines
13 KiB
Python

from __future__ import annotations
import asyncio
import unittest
from fastapi.testclient import TestClient
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.proxy import create_app
# ---------------------------------------------------------------------------
# Config helpers
# ---------------------------------------------------------------------------
def _live_backend_config(url: str, models: list[str] | None = None) -> BackendConfig:
return BackendConfig(url=url, model_ids=models or ["test-model"])
def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) -> ProxyConfig:
return ProxyConfig(
host="127.0.0.1",
port=8080,
api_keys=api_keys or [],
poll_interval=9999, # disable background polling during tests
slot_wait_timeout=1.0,
session_idle_ttl=300.0,
default_slot_capacity=2,
backends=[_live_backend_config(url) for url in backend_urls],
)
# ---------------------------------------------------------------------------
# Integration test cases
# ---------------------------------------------------------------------------
class TestHealthEndpoint(unittest.TestCase):
def test_health_no_live_backend_returns_503(self):
cfg = _proxy_config([])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
self.assertEqual(resp.status_code, 503)
def test_health_exempt_from_auth(self):
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
# /health is auth-exempt; no live backends so 503, never 401
self.assertNotEqual(resp.status_code, 401)
def test_health_with_valid_api_key(self):
cfg = _proxy_config([], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
self.assertEqual(resp.status_code, 503)
class TestModelsEndpoint(unittest.TestCase):
def test_models_returns_list_format(self):
"""GET /v1/models returns the OpenAI list envelope even with no live backends."""
cfg = _proxy_config(["http://b1"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/v1/models")
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["object"], "list")
self.assertIsInstance(data["data"], list)
# Model deduplication is tested at the registry level in
# test_registry.py::test_get_all_models_deduplicated.
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
def test_request_rejected_without_key(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.post("/v1/chat/completions", json={"model": "m", "messages": []})
self.assertEqual(resp.status_code, 401)
def test_monitor_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor")
self.assertEqual(resp.status_code, 200)
def test_monitor_data_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor/data")
self.assertEqual(resp.status_code, 200)
def test_monitor_trailing_slash_exempt_from_auth(self):
cfg = _proxy_config([], api_keys=["secret"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/monitor/")
# The trailing-slash path is auth-exempt; it may hit the catch-all (503)
# if the router doesn't redirect it, but it must never return 401.
self.assertNotEqual(resp.status_code, 401)
class TestSlotExhaustionTimeout(unittest.TestCase):
def test_timeout_when_no_slot_available(self):
"""Request with no live backends times out and returns 503."""
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.2,
default_slot_capacity=1,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/v1/chat/completions",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
self.assertEqual(resp.status_code, 503)
class TestCatchAllForwarding(unittest.TestCase):
def test_catch_all_no_backends_returns_503(self):
cfg = _proxy_config([])
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/some/unknown/path")
self.assertEqual(resp.status_code, 503)
class TestMonitorIntegration(unittest.TestCase):
def test_monitor_data_reflects_queue_depth(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.1,
default_slot_capacity=0,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
data = client.get("/monitor/data").json()
self.assertIsInstance(data["queue_depth"], int)
self.assertGreaterEqual(data["queue_depth"], 0)
def test_monitor_data_shows_all_backends(self):
cfg = _proxy_config(["http://b1", "http://b2", "http://b3"])
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
data = client.get("/monitor/data").json()
self.assertEqual(len(data["backends"]), 3)
urls = {b["url"] for b in data["backends"]}
self.assertEqual(urls, {"http://b1", "http://b2", "http://b3"})
class TestFullForwardPath(unittest.TestCase):
def test_backend_initially_dead_no_poll(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
slot_wait_timeout=2.0,
default_slot_capacity=2,
poll_interval=9999,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/monitor/data")
data = resp.json()
self.assertEqual(len(data["backends"]), 1)
self.assertFalse(data["backends"][0]["live"])
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
async def test_dead_backend_not_in_model_index(self):
"""Dead backends are not returned for model routing."""
from llamacpp_ha.registry import BackendRegistry
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["m"]),
BackendConfig(url="http://b2", model_ids=["m"]),
]
)
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._states["http://b2"].live = False
reg._rebuild_index()
backends = reg.get_backends_for_model("m")
self.assertEqual(len(backends), 1)
self.assertEqual(backends[0].url, "http://b1")
def test_all_dead_returns_empty(self):
from llamacpp_ha.registry import BackendRegistry
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
backends = reg.get_backends_for_model("m")
self.assertEqual(backends, [])
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
async def test_routes_to_backend_serving_model(self):
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(
backends=[
BackendConfig(url="http://b1", model_ids=["model-a"]),
BackendConfig(url="http://b2", model_ids=["model-b"]),
]
)
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["model-a"]
reg._states["http://b2"].live = True
reg._states["http://b2"].models = ["model-b"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 2)
slots.set_capacity("http://b2", 2)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
loop = asyncio.get_running_loop()
entry = QueueEntry(model_id="model-b", future=loop.create_future())
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b2")
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
async def test_queued_request_dispatched_after_slot_release(self):
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 1)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
loop = asyncio.get_running_loop()
entry = QueueEntry(model_id="m", future=loop.create_future())
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
await slots.release("http://b1")
scheduler.notify_slot_released()
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
reg = BackendRegistry(cfg)
async with reg._lock:
reg._states["http://b1"].live = True
reg._states["http://b1"].models = ["m"]
reg._rebuild_index()
slots = SlotTracker()
slots.set_capacity("http://b1", 2)
sessions = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
loop = asyncio.get_running_loop()
entries = [QueueEntry(model_id="m", future=loop.create_future()) for _ in range(3)]
for e in entries:
await queue.enqueue(e)
await scheduler._dispatch_all()
self.assertEqual(sum(1 for e in entries if e.future.done()), 0)
await slots.release("http://b1")
await scheduler._dispatch_all()
self.assertEqual(sum(1 for e in entries if e.future.done()), 1)
await slots.release("http://b1")
await scheduler._dispatch_all()
self.assertEqual(sum(1 for e in entries if e.future.done()), 2)
class TestSessionCookieIntegration(unittest.TestCase):
def test_session_timeout_returns_503_with_error_body(self):
"""On slot timeout the 503 body is valid JSON with an error field."""
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
slot_wait_timeout=0.1,
)
app = create_app(cfg)
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/v1/chat/completions",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
self.assertEqual(resp.status_code, 503)
self.assertIn("error", resp.json())
if __name__ == "__main__":
unittest.main()