188 lines
7.2 KiB
Python
188 lines
7.2 KiB
Python
import asyncio
|
|
import unittest
|
|
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
|
|
class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
|
|
async def test_acquire_when_free(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 2)
|
|
await tracker.acquire("http://b")
|
|
acquired, total = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 1)
|
|
self.assertEqual(total, 2)
|
|
|
|
async def test_has_free_slot(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
self.assertTrue(tracker.has_free_slot("http://b"))
|
|
await tracker.acquire("http://b")
|
|
self.assertFalse(tracker.has_free_slot("http://b"))
|
|
|
|
async def test_timeout_when_full(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
await tracker.acquire("http://b")
|
|
with self.assertRaises(TimeoutError):
|
|
async with asyncio.timeout(0.05):
|
|
await tracker.acquire("http://b")
|
|
|
|
async def test_release_unblocks_waiter(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
await tracker.acquire("http://b")
|
|
|
|
results = []
|
|
|
|
async def waiter():
|
|
async with asyncio.timeout(2.0):
|
|
await tracker.acquire("http://b")
|
|
results.append(True)
|
|
|
|
task = asyncio.create_task(waiter())
|
|
await asyncio.sleep(0.05)
|
|
await tracker.release("http://b")
|
|
await task
|
|
self.assertEqual(results, [True])
|
|
|
|
async def test_release_below_zero(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
await tracker.release("http://b")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 0)
|
|
|
|
def test_set_capacity_increase(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
tracker.set_capacity("http://b", 3)
|
|
_, total = tracker.usage("http://b")
|
|
self.assertEqual(total, 3)
|
|
|
|
def test_unknown_url_defaults(self):
|
|
tracker = SlotTracker()
|
|
self.assertTrue(tracker.has_free_slot("http://unknown"))
|
|
acquired, total = tracker.usage("http://unknown")
|
|
self.assertEqual(acquired, 0)
|
|
self.assertEqual(total, 1)
|
|
|
|
async def test_acquire_zero_timeout_succeeds_then_fails(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
async with asyncio.timeout(0):
|
|
await tracker.acquire("http://b")
|
|
with self.assertRaises(TimeoutError):
|
|
async with asyncio.timeout(0):
|
|
await tracker.acquire("http://b")
|
|
|
|
async def test_release_decrements(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 2)
|
|
await tracker.acquire("http://b")
|
|
await tracker.acquire("http://b")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 2)
|
|
await tracker.release("http://b")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 1)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Model-aware tests
|
|
# ------------------------------------------------------------------
|
|
|
|
def test_can_accept_respects_max_models(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
tracker.set_max_models("http://b", 1)
|
|
self.assertTrue(tracker.can_accept("http://b", "model-a"))
|
|
|
|
async def test_max_models_blocks_second_model(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
tracker.set_max_models("http://b", 1)
|
|
await tracker.acquire("http://b", "model-a")
|
|
# model-a is still accepted (same model, slot available)
|
|
self.assertTrue(tracker.can_accept("http://b", "model-a"))
|
|
# model-b is blocked (max_models=1 already reached)
|
|
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
|
|
|
async def test_max_models_unblocks_after_release(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
tracker.set_max_models("http://b", 1)
|
|
await tracker.acquire("http://b", "model-a")
|
|
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
|
await tracker.release("http://b", "model-a")
|
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
|
|
|
async def test_active_model_set(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
|
|
await tracker.acquire("http://b", "model-a")
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
|
|
await tracker.acquire("http://b", "model-b")
|
|
self.assertEqual(
|
|
tracker.active_model_set("http://b"), frozenset({"model-a", "model-b"})
|
|
)
|
|
await tracker.release("http://b", "model-a")
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-b"}))
|
|
|
|
async def test_acquire_tracks_active_models(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
await tracker.acquire("http://b", "model-a")
|
|
await tracker.acquire("http://b", "model-a")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 2)
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
|
|
await tracker.release("http://b", "model-a")
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset({"model-a"}))
|
|
await tracker.release("http://b", "model-a")
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
|
|
|
|
async def test_reset_acquired_clears_state(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 2)
|
|
await tracker.acquire("http://b", "model-a")
|
|
await tracker.acquire("http://b", "model-a")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 2)
|
|
await tracker.reset_acquired("http://b")
|
|
acquired, _ = tracker.usage("http://b")
|
|
self.assertEqual(acquired, 0)
|
|
self.assertEqual(tracker.active_model_set("http://b"), frozenset())
|
|
|
|
async def test_reset_acquired_unblocks_waiters(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 1)
|
|
tracker.set_max_models("http://b", 1)
|
|
await tracker.acquire("http://b", "model-a")
|
|
|
|
unblocked = []
|
|
|
|
async def waiter():
|
|
async with asyncio.timeout(2.0):
|
|
await tracker.acquire("http://b", "model-b")
|
|
unblocked.append(True)
|
|
|
|
task = asyncio.create_task(waiter())
|
|
await asyncio.sleep(0.05)
|
|
self.assertFalse(unblocked)
|
|
await tracker.reset_acquired("http://b")
|
|
await task
|
|
self.assertEqual(unblocked, [True])
|
|
|
|
async def test_max_models_none_allows_any(self):
|
|
tracker = SlotTracker()
|
|
tracker.set_capacity("http://b", 4)
|
|
tracker.set_max_models("http://b", None)
|
|
await tracker.acquire("http://b", "model-a")
|
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
|
self.assertTrue(tracker.can_accept("http://b", "model-c"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|