Files
llamacpp-ha/tests/test_registry.py
2026-05-17 22:15:13 +02:00

153 lines
5.2 KiB
Python

import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.registry import BackendRegistry, BackendState
def _make_config(backends=None):
if backends is None:
backends = [{"url": "http://b1"}, {"url": "http://b2"}]
return ProxyConfig(backends=[BackendConfig(**b) for b in backends])
class TestBackendRegistry(unittest.IsolatedAsyncioTestCase):
def _make_registry(self, backends=None):
cfg = _make_config(backends)
return BackendRegistry(cfg)
async def test_initial_state_all_dead(self):
reg = self._make_registry()
live = reg.get_all_live_backends()
self.assertEqual(live, [])
async def test_get_all_states_returns_all(self):
reg = self._make_registry()
states = reg.get_all_states()
self.assertEqual(len(states), 2)
urls = {s.url for s in states}
self.assertEqual(urls, {"http://b1", "http://b2"})
async def test_liveness_transition(self):
reg = self._make_registry([{"url": "http://b1"}])
state = reg.get_state("http://b1")
self.assertFalse(state.live)
# Simulate a successful poll by directly mutating state
async with reg._lock:
state.live = True
state.models = ["llama3"]
reg._rebuild_index()
live = reg.get_all_live_backends()
self.assertEqual(len(live), 1)
async def test_model_index_rebuilt(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1", "m2"]
b2.live = True
b2.models = ["m2", "m3"]
reg._rebuild_index()
self.assertEqual(len(reg.get_backends_for_model("m1")), 1)
self.assertEqual(len(reg.get_backends_for_model("m2")), 2)
self.assertEqual(len(reg.get_backends_for_model("m3")), 1)
async def test_get_all_models_deduplicated(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1", "shared"]
b2.live = True
b2.models = ["m2", "shared"]
reg._rebuild_index()
models = reg.get_all_models()
self.assertEqual(len(models), 3)
self.assertEqual(len(set(models)), 3) # no duplicates
async def test_dead_backend_excluded_from_model_index(self):
reg = self._make_registry(
[{"url": "http://b1"}, {"url": "http://b2"}]
)
async with reg._lock:
b1 = reg.get_state("http://b1")
b2 = reg.get_state("http://b2")
b1.live = True
b1.models = ["m1"]
b2.live = False
b2.models = ["m1"]
reg._rebuild_index()
backends = reg.get_backends_for_model("m1")
self.assertEqual(len(backends), 1)
self.assertEqual(backends[0].url, "http://b1")
async def test_explicit_model_ids_used_over_backend_report(self):
cfg = ProxyConfig(
backends=[BackendConfig(url="http://b1", model_ids=["only-this"])]
)
reg = BackendRegistry(cfg)
# Simulate fetching models: should use config model_ids, not backend response
state = reg.get_state("http://b1")
mock_session = AsyncMock()
reg._session = mock_session
models = await reg._fetch_models(state, {})
self.assertEqual(models, ["only-this"])
async def test_last_poll_age_increases(self):
import time
reg = self._make_registry([{"url": "http://b1"}])
state = reg.get_state("http://b1")
state.last_poll_time = time.monotonic() - 10
age = state.last_poll_age
self.assertGreaterEqual(age, 9.9)
async def test_no_backends_for_unknown_model(self):
reg = self._make_registry()
result = reg.get_backends_for_model("no-such-model")
self.assertEqual(result, [])
async def test_poll_all_concurrent(self):
"""poll_all runs backends concurrently; a slow backend doesn't block others."""
cfg = ProxyConfig(
backends=[BackendConfig(url=f"http://b{i}") for i in range(5)]
)
reg = BackendRegistry(cfg)
poll_times = []
async def fake_poll_one(state):
poll_times.append(asyncio.get_running_loop().time())
await asyncio.sleep(0.05)
async with reg._lock:
state.live = True
state.models = ["m1"]
with patch.object(reg, "_poll_one", side_effect=fake_poll_one):
import time
start = time.monotonic()
await reg._poll_all()
elapsed = time.monotonic() - start
# 5 backends * 0.05s each, but concurrent: should finish in ~0.1s not 0.25s
self.assertLess(elapsed, 0.2)
self.assertEqual(len(reg.get_all_live_backends()), 5)
if __name__ == "__main__":
unittest.main()