93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
import asyncio
|
|
import unittest
|
|
|
|
from llamacpp_ha.queue import QueueEntry, RequestQueue
|
|
|
|
|
|
def _entry(**kwargs) -> QueueEntry:
|
|
e = QueueEntry(**kwargs)
|
|
e.future = asyncio.get_running_loop().create_future()
|
|
return e
|
|
|
|
|
|
class TestRequestQueue(unittest.IsolatedAsyncioTestCase):
|
|
async def test_enqueue_and_pending(self):
|
|
q = RequestQueue()
|
|
e1 = _entry(model_id="m1")
|
|
e2 = _entry(model_id="m2")
|
|
await q.enqueue(e1)
|
|
await q.enqueue(e2)
|
|
pending = await q.pending()
|
|
self.assertEqual(len(pending), 2)
|
|
self.assertEqual(pending[0].model_id, "m1")
|
|
self.assertEqual(pending[1].model_id, "m2")
|
|
|
|
async def test_fifo_order(self):
|
|
q = RequestQueue()
|
|
ids = []
|
|
for i in range(5):
|
|
e = _entry(model_id=f"m{i}")
|
|
await q.enqueue(e)
|
|
ids.append(e.request_id)
|
|
pending = await q.pending()
|
|
self.assertEqual([e.request_id for e in pending], ids)
|
|
|
|
async def test_remove(self):
|
|
q = RequestQueue()
|
|
e1 = _entry(model_id="m1")
|
|
e2 = _entry(model_id="m2")
|
|
await q.enqueue(e1)
|
|
await q.enqueue(e2)
|
|
await q.remove(e1)
|
|
pending = await q.pending()
|
|
self.assertEqual(len(pending), 1)
|
|
self.assertEqual(pending[0].model_id, "m2")
|
|
|
|
async def test_remove_nonexistent_no_error(self):
|
|
q = RequestQueue()
|
|
e = _entry(model_id="m")
|
|
await q.remove(e) # should not raise
|
|
|
|
async def test_depth(self):
|
|
q = RequestQueue()
|
|
self.assertEqual(await q.depth(), 0)
|
|
for _ in range(3):
|
|
await q.enqueue(_entry(model_id="m"))
|
|
self.assertEqual(await q.depth(), 3)
|
|
|
|
async def test_wakeup_event_set_on_enqueue(self):
|
|
q = RequestQueue()
|
|
self.assertFalse(q.wakeup_event.is_set())
|
|
await q.enqueue(_entry(model_id="m"))
|
|
self.assertTrue(q.wakeup_event.is_set())
|
|
|
|
def test_notify_sets_wakeup(self):
|
|
q = RequestQueue()
|
|
q.notify()
|
|
self.assertTrue(q.wakeup_event.is_set())
|
|
|
|
async def test_entry_metadata(self):
|
|
import time
|
|
before = time.monotonic()
|
|
e = QueueEntry(model_id="llama3", session_id="sess123", estimated_tokens=42)
|
|
after = time.monotonic()
|
|
self.assertEqual(e.model_id, "llama3")
|
|
self.assertEqual(e.session_id, "sess123")
|
|
self.assertEqual(e.estimated_tokens, 42)
|
|
self.assertIsNone(e.future)
|
|
self.assertGreaterEqual(e.arrival_time, before)
|
|
self.assertLessEqual(e.arrival_time, after)
|
|
self.assertGreaterEqual(e.wait_seconds, 0)
|
|
|
|
async def test_snapshot_truncates_session_id(self):
|
|
q = RequestQueue()
|
|
e = _entry(model_id="m", session_id="abcdef1234567890")
|
|
await q.enqueue(e)
|
|
snap = await q.snapshot()
|
|
self.assertEqual(len(snap), 1)
|
|
self.assertEqual(snap[0]["session_id"], "abcdef12")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|