304 lines
11 KiB
Python
304 lines
11 KiB
Python
import asyncio
|
|
import unittest
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
from llamacpp_ha.policies import RoundRobinPolicy
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
from llamacpp_ha.registry import BackendRegistry, BackendState
|
|
from llamacpp_ha.scheduler import Scheduler
|
|
from llamacpp_ha.session_store import SessionStore
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
|
|
def _make_state(url: str, models: list[str] | None = None) -> BackendState:
|
|
cfg = BackendConfig(url=url)
|
|
return BackendState(config=cfg, live=True, models=models or ["m1"])
|
|
|
|
|
|
def _entry(**kwargs) -> QueueEntry:
|
|
e = QueueEntry(**kwargs)
|
|
e.future = asyncio.get_running_loop().create_future()
|
|
return e
|
|
|
|
|
|
class TestScheduler(unittest.IsolatedAsyncioTestCase):
|
|
def _make_scheduler(
|
|
self,
|
|
live_backends=None,
|
|
model_affinity_sched_bonus: int = 0,
|
|
queue_aging_equalization: float = 30.0,
|
|
):
|
|
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
|
|
registry = BackendRegistry(cfg)
|
|
for state in (live_backends or []):
|
|
registry._states[state.url] = state
|
|
registry._rebuild_index()
|
|
|
|
slot_tracker = SlotTracker()
|
|
for state in (live_backends or []):
|
|
slot_tracker.set_capacity(state.url, 2)
|
|
|
|
session_store = SessionStore()
|
|
queue = RequestQueue()
|
|
scheduler = Scheduler(
|
|
queue=queue,
|
|
registry=registry,
|
|
slot_tracker=slot_tracker,
|
|
session_store=session_store,
|
|
policy=RoundRobinPolicy(),
|
|
model_affinity_sched_bonus=model_affinity_sched_bonus,
|
|
queue_aging_equalization=queue_aging_equalization,
|
|
)
|
|
return scheduler, queue, registry, slot_tracker, session_store
|
|
|
|
async def test_dispatches_entry_to_live_backend(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_skips_full_backends(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertFalse(entry.future.done())
|
|
|
|
async def test_session_affinity_preferred(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, _, _, sessions = self._make_scheduler([b1, b2])
|
|
|
|
await sessions.get_or_create("sess1")
|
|
await sessions.update("sess1", preferred_backend="http://b2")
|
|
|
|
entry = _entry(model_id="m1", session_id="sess1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b2")
|
|
|
|
async def test_session_affinity_fallback_when_full(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, _, slots, sessions = self._make_scheduler([b1, b2])
|
|
slots.set_capacity("http://b2", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b2")
|
|
|
|
await sessions.get_or_create("sess1")
|
|
await sessions.update("sess1", preferred_backend="http://b2")
|
|
|
|
entry = _entry(model_id="m1", session_id="sess1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertTrue(entry.future.done())
|
|
self.assertEqual(entry.future.result().url, "http://b1")
|
|
|
|
async def test_no_live_backends_entry_stays(self):
|
|
scheduler, queue, *_ = self._make_scheduler([])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertFalse(entry.future.done())
|
|
self.assertEqual(await queue.depth(), 1)
|
|
|
|
def test_notify_slot_released_wakes_queue(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
scheduler.notify_slot_released()
|
|
self.assertTrue(queue.wakeup_event.is_set())
|
|
|
|
async def test_cancelled_future_cleaned_up(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, *_ = self._make_scheduler([b1])
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
entry.future.cancel()
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertEqual(await queue.depth(), 0)
|
|
|
|
async def test_round_robin_across_backends(self):
|
|
b1 = _make_state("http://b1")
|
|
b2 = _make_state("http://b2")
|
|
scheduler, queue, *_ = self._make_scheduler([b1, b2])
|
|
|
|
results = []
|
|
for _ in range(4):
|
|
e = _entry(model_id="m1")
|
|
await queue.enqueue(e)
|
|
await scheduler._dispatch_all()
|
|
results.append(e.future.result().url)
|
|
|
|
self.assertEqual(results.count("http://b1"), 2)
|
|
self.assertEqual(results.count("http://b2"), 2)
|
|
|
|
async def test_slot_released_then_dispatch(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertFalse(entry.future.done())
|
|
|
|
await slots.release("http://b1")
|
|
scheduler.notify_slot_released()
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done())
|
|
|
|
# ------------------------------------------------------------------
|
|
# max_models / preemption prevention
|
|
# ------------------------------------------------------------------
|
|
|
|
async def test_max_models_blocks_second_model_on_same_backend(self):
|
|
b1 = _make_state("http://b1")
|
|
_, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 4)
|
|
slots.set_max_models("http://b1", 1)
|
|
|
|
# Occupy a slot with model-a
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
# Request for model-b should stay queued
|
|
cfg = ProxyConfig(backends=[BackendConfig(url="http://b1", model_ids=["m1", "m2"])])
|
|
registry = BackendRegistry(cfg)
|
|
registry._states["http://b1"] = BackendState(
|
|
config=BackendConfig(url="http://b1", model_ids=["m1", "m2"]),
|
|
live=True,
|
|
models=["m1", "m2"],
|
|
)
|
|
registry._rebuild_index()
|
|
|
|
sched2 = Scheduler(
|
|
queue=queue,
|
|
registry=registry,
|
|
slot_tracker=slots,
|
|
session_store=SessionStore(),
|
|
policy=RoundRobinPolicy(),
|
|
)
|
|
entry = _entry(model_id="m2")
|
|
await queue.enqueue(entry)
|
|
await sched2._dispatch_all()
|
|
self.assertFalse(entry.future.done(), "model-b should be blocked by max_models=1")
|
|
|
|
async def test_max_models_allows_same_model(self):
|
|
b1 = _make_state("http://b1")
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
|
|
slots.set_capacity("http://b1", 4)
|
|
slots.set_max_models("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
entry = _entry(model_id="m1")
|
|
await queue.enqueue(entry)
|
|
await scheduler._dispatch_all()
|
|
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Priority scheduling
|
|
# ------------------------------------------------------------------
|
|
|
|
async def test_pure_fifo_when_bonus_zero(self):
|
|
"""Default (bonus=0): blocked head-of-queue does not prevent later entries."""
|
|
b1 = _make_state("http://b1", models=["m1"])
|
|
b2 = _make_state("http://b2", models=["m2"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2])
|
|
slots.set_capacity("http://b1", 1)
|
|
slots.set_capacity("http://b2", 1)
|
|
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b2", "m2")
|
|
|
|
e_m1 = _entry(model_id="m1")
|
|
e_m2 = _entry(model_id="m2")
|
|
await queue.enqueue(e_m1)
|
|
await queue.enqueue(e_m2)
|
|
|
|
await slots.release("http://b2", "m2")
|
|
await scheduler._dispatch_all()
|
|
|
|
self.assertFalse(e_m1.future.done())
|
|
self.assertTrue(e_m2.future.done())
|
|
|
|
async def test_warm_model_gets_priority(self):
|
|
"""With bonus>0, a warm-model request is dispatched before a cold one."""
|
|
b1 = _make_state("http://b1", models=["m1", "m2"])
|
|
scheduler, queue, _, slots, _ = self._make_scheduler(
|
|
[b1], model_affinity_sched_bonus=10
|
|
)
|
|
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
|
|
slots.set_capacity("http://b1", 2)
|
|
|
|
# m1 is in-flight on b1 (warm); m2 is cold
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
# Queue: cold m2 request arrives first, warm m1 request arrives second
|
|
e_m2 = _entry(model_id="m2")
|
|
e_m1 = _entry(model_id="m1")
|
|
await queue.enqueue(e_m2)
|
|
await queue.enqueue(e_m1)
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
# m1 has warm bonus → higher priority → dispatched first despite arriving later
|
|
self.assertTrue(e_m1.future.done())
|
|
self.assertFalse(e_m2.future.done())
|
|
|
|
async def test_aging_overtakes_warm_bonus(self):
|
|
"""After equalization time, an aged cold request outranks the warm bonus."""
|
|
b1 = _make_state("http://b1", models=["m1", "m2"])
|
|
# equalization=0.1s so aging is fast enough to test synchronously
|
|
scheduler, queue, _, slots, _ = self._make_scheduler(
|
|
[b1], model_affinity_sched_bonus=10, queue_aging_equalization=0.1
|
|
)
|
|
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
|
|
slots.set_capacity("http://b1", 2)
|
|
|
|
async with asyncio.timeout(1.0):
|
|
await slots.acquire("http://b1", "m1")
|
|
|
|
# m2 request arrives and waits long enough to exceed the bonus
|
|
e_m2 = _entry(model_id="m2")
|
|
await queue.enqueue(e_m2)
|
|
await asyncio.sleep(0.15) # age_score > bonus after equalization
|
|
|
|
# m1 warm request arrives after m2 has already aged past equalization
|
|
e_m1 = _entry(model_id="m1")
|
|
await queue.enqueue(e_m1)
|
|
|
|
await scheduler._dispatch_all()
|
|
|
|
# m2's age_score now exceeds the warm bonus → m2 dispatched first
|
|
self.assertTrue(e_m2.future.done())
|
|
self.assertFalse(e_m1.future.done())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|