321 lines
12 KiB
Python
321 lines
12 KiB
Python
import asyncio
|
|
import unittest
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry, BackendState
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
|
|
def _make_state(url: str, models: list[str] | None = None) -> BackendState:
|
|
cfg = BackendConfig(url=url)
|
|
return BackendState(config=cfg, live=True, models=models or ["m1"])
|
|
|
|
|
|
def _entry(**kwargs) -> QueueEntry:
|
|
e = QueueEntry(**kwargs)
|
|
e.future = asyncio.get_running_loop().create_future()
|
|
return e
|
|
|
|
|
|
class TestScheduler(unittest.IsolatedAsyncioTestCase):
|
|
def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0):
|
|
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
|
|
registry = BackendRegistry(cfg)
|
|
for state in (live_backends or []):
|
|
registry._states[state.url] = state
|
|
registry._rebuild_index()
|
|
|
|
slot_tracker = SlotTracker()
|
|
for state in (live_backends or []):
|
|
slot_tracker.set_capacity(state.url, 2)
|
|
|
|
session_store = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(
|
|
queue=queue,
|
|
registry=registry,
|
|
slot_tracker=slot_tracker,
|
|
session_store=session_store,
|
|
policy=RoundRobinPolicy(),
|
|
max_queue_skip=max_queue_skip,
|
|
)
|
|
return scheduler, queue, registry, slot_tracker, session_store
|
|
|
|
async def test_dispatches_entry_to_live_backend(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_skips_full_backends(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertFalse(entry.future.done())
|
|
|
|
async def test_session_affinity_preferred(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2])
|
|
|
|
await sessions.get_or_create("sess1")
|
|
await sessions.update("sess1", preferred_backend="http://b2")
|
|
|
|
entry = _entry(model_id="m1", session_id="sess1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b2")
|
|
|
|
async def test_session_affinity_fallback_when_full(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2])
|
|
slots.set_capacity("http://b2", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b2")
|
|
|
|
await sessions.get_or_create("sess1")
|
|
await sessions.update("sess1", preferred_backend="http://b2")
|
|
|
|
entry = _entry(model_id="m1", session_id="sess1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_no_live_backends_entry_stays(self):
|
|
scheduler, queue, *_ = self._make_scheduler([])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertFalse(entry.future.done())
|
|
self.assertEqual(await queue.depth(), 1)
|
|
|
|
def test_notify_slot_released_wakes_queue(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
scheduler.notify_slot_released()
|
|
self.assertTrue(queue.wakeup_event.is_set())
|
|
|
|
async def test_cancelled_future_cleaned_up(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
entry.future.cancel()
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertEqual(await queue.depth(), 0)
|
|
|
|
async def test_round_robin_across_backends(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, *_ = self._make_scheduler([b1, b2])
|
|
|
|
results = []
|
|
for _ in range(4):
|
|
e = _entry(model_id="m1")
|
|
await queue.enqueue(e)
|
|
await scheduler._dispatch_all()
|
|
results.append(e.future.result().url)
|
|
|
|
self.assertEqual(results.count("http://b1"), 2)
|
|
self.assertEqual(results.count("http://b2"), 2)
|
|
|
|
async def test_slot_released_then_dispatch(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertFalse(entry.future.done())
|
|
|
|
await slots.release("http://b1")
|
|
scheduler.notify_slot_released()
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done())
|
|
|
|
# ------------------------------------------------------------------
|
|
# max_models / preemption prevention
|
|
# ------------------------------------------------------------------
|
|
|
|
async def test_max_models_blocks_second_model_on_same_backend(self):
|
|
b1 = _make_state("http://b1")
|
|
_, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 4)
|
|
slots.set_max_models("http://b1", 1)
|
|
|
|
# Occupy a slot with model-a
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
# Request for model-b should stay queued
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])])
|
|
registry = BackendRegistry(cfg)
|
|
registry._states["http://b1"] = BackendState(
|
|
config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]),
|
|
live=True,
|
|
models=["m1", "m2"],
|
|
)
|
|
registry._rebuild_index()
|
|
|
|
sched2 = Scheduler(
|
|
queue=queue,
|
|
registry=registry,
|
|
slot_tracker=slots,
|
|
session_store=SessionStore(),
|
|
policy=RoundRobinPolicy(),
|
|
)
|
|
entry = _entry(model_id="m2")
|
|
await queue.enqueue(entry)
|
|
await sched2._dispatch_all()
|
|
self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1")
|
|
|
|
async def test_max_models_allows_same_model(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 4)
|
|
slots.set_max_models("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
|
|
|
|
# ------------------------------------------------------------------
|
|
# N-skip reordering
|
|
# ------------------------------------------------------------------
|
|
|
|
async def test_no_reorder_when_max_queue_skip_zero(self):
|
|
"""Default FIFO: model-B request is not promoted over model-A."""
|
|
b1 = _make_state("http://b1", models=["m1"])
|
|
b2 = _make_state("http://b2", models=["m2"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0)
|
|
slots.set_capacity("http://b1", 1)
|
|
slots.set_capacity("http://b2", 1)
|
|
|
|
# Fill b1; b2 is free with m2 active
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b2", "m2")
|
|
|
|
# Queue: [m1-request (blocked), m2-request (could go to b2)]
|
|
e_m1 = _entry(model_id="m1")
|
|
e_m2 = _entry(model_id="m2")
|
|
await queue.enqueue(e_m1)
|
|
await queue.enqueue(e_m2)
|
|
|
|
# Release b2 slot so m2 can be served
|
|
await slots.release("http://b2", "m2")
|
|
await scheduler._dispatch_all()
|
|
|
|
# m2 can be dispatched even with max_queue_skip=0 because _dispatch_all
|
|
# scans all entries (not strict head-of-line per model)
|
|
self.assertFalse(e_m1.future.done())
|
|
self.assertTrue(e_m2.future.done())
|
|
# skip_count must NOT be bumped when max_queue_skip=0
|
|
self.assertEqual(e_m1.skip_count, 0)
|
|
|
|
async def test_affinity_promotes_matching_model(self):
|
|
"""With max_queue_skip>0, a matching model gets promoted."""
|
|
b1 = _make_state("http://b1", models=["m1"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3)
|
|
slots.set_capacity("http://b1", 2)
|
|
|
|
# b1 already has m1 in-flight
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
# Queue: [m2-entry (no affinity), m1-entry (affinity match)]
|
|
e_other = _entry(model_id="m2") # no backend serves m2
|
|
e_m1 = _entry(model_id="m1")
|
|
await queue.enqueue(e_other)
|
|
await queue.enqueue(e_m1)
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
# m1 entry is promoted (affinity pass), m2 stays (no backend)
|
|
self.assertTrue(e_m1.future.done())
|
|
self.assertFalse(e_other.future.done())
|
|
# e_other was bypassed once
|
|
self.assertEqual(e_other.skip_count, 1)
|
|
|
|
async def test_affinity_skips_when_idle_backend_available(self):
|
|
"""Warm-model routing is bypassed when a completely idle backend exists."""
|
|
b1 = _make_state("http://b1", models=["m1"])
|
|
b2 = _make_state("http://b2", models=["m1"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=3)
|
|
|
|
# b2 is warm (m1 active, 1/2 slots used); b1 is completely idle (0/2)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b2", "m1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
# Affinity pass must not force the request onto the warm backend (b2).
|
|
# Round-robin picks b1 first (b1 is index 0 in the registry), which is
|
|
# correct: b1 is idle and should absorb the load.
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_skip_count_caps_reordering(self):
|
|
"""Once skip_count reaches max_queue_skip the entry freezes at head."""
|
|
b1 = _make_state("http://b1", models=["m1"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2)
|
|
slots.set_capacity("http://b1", 4)
|
|
|
|
# b1 has m1 active
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
e_other = _entry(model_id="m2")
|
|
e_other.skip_count = 2 # already at limit — must not be bypassed
|
|
e_m1 = _entry(model_id="m1")
|
|
await queue.enqueue(e_other)
|
|
await queue.enqueue(e_m1)
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
# Affinity pass stops at e_other (skip_count >= max_queue_skip),
|
|
# so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass.
|
|
# e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass.
|
|
self.assertTrue(e_m1.future.done())
|
|
# skip_count must NOT increase further (entry was frozen)
|
|
self.assertEqual(e_other.skip_count, 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|