135 lines
4.8 KiB
Python
135 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
session_id: str
|
|
model_id: str | None = None
|
|
last_message_index: int = 0
|
|
prefix_hash: str = ""
|
|
preferred_backend: str | None = None
|
|
last_active: float = field(default_factory=time.monotonic)
|
|
|
|
def touch(self) -> None:
|
|
self.last_active = time.monotonic()
|
|
|
|
def is_expired(self, ttl: float) -> bool:
|
|
return (time.monotonic() - self.last_active) > ttl
|
|
|
|
|
|
def compute_prefix_hash(messages: list[dict]) -> str:
|
|
"""Hash the messages list for KV-cache affinity tracking."""
|
|
blob = json.dumps(messages, sort_keys=True, separators=(",", ":")).encode()
|
|
return hashlib.sha256(blob).hexdigest()[:16]
|
|
|
|
|
|
class SessionStore:
|
|
def __init__(self, ttl: float = 300.0) -> None:
|
|
self._ttl = ttl
|
|
self._sessions: dict[str, Session] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
def new_session_id(self) -> str:
|
|
return uuid.uuid4().hex
|
|
|
|
async def get_or_create(self, session_id: str) -> Session:
|
|
async with self._lock:
|
|
if session_id not in self._sessions:
|
|
self._sessions[session_id] = Session(session_id=session_id)
|
|
session = self._sessions[session_id]
|
|
session.touch()
|
|
return session
|
|
|
|
async def update(
|
|
self,
|
|
session_id: str,
|
|
model_id: str | None = None,
|
|
messages: list[dict] | None = None,
|
|
preferred_backend: str | None = None,
|
|
) -> None:
|
|
async with self._lock:
|
|
if session_id not in self._sessions:
|
|
self._sessions[session_id] = Session(session_id=session_id)
|
|
session = self._sessions[session_id]
|
|
if model_id is not None:
|
|
session.model_id = model_id
|
|
if messages is not None:
|
|
session.last_message_index = len(messages)
|
|
session.prefix_hash = compute_prefix_hash([messages[-1]])
|
|
if preferred_backend is not None:
|
|
session.preferred_backend = preferred_backend
|
|
session.touch()
|
|
|
|
async def get_preferred_backend(self, session_id: str) -> str | None:
|
|
async with self._lock:
|
|
session = self._sessions.get(session_id)
|
|
if session is None or session.is_expired(self._ttl):
|
|
return None
|
|
return session.preferred_backend
|
|
|
|
async def expire(self) -> int:
|
|
"""Remove expired sessions. Returns count removed."""
|
|
async with self._lock:
|
|
expired = [
|
|
sid
|
|
for sid, s in self._sessions.items()
|
|
if s.is_expired(self._ttl)
|
|
]
|
|
for sid in expired:
|
|
del self._sessions[sid]
|
|
return len(expired)
|
|
|
|
async def count(self) -> int:
|
|
async with self._lock:
|
|
return sum(
|
|
1 for s in self._sessions.values() if not s.is_expired(self._ttl)
|
|
)
|
|
|
|
async def count_by_model(self) -> dict[str, int]:
|
|
async with self._lock:
|
|
result: dict[str, int] = {}
|
|
for s in self._sessions.values():
|
|
if not s.is_expired(self._ttl) and s.model_id:
|
|
result[s.model_id] = result.get(s.model_id, 0) + 1
|
|
return result
|
|
|
|
def _prefix_match_length(self, s: Session, messages: list[dict]) -> int:
|
|
"""Return the matched prefix length for s against messages, or 0 if no match."""
|
|
k = s.last_message_index
|
|
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
|
|
return 0
|
|
return k if compute_prefix_hash([messages[k - 1]]) == s.prefix_hash else 0
|
|
|
|
async def find_by_prefix(self, messages: list[dict]) -> str | None:
|
|
"""Return the preferred backend whose stored conversation is a prefix of messages.
|
|
|
|
Checks whether hash(messages[:k]) equals a session's prefix_hash (k is
|
|
that session's last_message_index). Returns the preferred_backend of the
|
|
longest matching session, or None. This lets clients that omit the
|
|
session cookie still land on the backend holding their KV-cache.
|
|
"""
|
|
if not messages:
|
|
return None
|
|
best: Session | None = None
|
|
async with self._lock:
|
|
for s in self._sessions.values():
|
|
k = self._prefix_match_length(s, messages)
|
|
if k > 0 and (best is None or k > best.last_message_index):
|
|
best = s
|
|
if best is not None:
|
|
best.touch()
|
|
return best.preferred_backend if best is not None else None
|
|
|
|
async def snapshot(self) -> list[Session]:
|
|
async with self._lock:
|
|
return [
|
|
s for s in self._sessions.values() if not s.is_expired(self._ttl)
|
|
]
|