import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, patch from llamacpp_ha.config import BackendConfig, ProxyConfig from llamacpp_ha.registry import BackendRegistry, BackendState def _make_config(backends=None): if backends is None: backends = [{"url": "http://b1"}, {"url": "http://b2"}] return ProxyConfig(backends=[BackendConfig(**b) for b in backends]) class TestBackendRegistry(unittest.IsolatedAsyncioTestCase): def _make_registry(self, backends=None): cfg = _make_config(backends) return BackendRegistry(cfg) async def test_initial_state_all_dead(self): reg = self._make_registry() live = reg.get_all_live_backends() self.assertEqual(live, []) async def test_get_all_states_returns_all(self): reg = self._make_registry() states = reg.get_all_states() self.assertEqual(len(states), 2) urls = {s.url for s in states} self.assertEqual(urls, {"http://b1", "http://b2"}) async def test_liveness_transition(self): reg = self._make_registry([{"url": "http://b1"}]) state = reg.get_state("http://b1") self.assertFalse(state.live) # Simulate a successful poll by directly mutating state async with reg._lock: state.live = True state.models = ["llama3"] reg._rebuild_index() live = reg.get_all_live_backends() self.assertEqual(len(live), 1) async def test_model_index_rebuilt(self): reg = self._make_registry( [{"url": "http://b1"}, {"url": "http://b2"}] ) async with reg._lock: b1 = reg.get_state("http://b1") b2 = reg.get_state("http://b2") b1.live = True b1.models = ["m1", "m2"] b2.live = True b2.models = ["m2", "m3"] reg._rebuild_index() self.assertEqual(len(reg.get_backends_for_model("m1")), 1) self.assertEqual(len(reg.get_backends_for_model("m2")), 2) self.assertEqual(len(reg.get_backends_for_model("m3")), 1) async def test_get_all_models_deduplicated(self): reg = self._make_registry( [{"url": "http://b1"}, {"url": "http://b2"}] ) async with reg._lock: b1 = reg.get_state("http://b1") b2 = reg.get_state("http://b2") b1.live = True b1.models = ["m1", "shared"] b2.live = True b2.models = ["m2", "shared"] reg._rebuild_index() models = reg.get_all_models() self.assertEqual(len(models), 3) self.assertEqual(len(set(models)), 3) # no duplicates async def test_dead_backend_excluded_from_model_index(self): reg = self._make_registry( [{"url": "http://b1"}, {"url": "http://b2"}] ) async with reg._lock: b1 = reg.get_state("http://b1") b2 = reg.get_state("http://b2") b1.live = True b1.models = ["m1"] b2.live = False b2.models = ["m1"] reg._rebuild_index() backends = reg.get_backends_for_model("m1") self.assertEqual(len(backends), 1) self.assertEqual(backends[0].url, "http://b1") async def test_explicit_model_ids_used_over_backend_report(self): cfg = ProxyConfig( backends=[BackendConfig(url="http://b1", model_ids=["only-this"])] ) reg = BackendRegistry(cfg) # Simulate fetching models: should use config model_ids, not backend response state = reg.get_state("http://b1") mock_session = AsyncMock() reg._session = mock_session models = await reg._fetch_models(state, {}) self.assertEqual(models, ["only-this"]) async def test_last_poll_age_increases(self): import time reg = self._make_registry([{"url": "http://b1"}]) state = reg.get_state("http://b1") state.last_poll_time = time.monotonic() - 10 age = state.last_poll_age self.assertGreaterEqual(age, 9.9) async def test_no_backends_for_unknown_model(self): reg = self._make_registry() result = reg.get_backends_for_model("no-such-model") self.assertEqual(result, []) async def test_poll_all_concurrent(self): """poll_all runs backends concurrently; a slow backend doesn't block others.""" cfg = ProxyConfig( backends=[BackendConfig(url=f"http://b{i}") for i in range(5)] ) reg = BackendRegistry(cfg) poll_times = [] async def fake_poll_one(state): poll_times.append(asyncio.get_running_loop().time()) await asyncio.sleep(0.05) async with reg._lock: state.live = True state.models = ["m1"] with patch.object(reg, "_poll_one", side_effect=fake_poll_one): import time start = time.monotonic() await reg._poll_all() elapsed = time.monotonic() - start # 5 backends * 0.05s each, but concurrent: should finish in ~0.1s not 0.25s self.assertLess(elapsed, 0.2) self.assertEqual(len(reg.get_all_live_backends()), 5) if __name__ == "__main__": unittest.main()