Files
llamacpp-ha/tests/test_scheduler.py

321 lines
12 KiB
Python

import asyncio
import unittest
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry, BackendState
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url: str, models: list[str] | None = None) -> BackendState:
cfg = BackendConfig(url=url)
return BackendState(config=cfg, live=True, models=models or ["m1"])
def _entry(**kwargs) -> QueueEntry:
e = QueueEntry(**kwargs)
e.future = asyncio.get_running_loop().create_future()
return e
class TestScheduler(unittest.IsolatedAsyncioTestCase):
def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
registry = BackendRegistry(cfg)
for state in (live_backends or []):
registry._states[state.url] = state
registry._rebuild_index()
slot_tracker = SlotTracker()
for state in (live_backends or []):
slot_tracker.set_capacity(state.url, 2)
session_store = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=max_queue_skip,
)
return scheduler, queue, registry, slot_tracker, session_store
async def test_dispatches_entry_to_live_backend(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_skips_full_backends(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
async def test_session_affinity_preferred(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2])
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b2")
async def test_session_affinity_fallback_when_full(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2])
slots.set_capacity("http://b2", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b2")
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_no_live_backends_entry_stays(self):
scheduler, queue, *_ = self._make_scheduler([])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
self.assertEqual(await queue.depth(), 1)
def test_notify_slot_released_wakes_queue(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
scheduler.notify_slot_released()
self.assertTrue(queue.wakeup_event.is_set())
async def test_cancelled_future_cleaned_up(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
entry.future.cancel()
await scheduler._dispatch_all()
self.assertEqual(await queue.depth(), 0)
async def test_round_robin_across_backends(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, *_ = self._make_scheduler([b1, b2])
results = []
for _ in range(4):
e = _entry(model_id="m1")
await queue.enqueue(e)
await scheduler._dispatch_all()
results.append(e.future.result().url)
self.assertEqual(results.count("http://b1"), 2)
self.assertEqual(results.count("http://b2"), 2)
async def test_slot_released_then_dispatch(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
await slots.release("http://b1")
scheduler.notify_slot_released()
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
# ------------------------------------------------------------------
# max_models / preemption prevention
# ------------------------------------------------------------------
async def test_max_models_blocks_second_model_on_same_backend(self):
b1 = _make_state("http://b1")
_, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
# Occupy a slot with model-a
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Request for model-b should stay queued
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])])
registry = BackendRegistry(cfg)
registry._states["http://b1"] = BackendState(
config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]),
live=True,
models=["m1", "m2"],
)
registry._rebuild_index()
sched2 = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slots,
session_store=SessionStore(),
policy=RoundRobinPolicy(),
)
entry = _entry(model_id="m2")
await queue.enqueue(entry)
await sched2._dispatch_all()
self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1")
async def test_max_models_allows_same_model(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
# ------------------------------------------------------------------
# N-skip reordering
# ------------------------------------------------------------------
async def test_no_reorder_when_max_queue_skip_zero(self):
"""Default FIFO: model-B request is not promoted over model-A."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0)
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# Fill b1; b2 is free with m2 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m2")
# Queue: [m1-request (blocked), m2-request (could go to b2)]
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
# Release b2 slot so m2 can be served
await slots.release("http://b2", "m2")
await scheduler._dispatch_all()
# m2 can be dispatched even with max_queue_skip=0 because _dispatch_all
# scans all entries (not strict head-of-line per model)
self.assertFalse(e_m1.future.done())
self.assertTrue(e_m2.future.done())
# skip_count must NOT be bumped when max_queue_skip=0
self.assertEqual(e_m1.skip_count, 0)
async def test_affinity_promotes_matching_model(self):
"""With max_queue_skip>0, a matching model gets promoted."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3)
slots.set_capacity("http://b1", 2)
# b1 already has m1 in-flight
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: [m2-entry (no affinity), m1-entry (affinity match)]
e_other = _entry(model_id="m2") # no backend serves m2
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m1 entry is promoted (affinity pass), m2 stays (no backend)
self.assertTrue(e_m1.future.done())
self.assertFalse(e_other.future.done())
# e_other was bypassed once
self.assertEqual(e_other.skip_count, 1)
async def test_affinity_skips_when_idle_backend_available(self):
"""Warm-model routing is bypassed when a completely idle backend exists."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=3)
# b2 is warm (m1 active, 1/2 slots used); b1 is completely idle (0/2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
# Affinity pass must not force the request onto the warm backend (b2).
# Round-robin picks b1 first (b1 is index 0 in the registry), which is
# correct: b1 is idle and should absorb the load.
self.assertEqual(entry.future.result().url, "http://b1")
async def test_skip_count_caps_reordering(self):
"""Once skip_count reaches max_queue_skip the entry freezes at head."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2)
slots.set_capacity("http://b1", 4)
# b1 has m1 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_other = _entry(model_id="m2")
e_other.skip_count = 2 # already at limit — must not be bypassed
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# Affinity pass stops at e_other (skip_count >= max_queue_skip),
# so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass.
# e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass.
self.assertTrue(e_m1.future.done())
# skip_count must NOT increase further (entry was frozen)
self.assertEqual(e_other.skip_count, 2)
if __name__ == "__main__":
unittest.main()