import asyncio import unittest from llamacpp_ha.queue import QueueEntry, RequestQueue def _entry(**kwargs) -> QueueEntry: e = QueueEntry(**kwargs) e.future = asyncio.get_running_loop().create_future() return e class TestRequestQueue(unittest.IsolatedAsyncioTestCase): async def test_enqueue_and_pending(self): q = RequestQueue() e1 = _entry(model_id="m1") e2 = _entry(model_id="m2") await q.enqueue(e1) await q.enqueue(e2) pending = await q.pending() self.assertEqual(len(pending), 2) self.assertEqual(pending[0].model_id, "m1") self.assertEqual(pending[1].model_id, "m2") async def test_fifo_order(self): q = RequestQueue() ids = [] for i in range(5): e = _entry(model_id=f"m{i}") await q.enqueue(e) ids.append(e.request_id) pending = await q.pending() self.assertEqual([e.request_id for e in pending], ids) async def test_remove(self): q = RequestQueue() e1 = _entry(model_id="m1") e2 = _entry(model_id="m2") await q.enqueue(e1) await q.enqueue(e2) await q.remove(e1) pending = await q.pending() self.assertEqual(len(pending), 1) self.assertEqual(pending[0].model_id, "m2") async def test_remove_nonexistent_no_error(self): q = RequestQueue() e = _entry(model_id="m") await q.remove(e) # should not raise async def test_depth(self): q = RequestQueue() self.assertEqual(await q.depth(), 0) for _ in range(3): await q.enqueue(_entry(model_id="m")) self.assertEqual(await q.depth(), 3) async def test_wakeup_event_set_on_enqueue(self): q = RequestQueue() self.assertFalse(q.wakeup_event.is_set()) await q.enqueue(_entry(model_id="m")) self.assertTrue(q.wakeup_event.is_set()) def test_notify_sets_wakeup(self): q = RequestQueue() q.notify() self.assertTrue(q.wakeup_event.is_set()) async def test_entry_metadata(self): import time before = time.monotonic() e = QueueEntry(model_id="llama3", session_id="sess123", estimated_tokens=42) after = time.monotonic() self.assertEqual(e.model_id, "llama3") self.assertEqual(e.session_id, "sess123") self.assertEqual(e.estimated_tokens, 42) self.assertIsNone(e.future) self.assertGreaterEqual(e.arrival_time, before) self.assertLessEqual(e.arrival_time, after) self.assertGreaterEqual(e.wait_seconds, 0) async def test_snapshot_truncates_session_id(self): q = RequestQueue() e = _entry(model_id="m", session_id="abcdef1234567890") await q.enqueue(e) snap = await q.snapshot() self.assertEqual(len(snap), 1) self.assertEqual(snap[0]["session_id"], "abcdef12") if __name__ == "__main__": unittest.main()