Files
llamacpp-ha/tests/test_queue.py
2026-05-17 09:54:18 +02:00

93 lines
2.9 KiB
Python

import asyncio
import unittest
from llamacpp_ha.queue import QueueEntry, RequestQueue
def _entry(**kwargs) -> QueueEntry:
e = QueueEntry(**kwargs)
e.future = asyncio.get_running_loop().create_future()
return e
class TestRequestQueue(unittest.IsolatedAsyncioTestCase):
async def test_enqueue_and_pending(self):
q = RequestQueue()
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m2")
await q.enqueue(e1)
await q.enqueue(e2)
pending = await q.pending()
self.assertEqual(len(pending), 2)
self.assertEqual(pending[0].model_id, "m1")
self.assertEqual(pending[1].model_id, "m2")
async def test_fifo_order(self):
q = RequestQueue()
ids = []
for i in range(5):
e = _entry(model_id=f"m{i}")
await q.enqueue(e)
ids.append(e.request_id)
pending = await q.pending()
self.assertEqual([e.request_id for e in pending], ids)
async def test_remove(self):
q = RequestQueue()
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m2")
await q.enqueue(e1)
await q.enqueue(e2)
await q.remove(e1)
pending = await q.pending()
self.assertEqual(len(pending), 1)
self.assertEqual(pending[0].model_id, "m2")
async def test_remove_nonexistent_no_error(self):
q = RequestQueue()
e = _entry(model_id="m")
await q.remove(e) # should not raise
async def test_depth(self):
q = RequestQueue()
self.assertEqual(await q.depth(), 0)
for _ in range(3):
await q.enqueue(_entry(model_id="m"))
self.assertEqual(await q.depth(), 3)
async def test_wakeup_event_set_on_enqueue(self):
q = RequestQueue()
self.assertFalse(q.wakeup_event.is_set())
await q.enqueue(_entry(model_id="m"))
self.assertTrue(q.wakeup_event.is_set())
def test_notify_sets_wakeup(self):
q = RequestQueue()
q.notify()
self.assertTrue(q.wakeup_event.is_set())
async def test_entry_metadata(self):
import time
before = time.monotonic()
e = QueueEntry(model_id="llama3", session_id="sess123", estimated_tokens=42)
after = time.monotonic()
self.assertEqual(e.model_id, "llama3")
self.assertEqual(e.session_id, "sess123")
self.assertEqual(e.estimated_tokens, 42)
self.assertIsNone(e.future)
self.assertGreaterEqual(e.arrival_time, before)
self.assertLessEqual(e.arrival_time, after)
self.assertGreaterEqual(e.wait_seconds, 0)
async def test_snapshot_truncates_session_id(self):
q = RequestQueue()
e = _entry(model_id="m", session_id="abcdef1234567890")
await q.enqueue(e)
snap = await q.snapshot()
self.assertEqual(len(snap), 1)
self.assertEqual(snap[0]["session_id"], "abcdef12")
if __name__ == "__main__":
unittest.main()