Files
llamacpp-ha/tests/test_session_store.py
2026-05-17 09:54:18 +02:00

113 lines
4.0 KiB
Python

import asyncio
import unittest
from unittest.mock import patch
from llamacpp_ha.session_store import SessionStore, compute_prefix_hash
class TestComputePrefixHash(unittest.TestCase):
def test_same_messages_same_hash(self):
msgs = [{"role": "user", "content": "hello"}]
self.assertEqual(compute_prefix_hash(msgs), compute_prefix_hash(msgs))
def test_different_messages_different_hash(self):
a = [{"role": "user", "content": "hello"}]
b = [{"role": "user", "content": "world"}]
self.assertNotEqual(compute_prefix_hash(a), compute_prefix_hash(b))
def test_empty_messages(self):
h = compute_prefix_hash([])
self.assertIsInstance(h, str)
self.assertEqual(len(h), 16)
class TestSessionStore(unittest.IsolatedAsyncioTestCase):
async def test_get_or_create_new(self):
store = SessionStore(ttl=300.0)
session = await store.get_or_create("abc")
self.assertEqual(session.session_id, "abc")
self.assertIsNone(session.model_id)
async def test_get_or_create_existing(self):
store = SessionStore(ttl=300.0)
s1 = await store.get_or_create("abc")
await store.update("abc", model_id="llama3")
s2 = await store.get_or_create("abc")
self.assertEqual(s2.model_id, "llama3")
async def test_update_model_and_messages(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
msgs = [{"role": "user", "content": "hi"}]
await store.update("s1", model_id="m1", messages=msgs, preferred_backend="http://b1")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b1")
async def test_affinity_hit(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.update("s1", preferred_backend="http://b2")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b2")
async def test_affinity_miss_unknown_session(self):
store = SessionStore(ttl=300.0)
pref = await store.get_preferred_backend("nonexistent")
self.assertIsNone(pref)
async def test_ttl_expiry(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await asyncio.sleep(0.1)
pref = await store.get_preferred_backend("s1")
self.assertIsNone(pref)
async def test_expire_removes_stale(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await store.get_or_create("s2")
await asyncio.sleep(0.1)
removed = await store.expire()
self.assertEqual(removed, 2)
count = await store.count()
self.assertEqual(count, 0)
async def test_count_active(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
count = await store.count()
self.assertEqual(count, 2)
async def test_count_by_model(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
await store.get_or_create("s3")
await store.update("s1", model_id="m1")
await store.update("s2", model_id="m1")
await store.update("s3", model_id="m2")
by_model = await store.count_by_model()
self.assertEqual(by_model["m1"], 2)
self.assertEqual(by_model["m2"], 1)
async def test_update_nonexistent_session_no_error(self):
store = SessionStore(ttl=300.0)
await store.update("nope", model_id="m1") # should not raise
async def test_concurrent_access(self):
store = SessionStore(ttl=300.0)
async def worker(i):
sid = f"s{i}"
await store.get_or_create(sid)
await store.update(sid, model_id="m", preferred_backend=f"http://b{i}")
await asyncio.gather(*[worker(i) for i in range(20)])
count = await store.count()
self.assertEqual(count, 20)
if __name__ == "__main__":
unittest.main()