Files
llamacpp-ha/tests/test_session_store.py
2026-05-17 22:15:13 +02:00

176 lines
6.9 KiB
Python

import asyncio
import unittest
from unittest.mock import patch
from llamacpp_ha.session_store import SessionStore, compute_prefix_hash
class TestComputePrefixHash(unittest.TestCase):
def test_same_messages_same_hash(self):
msgs = [{"role": "user", "content": "hello"}]
self.assertEqual(compute_prefix_hash(msgs), compute_prefix_hash(msgs))
def test_different_messages_different_hash(self):
a = [{"role": "user", "content": "hello"}]
b = [{"role": "user", "content": "world"}]
self.assertNotEqual(compute_prefix_hash(a), compute_prefix_hash(b))
def test_empty_messages(self):
h = compute_prefix_hash([])
self.assertIsInstance(h, str)
self.assertEqual(len(h), 16)
class TestSessionStore(unittest.IsolatedAsyncioTestCase):
async def test_get_or_create_new(self):
store = SessionStore(ttl=300.0)
session = await store.get_or_create("abc")
self.assertEqual(session.session_id, "abc")
self.assertIsNone(session.model_id)
async def test_get_or_create_existing(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("abc")
await store.update("abc", model_id="llama3")
s2 = await store.get_or_create("abc")
self.assertEqual(s2.model_id, "llama3")
async def test_update_model_and_messages(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
msgs = [{"role": "user", "content": "hi"}]
await store.update("s1", model_id="m1", messages=msgs, preferred_backend="http://b1")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b1")
async def test_affinity_hit(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.update("s1", preferred_backend="http://b2")
pref = await store.get_preferred_backend("s1")
self.assertEqual(pref, "http://b2")
async def test_affinity_miss_unknown_session(self):
store = SessionStore(ttl=300.0)
pref = await store.get_preferred_backend("nonexistent")
self.assertIsNone(pref)
async def test_ttl_expiry(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await asyncio.sleep(0.1)
pref = await store.get_preferred_backend("s1")
self.assertIsNone(pref)
async def test_expire_removes_stale(self):
store = SessionStore(ttl=0.05)
await store.get_or_create("s1")
await store.get_or_create("s2")
await asyncio.sleep(0.1)
removed = await store.expire()
self.assertEqual(removed, 2)
count = await store.count()
self.assertEqual(count, 0)
async def test_count_active(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
count = await store.count()
self.assertEqual(count, 2)
async def test_count_by_model(self):
store = SessionStore(ttl=300.0)
await store.get_or_create("s1")
await store.get_or_create("s2")
await store.get_or_create("s3")
await store.update("s1", model_id="m1")
await store.update("s2", model_id="m1")
await store.update("s3", model_id="m2")
by_model = await store.count_by_model()
self.assertEqual(by_model["m1"], 2)
self.assertEqual(by_model["m2"], 1)
async def test_update_nonexistent_session_no_error(self):
store = SessionStore(ttl=300.0)
await store.update("nope", model_id="m1") # should not raise
async def test_find_by_prefix_no_messages(self):
store = SessionStore(ttl=300.0)
result = await store.find_by_prefix([])
self.assertIsNone(result)
async def test_find_by_prefix_exact_match(self):
"""Single-turn session is found when messages match exactly."""
store = SessionStore(ttl=300.0)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs, preferred_backend="http://b1")
result = await store.find_by_prefix(msgs)
self.assertEqual(result, "http://b1")
async def test_find_by_prefix_continuation(self):
"""Turn-1 session is found when turn-2 messages extend the stored prefix."""
store = SessionStore(ttl=300.0)
turn1 = [{"role": "user", "content": "hello"}]
turn2 = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "tell me more"},
]
await store.update("s1", messages=turn1, preferred_backend="http://b1")
result = await store.find_by_prefix(turn2)
self.assertEqual(result, "http://b1")
async def test_find_by_prefix_prefers_longest_match(self):
"""When multiple sessions match, the one with the longer stored prefix wins."""
store = SessionStore(ttl=300.0)
turn1 = [{"role": "user", "content": "hello"}]
turn2 = [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]
turn3 = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "more"},
]
await store.update("s1", messages=turn1, preferred_backend="http://b1")
await store.update("s2", messages=turn2, preferred_backend="http://b2")
result = await store.find_by_prefix(turn3)
self.assertEqual(result, "http://b2") # longer match wins
async def test_find_by_prefix_no_match(self):
store = SessionStore(ttl=300.0)
await store.update("s1", messages=[{"role": "user", "content": "hello"}], preferred_backend="http://b1")
result = await store.find_by_prefix([{"role": "user", "content": "completely different"}])
self.assertIsNone(result)
async def test_find_by_prefix_ignores_expired(self):
store = SessionStore(ttl=0.05)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs, preferred_backend="http://b1")
import asyncio
await asyncio.sleep(0.1)
result = await store.find_by_prefix(msgs)
self.assertIsNone(result)
async def test_find_by_prefix_ignores_session_without_backend(self):
store = SessionStore(ttl=300.0)
msgs = [{"role": "user", "content": "hello"}]
await store.update("s1", messages=msgs) # no preferred_backend
result = await store.find_by_prefix(msgs)
self.assertIsNone(result)
async def test_concurrent_access(self):
store = SessionStore(ttl=300.0)
async def worker(i):
sid = f"s{i}"
await store.get_or_create(sid)
await store.update(sid, model_id="m", preferred_backend=f"http://b{i}")
await asyncio.gather(*[worker(i) for i in range(20)])
count = await store.count()
self.assertEqual(count, 20)
if __name__ == "__main__":
unittest.main()