153 lines
5.2 KiB
Python
153 lines
5.2 KiB
Python
import asyncio
|
|
import unittest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.registry import BackendRegistry, BackendState
|
|
|
|
|
|
def _make_config(backends=None):
|
|
if backends is None:
|
|
backends = [{"url": "http://b1"}, {"url": "http://b2"}]
|
|
return ProxyConfig(backends=[BackendConfig(**b) for b in backends])
|
|
|
|
|
|
class TestBackendRegistry(unittest.IsolatedAsyncioTestCase):
|
|
def _make_registry(self, backends=None):
|
|
cfg = _make_config(backends)
|
|
return BackendRegistry(cfg)
|
|
|
|
async def test_initial_state_all_dead(self):
|
|
reg = self._make_registry()
|
|
live = reg.get_all_live_backends()
|
|
self.assertEqual(live, [])
|
|
|
|
async def test_get_all_states_returns_all(self):
|
|
reg = self._make_registry()
|
|
states = reg.get_all_states()
|
|
self.assertEqual(len(states), 2)
|
|
urls = {s.url for s in states}
|
|
self.assertEqual(urls, {"http://b1", "http://b2"})
|
|
|
|
async def test_liveness_transition(self):
|
|
reg = self._make_registry([{"url": "http://b1"}])
|
|
state = reg.get_state("http://b1")
|
|
self.assertFalse(state.live)
|
|
|
|
# Simulate a successful poll by directly mutating state
|
|
async with reg._lock:
|
|
state.live = True
|
|
state.models = ["llama3"]
|
|
reg._rebuild_index()
|
|
|
|
live = reg.get_all_live_backends()
|
|
self.assertEqual(len(live), 1)
|
|
|
|
async def test_model_index_rebuilt(self):
|
|
reg = self._make_registry(
|
|
[{"url": "http://b1"}, {"url": "http://b2"}]
|
|
)
|
|
async with reg._lock:
|
|
b1 = reg.get_state("http://b1")
|
|
b2 = reg.get_state("http://b2")
|
|
b1.live = True
|
|
b1.models = ["m1", "m2"]
|
|
b2.live = True
|
|
b2.models = ["m2", "m3"]
|
|
reg._rebuild_index()
|
|
|
|
self.assertEqual(len(reg.get_backends_for_model("m1")), 1)
|
|
self.assertEqual(len(reg.get_backends_for_model("m2")), 2)
|
|
self.assertEqual(len(reg.get_backends_for_model("m3")), 1)
|
|
|
|
async def test_get_all_models_deduplicated(self):
|
|
reg = self._make_registry(
|
|
[{"url": "http://b1"}, {"url": "http://b2"}]
|
|
)
|
|
async with reg._lock:
|
|
b1 = reg.get_state("http://b1")
|
|
b2 = reg.get_state("http://b2")
|
|
b1.live = True
|
|
b1.models = ["m1", "shared"]
|
|
b2.live = True
|
|
b2.models = ["m2", "shared"]
|
|
reg._rebuild_index()
|
|
|
|
models = reg.get_all_models()
|
|
self.assertEqual(len(models), 3)
|
|
self.assertEqual(len(set(models)), 3) # no duplicates
|
|
|
|
async def test_dead_backend_excluded_from_model_index(self):
|
|
reg = self._make_registry(
|
|
[{"url": "http://b1"}, {"url": "http://b2"}]
|
|
)
|
|
async with reg._lock:
|
|
b1 = reg.get_state("http://b1")
|
|
b2 = reg.get_state("http://b2")
|
|
b1.live = True
|
|
b1.models = ["m1"]
|
|
b2.live = False
|
|
b2.models = ["m1"]
|
|
reg._rebuild_index()
|
|
|
|
backends = reg.get_backends_for_model("m1")
|
|
self.assertEqual(len(backends), 1)
|
|
self.assertEqual(backends[0].url, "http://b1")
|
|
|
|
async def test_explicit_model_ids_used_over_backend_report(self):
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url="http://b1", model_ids=["only-this"])]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
|
|
# Simulate fetching models: should use config model_ids, not backend response
|
|
state = reg.get_state("http://b1")
|
|
mock_session = AsyncMock()
|
|
reg._session = mock_session
|
|
|
|
models = await reg._fetch_models(state, {})
|
|
self.assertEqual(models, ["only-this"])
|
|
|
|
async def test_last_poll_age_increases(self):
|
|
import time
|
|
reg = self._make_registry([{"url": "http://b1"}])
|
|
state = reg.get_state("http://b1")
|
|
state.last_poll_time = time.monotonic() - 10
|
|
age = state.last_poll_age
|
|
self.assertGreaterEqual(age, 9.9)
|
|
|
|
async def test_no_backends_for_unknown_model(self):
|
|
reg = self._make_registry()
|
|
result = reg.get_backends_for_model("no-such-model")
|
|
self.assertEqual(result, [])
|
|
|
|
async def test_poll_all_concurrent(self):
|
|
"""poll_all runs backends concurrently; a slow backend doesn't block others."""
|
|
cfg = ProxyConfig(
|
|
backends=[BackendConfig(url=f"http://b{i}") for i in range(5)]
|
|
)
|
|
reg = BackendRegistry(cfg)
|
|
|
|
poll_times = []
|
|
|
|
async def fake_poll_one(state):
|
|
poll_times.append(asyncio.get_running_loop().time())
|
|
await asyncio.sleep(0.05)
|
|
async with reg._lock:
|
|
state.live = True
|
|
state.models = ["m1"]
|
|
|
|
with patch.object(reg, "_poll_one", side_effect=fake_poll_one):
|
|
import time
|
|
start = time.monotonic()
|
|
await reg._poll_all()
|
|
elapsed = time.monotonic() - start
|
|
|
|
# 5 backends * 0.05s each, but concurrent: should finish in ~0.1s not 0.25s
|
|
self.assertLess(elapsed, 0.2)
|
|
self.assertEqual(len(reg.get_all_live_backends()), 5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|