478 lines
18 KiB
Python
478 lines
18 KiB
Python
"""
|
|
Integration tests using an in-process fake llama.cpp backend.
|
|
|
|
The fake backend runs as a FastAPI app via httpx.AsyncClient(transport=...).
|
|
We patch aiohttp calls in the registry/forwarder so that requests go to our
|
|
fake server without opening real TCP sockets.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import unittest
|
|
from contextlib import asynccontextmanager
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiohttp
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
from fastapi.testclient import TestClient
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.proxy import create_app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers to build a minimal proxy with pre-seeded live backends
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _live_backend_config(url: str, models: list[str] | None = None) -> BackendConfig:
|
|
return BackendConfig(url=url, model_ids=models or ["test-model"])
|
|
|
|
|
|
def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) -> ProxyConfig:
|
|
return ProxyConfig(
|
|
host="127.0.0.1",
|
|
port=8080,
|
|
api_keys=api_keys or [],
|
|
poll_interval=9999, # disable background polling during tests
|
|
slot_wait_timeout=1.0,
|
|
session_idle_ttl=300.0,
|
|
default_slot_capacity=2,
|
|
backends=[_live_backend_config(url) for url in backend_urls],
|
|
)
|
|
|
|
|
|
def _seed_live(app_obj, backend_urls: list[str], models=None) -> None:
|
|
"""Directly seed the registry with live backends (skipping real HTTP poll)."""
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
# Access the app's state via lifespan-created objects
|
|
# We do this by patching _poll_all to be a no-op and manually setting state
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fake llama.cpp server (in-process, no real sockets)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_fake_backend(
|
|
model_id: str = "test-model",
|
|
response_content: str = '{"choices":[{"message":{"role":"assistant","content":"hello"}}]}',
|
|
streaming: bool = False,
|
|
slow_seconds: float = 0,
|
|
slot_count: int = 2,
|
|
) -> FastAPI:
|
|
fake = FastAPI()
|
|
|
|
@fake.get("/health")
|
|
async def health():
|
|
if slow_seconds:
|
|
await asyncio.sleep(slow_seconds)
|
|
return Response(status_code=200)
|
|
|
|
@fake.get("/v1/models")
|
|
async def models():
|
|
return JSONResponse({"object": "list", "data": [{"id": model_id, "object": "model"}]})
|
|
|
|
@fake.get("/slots")
|
|
async def slots():
|
|
return JSONResponse([{"id": i, "state": 0} for i in range(slot_count)])
|
|
|
|
@fake.post("/v1/chat/completions")
|
|
async def chat(request: Request):
|
|
if slow_seconds:
|
|
await asyncio.sleep(slow_seconds)
|
|
body = await request.json()
|
|
stream = body.get("stream", False)
|
|
if stream or streaming:
|
|
async def gen():
|
|
yield b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n'
|
|
yield b"data: [DONE]\n\n"
|
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
|
return JSONResponse(json.loads(response_content))
|
|
|
|
@fake.post("/v1/completions")
|
|
async def completions(request: Request):
|
|
return JSONResponse({"choices": [{"text": "hello"}]})
|
|
|
|
return fake
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration test cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHealthEndpoint(unittest.TestCase):
|
|
def _make_proxy_with_live_backend(self):
|
|
cfg = _proxy_config(["http://fake-b1"])
|
|
app = create_app(cfg)
|
|
|
|
# Seed registry directly before first request
|
|
def _patch_registry(app_obj):
|
|
pass
|
|
|
|
return app, cfg
|
|
|
|
def test_health_no_live_backend_returns_503(self):
|
|
cfg = _proxy_config([])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/health")
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
def test_health_with_api_key_required(self):
|
|
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/health")
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_health_with_valid_api_key(self):
|
|
cfg = _proxy_config([], api_keys=["mykey"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
# No live backends -> 503 but authenticated
|
|
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestModelsEndpoint(unittest.TestCase):
|
|
def test_models_returns_list_format(self):
|
|
cfg = _proxy_config(["http://b1"])
|
|
app = create_app(cfg)
|
|
# Seed registry
|
|
from llamacpp_ha.registry import BackendState
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/v1/models")
|
|
self.assertEqual(resp.status_code, 200)
|
|
data = resp.json()
|
|
self.assertEqual(data["object"], "list")
|
|
self.assertIn("data", data)
|
|
|
|
def test_models_deduplication(self):
|
|
"""Models appearing on multiple backends appear once."""
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["shared", "only-b1"]),
|
|
BackendConfig(url="http://b2", model_ids=["shared", "only-b2"]),
|
|
]
|
|
)
|
|
app = create_app(cfg)
|
|
from llamacpp_ha.registry import BackendState
|
|
# Manually seed live
|
|
import asyncio
|
|
|
|
async def _seed():
|
|
from llamacpp_ha import proxy as _proxy_module
|
|
# Can't easily seed without internal access in this test structure
|
|
pass
|
|
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/v1/models")
|
|
data = resp.json()
|
|
model_ids = [m["id"] for m in data["data"]]
|
|
self.assertEqual(len(model_ids), len(set(model_ids)), "Duplicates found")
|
|
|
|
|
|
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
|
|
def test_request_rejected_without_key(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.post("/v1/chat/completions", json={"model": "m", "messages": []})
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_monitor_exempt_from_auth(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/monitor")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_monitor_data_exempt_from_auth(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/monitor/data")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
|
|
class TestSlotExhaustionTimeout(unittest.TestCase):
|
|
def test_timeout_when_no_slot_available(self):
|
|
"""Request with no live backends times out and returns 503."""
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.2,
|
|
default_slot_capacity=1,
|
|
)
|
|
app = create_app(cfg)
|
|
# No live backends -> scheduler never dispatches -> timeout -> 503
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestCatchAllForwarding(unittest.TestCase):
|
|
def test_catch_all_no_backends_returns_503(self):
|
|
cfg = _proxy_config([])
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.get("/some/unknown/path")
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestMonitorIntegration(unittest.TestCase):
|
|
def test_monitor_data_reflects_queue_depth(self):
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.1,
|
|
default_slot_capacity=0,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
data = client.get("/monitor/data").json()
|
|
self.assertIsInstance(data["queue_depth"], int)
|
|
self.assertGreaterEqual(data["queue_depth"], 0)
|
|
|
|
def test_monitor_data_shows_all_backends(self):
|
|
cfg = _proxy_config(["http://b1", "http://b2", "http://b3"])
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
data = client.get("/monitor/data").json()
|
|
self.assertEqual(len(data["backends"]), 3)
|
|
urls = {b["url"] for b in data["backends"]}
|
|
self.assertEqual(urls, {"http://b1", "http://b2", "http://b3"})
|
|
|
|
|
|
class TestFullForwardPath(unittest.TestCase):
|
|
"""
|
|
Full path test: proxy starts up, monitor/data reflects backend state
|
|
(dead until polled), scheduler queues requests when no backends are live.
|
|
"""
|
|
|
|
def test_backend_initially_dead_no_poll(self):
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
|
|
slot_wait_timeout=2.0,
|
|
default_slot_capacity=2,
|
|
poll_interval=9999, # no background polling in test
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.get("/monitor/data")
|
|
data = resp.json()
|
|
# Backend exists but is dead until first poll completes
|
|
self.assertEqual(len(data["backends"]), 1)
|
|
self.assertFalse(data["backends"][0]["live"])
|
|
|
|
|
|
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
|
async def test_dead_backend_not_in_model_index(self):
|
|
"""Dead backends are not returned for model routing."""
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.registry import BackendRegistry, BackendState
|
|
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["m"]),
|
|
BackendConfig(url="http://b2", model_ids=["m"]),
|
|
]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
|
|
# Simulate: b1 live, b2 dead
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._states["http://b2"].live = False
|
|
reg._rebuild_index()
|
|
|
|
backends = reg.get_backends_for_model("m")
|
|
self.assertEqual(len(backends), 1)
|
|
self.assertEqual(backends[0].url, "http://b1")
|
|
|
|
async def test_all_dead_returns_empty(self):
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["m"]),
|
|
]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
backends = reg.get_backends_for_model("m")
|
|
self.assertEqual(backends, [])
|
|
|
|
|
|
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
|
|
async def test_routes_to_backend_serving_model(self):
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["model-a"]),
|
|
BackendConfig(url="http://b2", model_ids=["model-b"]),
|
|
]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["model-a"]
|
|
reg._states["http://b2"].live = True
|
|
reg._states["http://b2"].models = ["model-b"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 2)
|
|
slots.set_capacity("http://b2", 2)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
loop = asyncio.get_running_loop()
|
|
entry = QueueEntry(model_id="model-b", future=loop.create_future())
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b2")
|
|
|
|
|
|
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
|
"""
|
|
Slot is fully occupied → request queues → slot released → request dispatched.
|
|
Tests the complete wait-and-unblock path without involving real HTTP.
|
|
"""
|
|
|
|
async def test_queued_request_dispatched_after_slot_release(self):
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 1)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
# Occupy the only slot (simulating an in-flight request)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
# Enqueue a request — scheduler should not be able to dispatch
|
|
loop = asyncio.get_running_loop()
|
|
entry = QueueEntry(model_id="m", future=loop.create_future())
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertFalse(entry.future.done(), "Should be queued, not dispatched yet")
|
|
|
|
# Release the slot (simulates a prior request completing)
|
|
await slots.release("http://b1")
|
|
scheduler.notify_slot_released()
|
|
|
|
# Scheduler re-evaluates on next wakeup
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done(), "Should be dispatched after slot freed")
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 2)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
# Fill both slots
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
entries = [QueueEntry(model_id="m", future=loop.create_future()) for _ in range(3)]
|
|
for e in entries:
|
|
await queue.enqueue(e)
|
|
|
|
await scheduler._dispatch_all()
|
|
dispatched = sum(1 for e in entries if e.future.done())
|
|
self.assertEqual(dispatched, 0)
|
|
|
|
# Free one slot → one request dispatched
|
|
await slots.release("http://b1")
|
|
await scheduler._dispatch_all()
|
|
dispatched = sum(1 for e in entries if e.future.done())
|
|
self.assertEqual(dispatched, 1)
|
|
|
|
# Free second slot → another dispatched
|
|
await slots.release("http://b1")
|
|
await scheduler._dispatch_all()
|
|
dispatched = sum(1 for e in entries if e.future.done())
|
|
self.assertEqual(dispatched, 2)
|
|
|
|
|
|
class TestSessionCookieIntegration(unittest.TestCase):
|
|
def test_session_cookie_set_in_response(self):
|
|
"""Proxy sets X-Session-ID header and cookie on inference responses."""
|
|
# No live backends → times out → 503, but we still get session tracking
|
|
# We verify the session machinery works by checking that when the proxy
|
|
# processes a request, it assigns a session and would attach it to the response.
|
|
# For a full round-trip we'd need a live backend; here we verify the 503
|
|
# path still doesn't crash the session handling.
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.1,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
# Times out (no live backend) → 503
|
|
self.assertEqual(resp.status_code, 503)
|
|
# Even on timeout, no crash
|
|
self.assertIn("error", resp.json())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|