Files
llamacpp-ha/tests/test_scheduler.py

304 lines
11 KiB
Python

import asyncio
import unittest
from llamacpp_ha.config import BackendConfig, ProxyConfig
from llamacpp_ha.policies import RoundRobinPolicy
from llamacpp_ha.queue import QueueEntry, RequestQueue
from llamacpp_ha.registry import BackendRegistry, BackendState
from llamacpp_ha.scheduler import Scheduler
from llamacpp_ha.session_store import SessionStore
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url: str, models: list[str] | None = None) -> BackendState:
cfg = BackendConfig(url=url)
return BackendState(config=cfg, live=True, models=models or ["m1"])
def _entry(**kwargs) -> QueueEntry:
e = QueueEntry(**kwargs)
e.future = asyncio.get_running_loop().create_future()
return e
class TestScheduler(unittest.IsolatedAsyncioTestCase):
def _make_scheduler(
self,
live_backends=None,
model_affinity_sched_bonus: int = 0,
queue_aging_equalization: float = 30.0,
):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
registry = BackendRegistry(cfg)
for state in (live_backends or []):
registry._states[state.url] = state
registry._rebuild_index()
slot_tracker = SlotTracker()
for state in (live_backends or []):
slot_tracker.set_capacity(state.url, 2)
session_store = SessionStore()
queue = RequestQueue()
scheduler = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
model_affinity_sched_bonus=model_affinity_sched_bonus,
queue_aging_equalization=queue_aging_equalization,
)
return scheduler, queue, registry, slot_tracker, session_store
async def test_dispatches_entry_to_live_backend(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_skips_full_backends(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
async def test_session_affinity_preferred(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2])
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b2")
async def test_session_affinity_fallback_when_full(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2])
slots.set_capacity("http://b2", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b2")
await sessions.get_or_create("sess1")
await sessions.update("sess1", preferred_backend="http://b2")
entry = _entry(model_id="m1", session_id="sess1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
self.assertEqual(entry.future.result().url, "http://b1")
async def test_no_live_backends_entry_stays(self):
scheduler, queue, *_ = self._make_scheduler([])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
self.assertEqual(await queue.depth(), 1)
def test_notify_slot_released_wakes_queue(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
scheduler.notify_slot_released()
self.assertTrue(queue.wakeup_event.is_set())
async def test_cancelled_future_cleaned_up(self):
b1 = _make_state("http://b1")
scheduler, queue, *_ = self._make_scheduler([b1])
entry = _entry(model_id="m1")
await queue.enqueue(entry)
entry.future.cancel()
await scheduler._dispatch_all()
self.assertEqual(await queue.depth(), 0)
async def test_round_robin_across_backends(self):
b1 = _make_state("http://b1")
b2 = _make_state("http://b2")
scheduler, queue, *_ = self._make_scheduler([b1, b2])
results = []
for _ in range(4):
e = _entry(model_id="m1")
await queue.enqueue(e)
await scheduler._dispatch_all()
results.append(e.future.result().url)
self.assertEqual(results.count("http://b1"), 2)
self.assertEqual(results.count("http://b2"), 2)
async def test_slot_released_then_dispatch(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done())
await slots.release("http://b1")
scheduler.notify_slot_released()
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
# ------------------------------------------------------------------
# max_models / preemption prevention
# ------------------------------------------------------------------
async def test_max_models_blocks_second_model_on_same_backend(self):
b1 = _make_state("http://b1")
_, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
# Occupy a slot with model-a
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Request for model-b should stay queued
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])])
registry = BackendRegistry(cfg)
registry._states["http://b1"] = BackendState(
config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]),
live=True,
models=["m1", "m2"],
)
registry._rebuild_index()
sched2 = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slots,
session_store=SessionStore(),
policy=RoundRobinPolicy(),
)
entry = _entry(model_id="m2")
await queue.enqueue(entry)
await sched2._dispatch_all()
self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1")
async def test_max_models_allows_same_model(self):
b1 = _make_state("http://b1")
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 4)
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
# ------------------------------------------------------------------
# Priority scheduling
# ------------------------------------------------------------------
async def test_pure_fifo_when_bonus_zero(self):
"""Default (bonus=0): blocked head-of-queue does not prevent later entries."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2])
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m2")
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await slots.release("http://b2", "m2")
await scheduler._dispatch_all()
self.assertFalse(e_m1.future.done())
self.assertTrue(e_m2.future.done())
async def test_warm_model_gets_priority(self):
"""With bonus>0, a warm-model request is dispatched before a cold one."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
# m1 is in-flight on b1 (warm); m2 is cold
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: cold m2 request arrives first, warm m1 request arrives second
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m1 has warm bonus → higher priority → dispatched first despite arriving later
self.assertTrue(e_m1.future.done())
self.assertFalse(e_m2.future.done())
async def test_aging_overtakes_warm_bonus(self):
"""After equalization time, an aged cold request outranks the warm bonus."""
b1 = _make_state("http://b1", models=["m1", "m2"])
# equalization=0.1s so aging is fast enough to test synchronously
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10, queue_aging_equalization=0.1
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# m2 request arrives and waits long enough to exceed the bonus
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await asyncio.sleep(0.15) # age_score > bonus after equalization
# m1 warm request arrives after m2 has already aged past equalization
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m2's age_score now exceeds the warm bonus → m2 dispatched first
self.assertTrue(e_m2.future.done())
self.assertFalse(e_m1.future.done())
if __name__ == "__main__":
unittest.main()