113 lines
4.0 KiB
Python
113 lines
4.0 KiB
Python
import asyncio
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from llamacpp_ha.session_store import SessionStore, compute_prefix_hash
|
|
|
|
|
|
class TestComputePrefixHash(unittest.TestCase):
|
|
def test_same_messages_same_hash(self):
|
|
msgs = [{"role": "user", "content": "hello"}]
|
|
self.assertEqual(compute_prefix_hash(msgs), compute_prefix_hash(msgs))
|
|
|
|
def test_different_messages_different_hash(self):
|
|
a = [{"role": "user", "content": "hello"}]
|
|
b = [{"role": "user", "content": "world"}]
|
|
self.assertNotEqual(compute_prefix_hash(a), compute_prefix_hash(b))
|
|
|
|
def test_empty_messages(self):
|
|
h = compute_prefix_hash([])
|
|
self.assertIsInstance(h, str)
|
|
self.assertEqual(len(h), 16)
|
|
|
|
|
|
class TestSessionStore(unittest.IsolatedAsyncioTestCase):
|
|
async def test_get_or_create_new(self):
|
|
store = SessionStore(ttl=300.0)
|
|
session = await store.get_or_create("abc")
|
|
self.assertEqual(session.session_id, "abc")
|
|
self.assertIsNone(session.model_id)
|
|
|
|
async def test_get_or_create_existing(self):
|
|
store = SessionStore(ttl=300.0)
|
|
s1 = await store.get_or_create("abc")
|
|
await store.update("abc", model_id="llama3")
|
|
s2 = await store.get_or_create("abc")
|
|
self.assertEqual(s2.model_id, "llama3")
|
|
|
|
async def test_update_model_and_messages(self):
|
|
store = SessionStore(ttl=300.0)
|
|
await store.get_or_create("s1")
|
|
msgs = [{"role": "user", "content": "hi"}]
|
|
await store.update("s1", model_id="m1", messages=msgs, preferred_backend="http://b1")
|
|
pref = await store.get_preferred_backend("s1")
|
|
self.assertEqual(pref, "http://b1")
|
|
|
|
async def test_affinity_hit(self):
|
|
store = SessionStore(ttl=300.0)
|
|
await store.get_or_create("s1")
|
|
await store.update("s1", preferred_backend="http://b2")
|
|
pref = await store.get_preferred_backend("s1")
|
|
self.assertEqual(pref, "http://b2")
|
|
|
|
async def test_affinity_miss_unknown_session(self):
|
|
store = SessionStore(ttl=300.0)
|
|
pref = await store.get_preferred_backend("nonexistent")
|
|
self.assertIsNone(pref)
|
|
|
|
async def test_ttl_expiry(self):
|
|
store = SessionStore(ttl=0.05)
|
|
await store.get_or_create("s1")
|
|
await asyncio.sleep(0.1)
|
|
pref = await store.get_preferred_backend("s1")
|
|
self.assertIsNone(pref)
|
|
|
|
async def test_expire_removes_stale(self):
|
|
store = SessionStore(ttl=0.05)
|
|
await store.get_or_create("s1")
|
|
await store.get_or_create("s2")
|
|
await asyncio.sleep(0.1)
|
|
removed = await store.expire()
|
|
self.assertEqual(removed, 2)
|
|
count = await store.count()
|
|
self.assertEqual(count, 0)
|
|
|
|
async def test_count_active(self):
|
|
store = SessionStore(ttl=300.0)
|
|
await store.get_or_create("s1")
|
|
await store.get_or_create("s2")
|
|
count = await store.count()
|
|
self.assertEqual(count, 2)
|
|
|
|
async def test_count_by_model(self):
|
|
store = SessionStore(ttl=300.0)
|
|
await store.get_or_create("s1")
|
|
await store.get_or_create("s2")
|
|
await store.get_or_create("s3")
|
|
await store.update("s1", model_id="m1")
|
|
await store.update("s2", model_id="m1")
|
|
await store.update("s3", model_id="m2")
|
|
by_model = await store.count_by_model()
|
|
self.assertEqual(by_model["m1"], 2)
|
|
self.assertEqual(by_model["m2"], 1)
|
|
|
|
async def test_update_nonexistent_session_no_error(self):
|
|
store = SessionStore(ttl=300.0)
|
|
await store.update("nope", model_id="m1") # should not raise
|
|
|
|
async def test_concurrent_access(self):
|
|
store = SessionStore(ttl=300.0)
|
|
|
|
async def worker(i):
|
|
sid = f"s{i}"
|
|
await store.get_or_create(sid)
|
|
await store.update(sid, model_id="m", preferred_backend=f"http://b{i}")
|
|
|
|
await asyncio.gather(*[worker(i) for i in range(20)])
|
|
count = await store.count()
|
|
self.assertEqual(count, 20)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|