Files
llamacpp-ha/tests/test_slot_tracker.py
2026-05-17 09:54:18 +02:00

188 lines
7.2 KiB
Python

import asyncio
import unittest
from llamacpp_ha.slot_tracker import SlotTracker
class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
async def test_acquire_when_free(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b")
acquired, total = tracker.usage("http://b")
self.assertEqual(acquired, 1)
self.assertEqual(total, 2)
async def test_has_free_slot(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
self.assertTrue(tracker.has_free_slot("http://b"))
await tracker.acquire("http://b")
self.assertFalse(tracker.has_free_slot("http://b"))
async def test_timeout_when_full(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.acquire("http://b")
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.05):
await tracker.acquire("http://b")
async def test_release_unblocks_waiter(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.acquire("http://b")
results = []
async def waiter():
async with asyncio.timeout(2.0):
await tracker.acquire("http://b")
results.append(True)
task = asyncio.create_task(waiter())
await asyncio.sleep(0.05)
await tracker.release("http://b")
await task
self.assertEqual(results, [True])
async def test_release_below_zero(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
await tracker.release("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 0)
def test_set_capacity_increase(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_capacity("http://b", 3)
_, total = tracker.usage("http://b")
self.assertEqual(total, 3)
def test_unknown_url_defaults(self):
tracker = SlotTracker()
self.assertTrue(tracker.has_free_slot("http://unknown"))
acquired, total = tracker.usage("http://unknown")
self.assertEqual(acquired, 0)
self.assertEqual(total, 1)
async def test_acquire_zero_timeout_succeeds_then_fails(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
async with asyncio.timeout(0):
await tracker.acquire("http://b")
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0):
await tracker.acquire("http://b")
async def test_release_decrements(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b")
await tracker.acquire("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
await tracker.release("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 1)
# ------------------------------------------------------------------
# Model-aware tests
# ------------------------------------------------------------------
def test_can_accept_respects_max_models(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
async def test_max_models_blocks_second_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
# model-a is still accepted (same model, slot available)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
# model-b is blocked (max_models=1 already reached)
self.assertFalse(tracker.can_accept("http://b", "model-b"))
async def test_max_models_unblocks_after_release(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_active_model_set(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
await tracker.acquire("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.acquire("http://b", "model-b")
self.assertEqual(
tracker.active_model_set("http://b"), frozenset({"model-a", "model-b"})
)
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-b"}))
async def test_acquire_tracks_active_models(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
await tracker.release("http://b", "model-a")
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
async def test_reset_acquired_clears_state(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 2)
await tracker.reset_acquired("http://b")
acquired, _ = tracker.usage("http://b")
self.assertEqual(acquired, 0)
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
async def test_reset_acquired_unblocks_waiters(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_max_models("http://b", 1)
await tracker.acquire("http://b", "model-a")
unblocked = []
async def waiter():
async with asyncio.timeout(2.0):
await tracker.acquire("http://b", "model-b")
unblocked.append(True)
task = asyncio.create_task(waiter())
await asyncio.sleep(0.05)
self.assertFalse(unblocked)
await tracker.reset_acquired("http://b")
await task
self.assertEqual(unblocked, [True])
async def test_max_models_none_allows_any(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_max_models("http://b", None)
await tracker.acquire("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
self.assertTrue(tracker.can_accept("http://b", "model-c"))
if __name__ == "__main__":
unittest.main()