Files
llamacpp-ha/src/llamacpp_ha/session_store.py
2026-05-18 01:02:57 +02:00

135 lines
4.8 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import json
import time
import uuid
from dataclasses import dataclass, field
@dataclass
class Session:
session_id: str
model_id: str | None = None
last_message_index: int = 0
prefix_hash: str = ""
preferred_backend: str | None = None
last_active: float = field(default_factory=time.monotonic)
def touch(self) -> None:
self.last_active = time.monotonic()
def is_expired(self, ttl: float) -> bool:
return (time.monotonic() - self.last_active) > ttl
def compute_prefix_hash(messages: list[dict]) -> str:
"""Hash the messages list for KV-cache affinity tracking."""
blob = json.dumps(messages, sort_keys=True, separators=(",", ":")).encode()
return hashlib.sha256(blob).hexdigest()[:16]
class SessionStore:
def __init__(self, ttl: float = 300.0) -> None:
self._ttl = ttl
self._sessions: dict[str, Session] = {}
self._lock = asyncio.Lock()
def new_session_id(self) -> str:
return uuid.uuid4().hex
async def get_or_create(self, session_id: str) -> Session:
async with self._lock:
if session_id not in self._sessions:
self._sessions[session_id] = Session(session_id=session_id)
session = self._sessions[session_id]
session.touch()
return session
async def update(
self,
session_id: str,
model_id: str | None = None,
messages: list[dict] | None = None,
preferred_backend: str | None = None,
) -> None:
async with self._lock:
if session_id not in self._sessions:
self._sessions[session_id] = Session(session_id=session_id)
session = self._sessions[session_id]
if model_id is not None:
session.model_id = model_id
if messages is not None:
session.last_message_index = len(messages)
session.prefix_hash = compute_prefix_hash([messages[-1]])
if preferred_backend is not None:
session.preferred_backend = preferred_backend
session.touch()
async def get_preferred_backend(self, session_id: str) -> str | None:
async with self._lock:
session = self._sessions.get(session_id)
if session is None or session.is_expired(self._ttl):
return None
return session.preferred_backend
async def expire(self) -> int:
"""Remove expired sessions. Returns count removed."""
async with self._lock:
expired = [
sid
for sid, s in self._sessions.items()
if s.is_expired(self._ttl)
]
for sid in expired:
del self._sessions[sid]
return len(expired)
async def count(self) -> int:
async with self._lock:
return sum(
1 for s in self._sessions.values() if not s.is_expired(self._ttl)
)
async def count_by_model(self) -> dict[str, int]:
async with self._lock:
result: dict[str, int] = {}
for s in self._sessions.values():
if not s.is_expired(self._ttl) and s.model_id:
result[s.model_id] = result.get(s.model_id, 0) + 1
return result
def _prefix_match_length(self, s: Session, messages: list[dict]) -> int:
"""Return the matched prefix length for s against messages, or 0 if no match."""
k = s.last_message_index
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
return 0
return k if compute_prefix_hash([messages[k - 1]]) == s.prefix_hash else 0
async def find_by_prefix(self, messages: list[dict]) -> str | None:
"""Return the preferred backend whose stored conversation is a prefix of messages.
Checks whether hash(messages[:k]) equals a session's prefix_hash (k is
that session's last_message_index). Returns the preferred_backend of the
longest matching session, or None. This lets clients that omit the
session cookie still land on the backend holding their KV-cache.
"""
if not messages:
return None
best: Session | None = None
async with self._lock:
for s in self._sessions.values():
k = self._prefix_match_length(s, messages)
if k > 0 and (best is None or k > best.last_message_index):
best = s
if best is not None:
best.touch()
return best.preferred_backend if best is not None else None
async def snapshot(self) -> list[Session]:
async with self._lock:
return [
s for s in self._sessions.values() if not s.is_expired(self._ttl)
]