345 lines
13 KiB
Python
345 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import unittest
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.proxy import create_app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _live_backend_config(url: str, models: list[str] | None = None) -> BackendConfig:
|
|
return BackendConfig(url=url, model_ids=models or ["test-model"])
|
|
|
|
|
|
def _proxy_config(backend_urls: list[str], api_keys: list[str] | None = None) -> ProxyConfig:
|
|
return ProxyConfig(
|
|
host="127.0.0.1",
|
|
port=8080,
|
|
api_keys=api_keys or [],
|
|
poll_interval=9999, # disable background polling during tests
|
|
slot_wait_timeout=1.0,
|
|
session_idle_ttl=300.0,
|
|
default_slot_capacity=2,
|
|
backends=[_live_backend_config(url) for url in backend_urls],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration test cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHealthEndpoint(unittest.TestCase):
|
|
def test_health_no_live_backend_returns_503(self):
|
|
cfg = _proxy_config([])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/health")
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
def test_health_with_api_key_required(self):
|
|
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/health")
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_health_with_valid_api_key(self):
|
|
cfg = _proxy_config([], api_keys=["mykey"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/health", headers={"Authorization": "Bearer mykey"})
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestModelsEndpoint(unittest.TestCase):
|
|
def test_models_returns_list_format(self):
|
|
"""GET /v1/models returns the OpenAI list envelope even with no live backends."""
|
|
cfg = _proxy_config(["http://b1"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/v1/models")
|
|
self.assertEqual(resp.status_code, 200)
|
|
data = resp.json()
|
|
self.assertEqual(data["object"], "list")
|
|
self.assertIsInstance(data["data"], list)
|
|
|
|
# Model deduplication is tested at the registry level in
|
|
# test_registry.py::test_get_all_models_deduplicated.
|
|
|
|
|
|
class TestApiKeyMiddlewareIntegration(unittest.TestCase):
|
|
def test_request_rejected_without_key(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.post("/v1/chat/completions", json={"model": "m", "messages": []})
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_monitor_exempt_from_auth(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/monitor")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_monitor_data_exempt_from_auth(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/monitor/data")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_monitor_trailing_slash_exempt_from_auth(self):
|
|
cfg = _proxy_config([], api_keys=["secret"])
|
|
app = create_app(cfg)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
resp = client.get("/monitor/")
|
|
# The trailing-slash path is auth-exempt; it may hit the catch-all (503)
|
|
# if the router doesn't redirect it, but it must never return 401.
|
|
self.assertNotEqual(resp.status_code, 401)
|
|
|
|
|
|
class TestSlotExhaustionTimeout(unittest.TestCase):
|
|
def test_timeout_when_no_slot_available(self):
|
|
"""Request with no live backends times out and returns 503."""
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.2,
|
|
default_slot_capacity=1,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestCatchAllForwarding(unittest.TestCase):
|
|
def test_catch_all_no_backends_returns_503(self):
|
|
cfg = _proxy_config([])
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.get("/some/unknown/path")
|
|
self.assertEqual(resp.status_code, 503)
|
|
|
|
|
|
class TestMonitorIntegration(unittest.TestCase):
|
|
def test_monitor_data_reflects_queue_depth(self):
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.1,
|
|
default_slot_capacity=0,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
data = client.get("/monitor/data").json()
|
|
self.assertIsInstance(data["queue_depth"], int)
|
|
self.assertGreaterEqual(data["queue_depth"], 0)
|
|
|
|
def test_monitor_data_shows_all_backends(self):
|
|
cfg = _proxy_config(["http://b1", "http://b2", "http://b3"])
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
data = client.get("/monitor/data").json()
|
|
self.assertEqual(len(data["backends"]), 3)
|
|
urls = {b["url"] for b in data["backends"]}
|
|
self.assertEqual(urls, {"http://b1", "http://b2", "http://b3"})
|
|
|
|
|
|
class TestFullForwardPath(unittest.TestCase):
|
|
def test_backend_initially_dead_no_poll(self):
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://fake", model_ids=["test-model"])],
|
|
slot_wait_timeout=2.0,
|
|
default_slot_capacity=2,
|
|
poll_interval=9999,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.get("/monitor/data")
|
|
data = resp.json()
|
|
self.assertEqual(len(data["backends"]), 1)
|
|
self.assertFalse(data["backends"][0]["live"])
|
|
|
|
|
|
class TestBackendFailover(unittest.IsolatedAsyncioTestCase):
|
|
async def test_dead_backend_not_in_model_index(self):
|
|
"""Dead backends are not returned for model routing."""
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["m"]),
|
|
BackendConfig(url="http://b2", model_ids=["m"]),
|
|
]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._states["http://b2"].live = False
|
|
reg._rebuild_index()
|
|
|
|
backends = reg.get_backends_for_model("m")
|
|
self.assertEqual(len(backends), 1)
|
|
self.assertEqual(backends[0].url, "http://b1")
|
|
|
|
def test_all_dead_returns_empty(self):
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
|
reg = BackendRegistry(cfg)
|
|
backends = reg.get_backends_for_model("m")
|
|
self.assertEqual(backends, [])
|
|
|
|
|
|
class TestModelRouting(unittest.IsolatedAsyncioTestCase):
|
|
async def test_routes_to_backend_serving_model(self):
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(
|
|
backends=[
|
|
BackendConfig(url="http://b1", model_ids=["model-a"]),
|
|
BackendConfig(url="http://b2", model_ids=["model-b"]),
|
|
]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["model-a"]
|
|
reg._states["http://b2"].live = True
|
|
reg._states["http://b2"].models = ["model-b"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 2)
|
|
slots.set_capacity("http://b2", 2)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
loop = asyncio.get_running_loop()
|
|
entry = QueueEntry(model_id="model-b", future=loop.create_future())
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b2")
|
|
|
|
|
|
class TestSlotExhaustionAndUnblock(unittest.IsolatedAsyncioTestCase):
|
|
async def test_queued_request_dispatched_after_slot_release(self):
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 1)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
entry = QueueEntry(model_id="m", future=loop.create_future())
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertFalse(entry.future.done())
|
|
|
|
await slots.release("http://b1")
|
|
scheduler.notify_slot_released()
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_multiple_queued_requests_dispatched_as_slots_free(self):
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m"])])
|
|
reg = BackendRegistry(cfg)
|
|
async with reg._lock:
|
|
reg._states["http://b1"].live = True
|
|
reg._states["http://b1"].models = ["m"]
|
|
reg._rebuild_index()
|
|
|
|
slots = SlotTracker()
|
|
slots.set_capacity("http://b1", 2)
|
|
sessions = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(queue, reg, slots, sessions, RoundRobinPolicy())
|
|
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
entries = [QueueEntry(model_id="m", future=loop.create_future()) for _ in range(3)]
|
|
for e in entries:
|
|
await queue.enqueue(e)
|
|
|
|
await scheduler._dispatch_all()
|
|
self.assertEqual(sum(1 for e in entries if e.future.done()), 0)
|
|
|
|
await slots.release("http://b1")
|
|
await scheduler._dispatch_all()
|
|
self.assertEqual(sum(1 for e in entries if e.future.done()), 1)
|
|
|
|
await slots.release("http://b1")
|
|
await scheduler._dispatch_all()
|
|
self.assertEqual(sum(1 for e in entries if e.future.done()), 2)
|
|
|
|
|
|
class TestSessionCookieIntegration(unittest.TestCase):
|
|
def test_session_timeout_returns_503_with_error_body(self):
|
|
"""On slot timeout the 503 body is valid JSON with an error field."""
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["m"])],
|
|
slot_wait_timeout=0.1,
|
|
)
|
|
app = create_app(cfg)
|
|
with TestClient(app, raise_server_exceptions=False) as client:
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
self.assertEqual(resp.status_code, 503)
|
|
self.assertIn("error", resp.json())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|