Compare commits

...

17 Commits

Author SHA1 Message Date
28a6aa739d again 2026-05-19 22:00:26 +02:00
6e35208a66 improve backend sharing across models 2026-05-19 21:38:07 +02:00
579185654b polish 2026-05-19 21:21:42 +02:00
3c536241f6 integration 2026-05-19 21:11:35 +02:00
aad8854393 bump version 2026-05-19 20:57:57 +02:00
b379faebdb change scheduling strategy to hibrid priority queue with aging 2026-05-19 20:57:38 +02:00
0f5aabbf15 bump version 2026-05-18 01:03:06 +02:00
7cf16dcace improve cache 2026-05-18 01:02:57 +02:00
5211b2f1a0 bump version 2026-05-18 00:43:53 +02:00
13fb341354 feat: model unload 2026-05-18 00:34:27 +02:00
bcebaf0e93 improve monitoring 2026-05-18 00:25:10 +02:00
981272e6ca fix: audio streaming 2026-05-18 00:12:57 +02:00
71b3b20a4d bump version 2026-05-17 23:41:45 +02:00
8120357034 UI: show loaded model 2026-05-17 23:40:53 +02:00
47d4b4e4fc scheduler : improve backend usage 2026-05-17 23:38:43 +02:00
68b5229621 update license 2026-05-17 23:20:04 +02:00
8f54fc1740 add license 2026-05-17 23:16:06 +02:00
20 changed files with 733 additions and 157 deletions

View File

@@ -3,7 +3,12 @@
"allow": [
"Bash(python -m pytest)",
"Bash(python -m pytest --tb=short -q)",
"Bash(python -m pytest tests/test_slot_tracker.py -v)"
"Bash(python -m pytest tests/test_slot_tracker.py -v)",
"Bash(python -m pytest tests/test_scheduler.py -v)",
"Bash(python -m pytest tests/test_monitor.py::TestMonitorEndpoints::test_monitor_data_structure -v)",
"Bash(python -m pytest -q)",
"Bash(python -m pytest tests/test_config.py -q)",
"Bash(python -m pytest -x -q)"
]
}
}

21
LICENSE.md Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 chacha
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. |
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
### Per-backend fields

View File

@@ -7,7 +7,9 @@
"session_idle_ttl": 300,
"default_slot_capacity": 1,
"default_max_models": 1,
"max_queue_skip": 4,
"model_affinity_sched_bonus": 10,
"queue_aging_equalization": 30.0,
"model_unload_delay": 3.0,
"model_limits": {
"my-very-large-model": 1
},

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "llamacpp-ha"
version = "0.2.0"
version = "0.11.0"
description = "Smart load balancer for llama.cpp servers"
requires-python = ">=3.13"
dependencies = [

View File

@@ -38,12 +38,19 @@ class ProxyConfig(BaseSettings):
# Fallback max_models applied to any backend that does not set its own.
# None = unlimited. Set to 1 globally when all backends are llama.cpp.
default_max_models: int | None = None
# How many queue positions a model-affinity request may skip ahead.
# 0 = pure FIFO (default). N > 0 enables reordering: the scheduler looks
# for a request matching an already-active model and can promote it up to N
# positions; each bypassed entry accumulates a skip count and is immune to
# further skipping once it reaches N.
max_queue_skip: int = 0
# Priority bonus added to requests whose model is currently warm (active or in
# the warm-hold window) on an available backend. 0 = pure FIFO.
# Works together with queue_aging_equalization: after that many seconds a
# waiting request's age score catches up to the bonus, restoring FIFO order.
model_affinity_sched_bonus: int = 0
# Seconds after which an aging request's accumulated score equals
# model_affinity_sched_bonus, ensuring starvation is impossible.
# Only meaningful when model_affinity_sched_bonus > 0.
queue_aging_equalization: float = 30.0
# Seconds to keep a backend sticky to its last model after all slots drain.
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
# generation) that arrive shortly after the main response. 0 = disabled.
model_unload_delay: float = 3.0
# Per-model global concurrency cap across all backends.
# Use for very large models that cannot run concurrently due to RAM constraints.
model_limits: dict[str, int] = {}

View File

@@ -16,6 +16,14 @@ from .slot_tracker import SlotTracker
log = logging.getLogger(__name__)
_STREAM_CHUNK = 8192
_STREAMING_CONTENT_TYPES = (
"text/event-stream",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"audio/webm",
"audio/aac",
)
_HOP_BY_HOP = frozenset(
[
"connection",
@@ -74,7 +82,7 @@ async def forward_request(
resp = await resp_ctx.__aenter__()
content_type = resp.headers.get("Content-Type", "")
is_streaming = "text/event-stream" in content_type
is_streaming = any(ct in content_type for ct in _STREAMING_CONTENT_TYPES)
response_headers = {
k: v

View File

@@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/"])
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/", "/health", "/health/"])
class ApiKeyMiddleware(BaseHTTPMiddleware):

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from fastapi import APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
from .queue import RequestQueue
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry
from .session_store import SessionStore
from .slot_tracker import SlotTracker
@@ -17,10 +19,44 @@ class ProxyStats:
"""Per-app counters and timing. Created by create_app, passed to build_router."""
start_time: float = field(default_factory=time.monotonic)
total_requests: int = 0
session_hits: int = 0
session_misses: int = 0
new_sessions: int = 0
model_requests: dict[str, int] = field(default_factory=dict)
model_tokens: dict[str, int] = field(default_factory=dict)
backend_requests: dict[str, int] = field(default_factory=dict)
backend_session_hits: dict[str, int] = field(default_factory=dict)
backend_session_misses: dict[str, int] = field(default_factory=dict)
def increment_requests(self) -> None:
self.total_requests += 1
def record_model(self, model_id: str, tokens: int | None) -> None:
if not model_id:
return
self.model_requests[model_id] = self.model_requests.get(model_id, 0) + 1
if tokens:
self.model_tokens[model_id] = self.model_tokens.get(model_id, 0) + tokens
def record_backend(self, url: str) -> None:
self.backend_requests[url] = self.backend_requests.get(url, 0) + 1
def record_session(self, preferred_url: str | None, actual_url: str) -> None:
if preferred_url:
if actual_url == preferred_url:
self.session_hits += 1
self.backend_session_hits[actual_url] = self.backend_session_hits.get(actual_url, 0) + 1
else:
self.session_misses += 1
# Count miss against the preferred backend (the one that was expected but missed).
self.backend_session_misses[preferred_url] = self.backend_session_misses.get(preferred_url, 0) + 1
else:
self.new_sessions += 1
def session_hit_rate(self) -> int | None:
total = self.session_hits + self.session_misses
return round(self.session_hits / total * 100) if total else None
def uptime_str(self) -> str:
secs = int(time.monotonic() - self.start_time)
h, remainder = divmod(secs, 3600)
@@ -48,11 +84,14 @@ _HTML = """<!DOCTYPE html>
.badge-dead { background: #2c0f0f; color: #f85149; }
.slots { color: #d29922; }
.empty { color: #484f58; font-style: italic; }
.hit { color: #3fb950; }
.miss { color: #f85149; }
#status { float: right; font-size: 0.8em; color: #8b949e; }
.summary { display: flex; gap: 20px; flex-wrap: wrap; margin: 10px 0 20px; }
.stat { background: #161b22; border: 1px solid #30363d; border-radius: 6px; padding: 10px 16px; }
.stat-val { font-size: 1.6em; color: #58a6ff; }
.stat-label { font-size: 0.75em; color: #8b949e; margin-top: 2px; }
.num { text-align: right; }
</style>
</head>
<body>
@@ -65,24 +104,31 @@ _HTML = """<!DOCTYPE html>
<div class="stat"><div class="stat-val" id="queue-depth">-</div><div class="stat-label">Queue Depth</div></div>
<div class="stat"><div class="stat-val" id="session-count">-</div><div class="stat-label">Active Sessions</div></div>
<div class="stat"><div class="stat-val" id="live-count">-</div><div class="stat-label">Live Backends</div></div>
<div class="stat"><div class="stat-val" id="hit-rate">-</div><div class="stat-label">Session Hit Rate</div></div>
</div>
<h2>Backends</h2>
<table>
<thead><tr><th>URL</th><th>Status</th><th>Models</th><th>Slots</th><th>Last Poll</th></tr></thead>
<tbody id="backends-body"><tr><td colspan="5" class="empty">Loading...</td></tr></tbody>
<thead><tr><th>URL</th><th>Status</th><th>Active Model</th><th>Models</th><th>Slots</th><th class="num">Requests</th><th>Last Poll</th></tr></thead>
<tbody id="backends-body"><tr><td colspan="7" class="empty">Loading...</td></tr></tbody>
</table>
<h2>Queue</h2>
<table>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th><th>Skips</th></tr></thead>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Age</th><th>Est. Tokens</th><th class="num">Priority</th></tr></thead>
<tbody id="queue-body"><tr><td colspan="6" class="empty">Queue is empty</td></tr></tbody>
</table>
<h2>Sessions by Model</h2>
<h2>Model Stats</h2>
<table>
<thead><tr><th>Model</th><th>Active Sessions</th></tr></thead>
<tbody id="sessions-body"><tr><td colspan="2" class="empty">No active sessions</td></tr></tbody>
<thead><tr><th>Model</th><th class="num">Requests</th><th class="num">Est. Tokens</th><th class="num">Active Sessions</th></tr></thead>
<tbody id="model-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
</table>
<h2>Backend Stats</h2>
<table>
<thead><tr><th>Backend</th><th class="num">Requests</th><th class="num">Share</th><th>Session Affinity</th></tr></thead>
<tbody id="backend-stats-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
</table>
<script>
@@ -91,6 +137,17 @@ _HTML = """<!DOCTYPE html>
return String(s).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;');
}
function fmt(n) {
return n >= 1000 ? (n/1000).toFixed(1) + 'k' : String(n);
}
function fmtAge(s) {
if (s < 60) return s.toFixed(1) + 's';
const m = Math.floor(s / 60);
const r = Math.floor(s % 60);
return m + 'm ' + String(r).padStart(2, '0') + 's';
}
function render(data) {
document.getElementById('uptime').textContent = data.uptime;
document.getElementById('total-req').textContent = data.total_requests;
@@ -98,18 +155,32 @@ _HTML = """<!DOCTYPE html>
document.getElementById('session-count').textContent = data.session_count;
document.getElementById('live-count').textContent = data.live_backend_count;
const hr = data.session_hit_rate;
const hrEl = document.getElementById('hit-rate');
if (hr == null) {
hrEl.textContent = 'N/A';
hrEl.className = 'stat-val';
} else {
hrEl.textContent = hr + '%';
hrEl.className = 'stat-val ' + (hr >= 80 ? 'hit' : hr >= 50 ? 'slots' : 'miss');
}
const bBody = document.getElementById('backends-body');
if (!data.backends.length) {
bBody.innerHTML = '<tr><td colspan="5" class="empty">No backends configured</td></tr>';
bBody.innerHTML = '<tr><td colspan="7" class="empty">No backends configured</td></tr>';
} else {
bBody.innerHTML = data.backends.map(b => {
const badge = b.live
? '<span class="badge badge-live">live</span>'
: '<span class="badge badge-dead">dead</span>';
const active = b.active_models.length
? b.active_models.map(m => `<span class="badge badge-live">${esc(m)}</span>`).join(' ')
: '<span class="empty">idle</span>';
const models = b.models.length ? esc(b.models.join(', ')) : '<span class="empty">none</span>';
const slots = `<span class="slots">${b.slots_acquired}/${b.slots_total}</span>`;
const age = b.last_poll_age == null ? '<span class="empty">never</span>' : esc(b.last_poll_age.toFixed(1)) + 's';
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${models}</td><td>${slots}</td><td>${age}</td></tr>`;
const reqs = b.requests > 0 ? fmt(b.requests) : '<span class="empty">0</span>';
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${active}</td><td>${models}</td><td>${slots}</td><td class="num">${reqs}</td><td>${age}</td></tr>`;
}).join('');
}
@@ -120,19 +191,41 @@ _HTML = """<!DOCTYPE html>
qBody.innerHTML = data.queue.map(e => {
const tok = e.estimated_tokens != null ? esc(e.estimated_tokens) : '<span class="empty">-</span>';
const sid = e.session_id ? esc(e.session_id) : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td><td>${esc(e.skip_count)}</td></tr>`;
const pri = e.priority != null ? `<span class="slots">${e.priority.toFixed(1)}</span>` : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${fmtAge(e.wait_seconds)}</td><td>${tok}</td><td class="num">${pri}</td></tr>`;
}).join('');
}
const sBody = document.getElementById('sessions-body');
const sbm = data.sessions_by_model;
const keys = Object.keys(sbm);
if (!keys.length) {
sBody.innerHTML = '<tr><td colspan="2" class="empty">No active sessions</td></tr>';
const mBody = document.getElementById('model-body');
const ms = data.model_stats;
const mKeys = Object.keys(ms).sort((a,b) => ms[b].requests - ms[a].requests);
if (!mKeys.length) {
mBody.innerHTML = '<tr><td colspan="4" class="empty">No data yet</td></tr>';
} else {
sBody.innerHTML = keys.map(m =>
`<tr><td>${esc(m)}</td><td>${esc(sbm[m])}</td></tr>`
).join('');
mBody.innerHTML = mKeys.map(m => {
const s = ms[m];
const tok = s.estimated_tokens > 0 ? fmt(s.estimated_tokens) : '<span class="empty">-</span>';
const sess = s.active_sessions > 0 ? s.active_sessions : '<span class="empty">0</span>';
return `<tr><td>${esc(m)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${tok}</td><td class="num">${sess}</td></tr>`;
}).join('');
}
const bsBody = document.getElementById('backend-stats-body');
const bs = data.backend_stats;
const bsKeys = Object.keys(bs).sort((a,b) => bs[b].requests - bs[a].requests);
if (!bsKeys.length || data.total_requests === 0) {
bsBody.innerHTML = '<tr><td colspan="4" class="empty">No data yet</td></tr>';
} else {
bsBody.innerHTML = bsKeys.map(url => {
const s = bs[url];
const share = data.total_requests > 0
? Math.round(s.requests / data.total_requests * 100) + '%'
: '<span class="empty">-</span>';
const affinity = s.session_hits + s.session_misses > 0
? Math.round(s.session_hits / (s.session_hits + s.session_misses) * 100) + '% hit'
: '<span class="empty">-</span>';
return `<tr><td>${esc(url)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${share}</td><td>${affinity}</td></tr>`;
}).join('');
}
document.getElementById('status').textContent = 'updated ' + new Date().toLocaleTimeString();
@@ -162,6 +255,7 @@ def build_router(
request_queue: RequestQueue,
session_store: SessionStore,
stats: ProxyStats,
priority_fn: Callable[[QueueEntry], float] | None = None,
) -> APIRouter:
router = APIRouter()
@@ -180,18 +274,41 @@ def build_router(
{
"url": state.url,
"live": state.live,
"active_models": sorted(slot_tracker.active_model_set(state.url)),
"models": list(state.models),
"slots_acquired": acquired,
"slots_total": total,
"requests": stats.backend_requests.get(state.url, 0),
"last_poll_age": None if age == float("inf") else round(age, 1),
}
)
queue_snapshot = await request_queue.snapshot()
queue_snapshot = await request_queue.snapshot(score_fn=priority_fn)
session_count = await session_store.count()
sessions_by_model = await session_store.count_by_model()
live_count = sum(1 for s in states if s.live)
# Merge per-model request stats with active session counts.
all_models = set(stats.model_requests) | set(sessions_by_model)
model_stats = {
m: {
"requests": stats.model_requests.get(m, 0),
"estimated_tokens": stats.model_tokens.get(m, 0),
"active_sessions": sessions_by_model.get(m, 0),
}
for m in all_models
}
# Per-backend cumulative stats with session affinity breakdown.
backend_stats = {
url: {
"requests": count,
"session_hits": stats.backend_session_hits.get(url, 0),
"session_misses": stats.backend_session_misses.get(url, 0),
}
for url, count in stats.backend_requests.items()
}
return JSONResponse(
{
"uptime": stats.uptime_str(),
@@ -199,9 +316,13 @@ def build_router(
"queue_depth": len(queue_snapshot),
"session_count": session_count,
"live_backend_count": live_count,
"session_hits": stats.session_hits,
"session_misses": stats.session_misses,
"session_hit_rate": stats.session_hit_rate(),
"backends": backends_data,
"queue": queue_snapshot,
"sessions_by_model": sessions_by_model,
"model_stats": model_stats,
"backend_stats": backend_stats,
}
)

View File

@@ -96,7 +96,9 @@ def _init_slot_tracker(
registry: BackendRegistry,
slot_tracker: SlotTracker,
default_max_models: int | None,
model_unload_delay: float,
) -> None:
slot_tracker.set_model_unload_delay(model_unload_delay)
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
effective_max = (
@@ -279,6 +281,8 @@ async def _inference_endpoint(
if not incoming_session_id:
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
preferred_url = await session_store.get_preferred_backend(session_id)
result = await _dispatch_entry(
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
)
@@ -311,6 +315,10 @@ async def _inference_endpoint(
tried.add(failover.url)
backend = failover
stats.record_model(model_id, _estimate_tokens(body))
stats.record_backend(backend.url)
stats.record_session(preferred_url, backend.url)
if model_id:
messages = body.get("messages", [])
await session_store.update(
@@ -346,7 +354,9 @@ def create_app(config: ProxyConfig) -> FastAPI:
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=config.max_queue_skip,
model_affinity_sched_bonus=config.model_affinity_sched_bonus,
queue_aging_equalization=config.queue_aging_equalization,
wakeup_interval=config.model_unload_delay if config.model_unload_delay > 0 else 5.0,
)
@asynccontextmanager
@@ -355,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
timeout=aiohttp.ClientTimeout(total=300),
connector=aiohttp.TCPConnector(ssl=False),
)
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
_init_global_model_limits(config, slot_tracker)
registry.start()
scheduler.start()
@@ -383,6 +393,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
request_queue=request_queue,
session_store=session_store,
stats=stats,
priority_fn=scheduler._priority if config.model_affinity_sched_bonus > 0 else None,
)
)

View File

@@ -17,10 +17,6 @@ class QueueEntry:
future: asyncio.Future | None = field(default=None)
# Populated by scheduler at dispatch time
assigned_backend: str | None = None
# Incremented each time a later entry is dispatched ahead of this one via
# model-affinity reordering. When it reaches max_queue_skip the entry
# becomes immune to further skipping (starvation prevention).
skip_count: int = 0
@property
def wait_seconds(self) -> float:
@@ -59,19 +55,25 @@ class RequestQueue:
async with self._lock:
return len(self._entries)
async def snapshot(self) -> list[dict]:
async def snapshot(self, score_fn=None) -> list[dict]:
async with self._lock:
return [
{
entries = (
sorted(self._entries, key=score_fn, reverse=True)
if score_fn
else self._entries
)
rows = []
for e in entries:
d: dict = {
"request_id": e.request_id,
"model_id": e.model_id,
"session_id": (e.session_id or "")[:8] or None,
"wait_seconds": round(e.wait_seconds, 2),
"estimated_tokens": e.estimated_tokens,
"skip_count": e.skip_count,
"priority": round(score_fn(e), 1) if score_fn else None,
}
for e in self._entries
]
rows.append(d)
return rows
def notify(self) -> None:
"""Wake the scheduler without holding the lock."""

View File

@@ -139,8 +139,8 @@ class BackendRegistry:
models = state.models
capacity = state.slot_capacity
was_live = state.live
async with self._lock:
was_live = state.live
state.live = live
if live:
state.models = models

View File

@@ -18,36 +18,41 @@ class Scheduler:
Dispatch order
--------------
When max_queue_skip == 0 (default) the queue is pure FIFO.
When model_affinity_sched_bonus == 0 (default) the queue is pure FIFO.
When max_queue_skip > N the scheduler runs a two-phase dispatch on every
wakeup:
When model_affinity_sched_bonus > 0 each dispatch cycle runs two passes:
Phase 1 — model-affinity promotion
The scheduler scans the queue looking for entries whose model is already
in-flight on a free backend. It promotes those entries ahead of earlier
entries that would require a model switch. For every entry that is
bypassed, its skip_count is incremented. Scanning stops as soon as an
entry with skip_count >= max_queue_skip is encountered (that entry is
frozen at the head and must be served next, preventing starvation).
1. Diversity pass — for each distinct model that has NO active or warm
backend, the highest-priority entry for that model is dispatched first.
This guarantees that N distinct models in the queue will occupy N
backends (up to available capacity) before any model gets a second slot.
Respects SlotTracker.can_accept() so max_models still applies.
Phase 2 — standard FIFO
Remaining entries are dispatched in arrival order to any backend that
can accept the model (slot free, max_models constraint satisfied).
2. Priority pass — remaining entries are dispatched by effective priority:
priority = warm_bonus + age_score
warm_bonus = model_affinity_sched_bonus when the requested model is
currently warm (active or in warm-hold window) on any
backend; 0 otherwise.
age_score = wait_seconds × (bonus / queue_aging_equalization)
grows linearly with queue time. After
queue_aging_equalization seconds a cold request's
age_score equals the warm bonus, guaranteeing starvation
is impossible.
Python's sort is stable, so equal-priority entries stay in arrival order
(FIFO within the same priority band).
Session affinity
----------------
Session affinity is applied inside _resolve_backend as a hint: if the
preferred backend is in the candidate set it is chosen first. Affinity
is re-pinned to whichever backend ultimately serves the request; it is
never queued on a preferred backend when a free alternative exists.
preferred backend is in the candidate set it is chosen first.
Preemption prevention
---------------------
SlotTracker.can_accept() enforces the per-backend max_models limit. A
backend with max_models=1 that is serving model-A will reject model-B
requests until all model-A slots are released, preventing llama.cpp from
preempting the in-flight generation.
SlotTracker.can_accept() enforces the per-backend max_models limit.
"""
def __init__(
@@ -57,14 +62,24 @@ class Scheduler:
slot_tracker: SlotTracker,
session_store: SessionStore,
policy: RoutingPolicy,
max_queue_skip: int = 0,
model_affinity_sched_bonus: int = 0,
queue_aging_equalization: float = 30.0,
wakeup_interval: float = 5.0,
) -> None:
self._queue = queue
self._registry = registry
self._slots = slot_tracker
self._sessions = session_store
self._policy = policy
self._max_queue_skip = max_queue_skip
self._affinity_bonus = model_affinity_sched_bonus
# Points per second so that age_score == bonus after queue_aging_equalization s.
self._aging_rate: float = (
model_affinity_sched_bonus / queue_aging_equalization
if model_affinity_sched_bonus > 0 and queue_aging_equalization > 0
else 0.0
)
# Periodic retry covers cases where no explicit wakeup fires (e.g. sticky-window expiry).
self._wakeup_interval = max(wakeup_interval, 1.0)
self._task: asyncio.Task | None = None
def notify_slot_released(self) -> None:
@@ -81,10 +96,37 @@ class Scheduler:
async def _loop(self) -> None:
while True:
await self._queue.wakeup_event.wait()
try:
async with asyncio.timeout(self._wakeup_interval):
await self._queue.wakeup_event.wait()
except TimeoutError:
pass
self._queue.wakeup_event.clear()
await self._dispatch_all()
# ------------------------------------------------------------------
# Priority
# ------------------------------------------------------------------
def _priority(self, entry: QueueEntry) -> float:
age_score = entry.wait_seconds * self._aging_rate
warm_bonus = self._affinity_bonus if self._is_warm_for(entry) else 0
return age_score + warm_bonus
def _is_warm_for(self, entry: QueueEntry) -> bool:
"""True if the requested model is warm (active or in hold window) on any backend.
Slot availability is intentionally not checked here: the bonus reflects that
the model is already in GPU memory, so the request should sit at high priority
before a slot opens rather than only after one becomes free.
"""
if not entry.model_id:
return False
return any(
self._slots.is_warm(b.url, entry.model_id)
for b in self._registry.get_backends_for_model(entry.model_id)
)
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
@@ -92,7 +134,6 @@ class Scheduler:
async def _dispatch_all(self) -> None:
entries = await self._queue.pending()
# Prune stale futures first so they don't count in skip logic.
for entry in entries:
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
@@ -101,60 +142,48 @@ class Scheduler:
if not entries:
return
dispatched: set[int] = set()
if self._affinity_bonus > 0:
entries = sorted(entries, key=self._priority, reverse=True)
entries = await self._dispatch_unrepresented(entries)
if self._max_queue_skip > 0:
await self._affinity_pass(entries, dispatched)
# Standard FIFO pass for all remaining live entries.
for entry in entries:
if id(entry) in dispatched:
continue
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
continue
if await self._try_dispatch(entry):
dispatched.add(id(entry))
await self._queue.remove(entry)
async def _affinity_pass(
self, entries: list[QueueEntry], dispatched: set[int]
) -> None:
"""Phase 1: promote entries whose model is already active on a free backend.
async def _dispatch_unrepresented(
self, entries: list[QueueEntry]
) -> list[QueueEntry]:
"""Diversity pass: dispatch one entry per model with no active/warm backend.
Stops scanning as soon as an entry with skip_count >= max_queue_skip is
encountered — that entry is frozen and must be served by the FIFO pass.
Processes entries in caller-supplied priority order. Returns entries
that were not dispatched here, preserving their relative order.
"""
for i, entry in enumerate(entries):
if id(entry) in dispatched:
dispatched_models: set[str] = set()
remaining: list[QueueEntry] = []
for entry in entries:
model = entry.model_id
if not model or self._is_warm_for(entry) or model in dispatched_models:
remaining.append(entry)
continue
if entry.skip_count >= self._max_queue_skip:
break # frozen head-of-line; FIFO pass must handle it next
if await self._try_dispatch_affinity(entry):
dispatched.add(id(entry))
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
# Bump skip_count for every earlier entry we bypassed.
for j in range(i):
if id(entries[j]) not in dispatched:
entries[j].skip_count += 1
async def _try_dispatch_affinity(self, entry: QueueEntry) -> bool:
"""Dispatch only to a backend that already has entry.model_id in-flight."""
if not entry.model_id:
return False
live_backends = self._registry.get_backends_for_model(entry.model_id)
active_backends = [
b for b in live_backends
if entry.model_id in self._slots.active_model_set(b.url)
and self._slots.can_accept(b.url, entry.model_id)
]
if not active_backends:
return False
return await self._acquire_and_resolve(entry, active_backends)
continue
# Clear any sticky window that is the sole barrier for this model so that
# a continuously-renewing sticky (model A keeps re-queuing) cannot starve model B.
for b in self._registry.get_backends_for_model(model):
self._slots.waive_sticky_if_idle(b.url, model)
if await self._try_dispatch(entry):
await self._queue.remove(entry)
dispatched_models.add(model)
else:
remaining.append(entry)
return remaining
async def _try_dispatch(self, entry: QueueEntry) -> bool:
"""Standard dispatch: any live backend that can accept this model."""
"""Dispatch to any live backend that can accept this model."""
if entry.model_id:
live_backends = self._registry.get_backends_for_model(entry.model_id)
else:

View File

@@ -60,9 +60,9 @@ class SessionStore:
session = self._sessions[session_id]
if model_id is not None:
session.model_id = model_id
if messages is not None:
if messages:
session.last_message_index = len(messages)
session.prefix_hash = compute_prefix_hash(messages)
session.prefix_hash = compute_prefix_hash([messages[-1]])
if preferred_backend is not None:
session.preferred_backend = preferred_backend
session.touch()
@@ -105,7 +105,7 @@ class SessionStore:
k = s.last_message_index
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
return 0
return k if compute_prefix_hash(messages[:k]) == s.prefix_hash else 0
return k if compute_prefix_hash([messages[k - 1]]) == s.prefix_hash else 0
async def find_by_prefix(self, messages: list[dict]) -> str | None:
"""Return the preferred backend whose stored conversation is a prefix of messages.

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass, field
@@ -11,6 +12,10 @@ class _SlotState:
acquired: int = 0
active_models: dict[str, int] = field(default_factory=dict)
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
# Warm-hold window: after the last slot of a model drains, block other models
# for a brief period so follow-up requests (titles, suggestions) hit the same backend.
sticky_model: str | None = None
sticky_until: float = 0.0
class SlotTracker:
@@ -24,6 +29,11 @@ class SlotTracker:
self._slots: dict[str, _SlotState] = {}
self._global_limits: dict[str, int] = {}
self._global_counts: dict[str, int] = {}
self._model_unload_delay: float = 0.0
def set_model_unload_delay(self, delay: float) -> None:
"""Seconds to keep a backend sticky to its last model after all slots drain."""
self._model_unload_delay = max(0.0, delay)
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
if url not in self._slots:
@@ -50,6 +60,9 @@ class SlotTracker:
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
if state.acquired >= state.capacity:
return False
if model_id and state.sticky_model and time.monotonic() < state.sticky_until:
if model_id != state.sticky_model:
return False
if model_id and model_id not in state.active_models:
if state.max_models is not None and len(state.active_models) >= state.max_models:
return False
@@ -73,6 +86,15 @@ class SlotTracker:
return True
return self._can_acquire(state, model_id)
def is_warm(self, url: str, model_id: str) -> bool:
"""True if model_id is active or in warm-hold window on this backend."""
state = self._slots.get(url)
if state is None:
return False
if model_id in state.active_models:
return True
return state.sticky_model == model_id and time.monotonic() < state.sticky_until
def active_model_set(self, url: str) -> frozenset[str]:
"""Models currently in-flight on this backend."""
state = self._slots.get(url)
@@ -80,6 +102,31 @@ class SlotTracker:
return frozenset()
return frozenset(state.active_models)
def waive_sticky_if_idle(self, url: str, model_id: str) -> bool:
"""Clear the sticky window if it is the sole barrier for model_id on this backend.
Only waives when all three conditions hold:
1. Physical slot is free (acquired < capacity).
2. The sticky window is active for a *different* model.
3. max_models would NOT independently block model_id.
Returns True if the window was cleared.
"""
state = self._slots.get(url)
if state is None:
return False
if state.acquired >= state.capacity:
return False
if not state.sticky_model or time.monotonic() >= state.sticky_until:
return False
if state.sticky_model == model_id:
return False
if model_id not in state.active_models:
if state.max_models is not None and len(state.active_models) >= state.max_models:
return False
state.sticky_model = None
state.sticky_until = 0.0
return True
def usage(self, url: str) -> tuple[int, int]:
state = self._slots.get(url)
if state is None:
@@ -127,6 +174,9 @@ class SlotTracker:
count = state.active_models[model_id] - 1
if count <= 0:
del state.active_models[model_id]
if self._model_unload_delay > 0 and not state.active_models:
state.sticky_model = model_id
state.sticky_until = time.monotonic() + self._model_unload_delay
else:
state.active_models[model_id] = count
if model_id and model_id in self._global_limits:
@@ -146,4 +196,6 @@ class SlotTracker:
)
state.acquired = 0
state.active_models.clear()
state.sticky_model = None
state.sticky_until = 0.0
state.condition.notify_all()

View File

@@ -49,7 +49,8 @@ class TestProxyConfig(unittest.TestCase):
self.assertEqual(cfg.session_idle_ttl, 300.0)
self.assertEqual(cfg.backends, [])
self.assertIsNone(cfg.default_max_models)
self.assertEqual(cfg.max_queue_skip, 0)
self.assertEqual(cfg.model_affinity_sched_bonus, 0)
self.assertEqual(cfg.queue_aging_equalization, 30.0)
def test_from_file_minimal(self):
path = self._write_config(
@@ -71,7 +72,8 @@ class TestProxyConfig(unittest.TestCase):
"slot_wait_timeout": 60,
"session_idle_ttl": 600,
"default_max_models": 2,
"max_queue_skip": 5,
"model_affinity_sched_bonus": 10,
"queue_aging_equalization": 30.0,
"backends": [
{"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1},
{"url": "http://b2/"},
@@ -85,7 +87,8 @@ class TestProxyConfig(unittest.TestCase):
self.assertEqual(cfg.api_keys, ["key1", "key2"])
self.assertEqual(cfg.poll_interval, 10)
self.assertEqual(cfg.default_max_models, 2)
self.assertEqual(cfg.max_queue_skip, 5)
self.assertEqual(cfg.model_affinity_sched_bonus, 10)
self.assertEqual(cfg.queue_aging_equalization, 30.0)
self.assertEqual(cfg.backends[0].api_key, "secret")
self.assertEqual(cfg.backends[0].model_ids, ["m1"])
self.assertEqual(cfg.backends[0].max_models, 1)

View File

@@ -44,12 +44,13 @@ class TestHealthEndpoint(unittest.TestCase):
resp = client.get("/health")
self.assertEqual(resp.status_code, 503)
def test_health_with_api_key_required(self):
def test_health_exempt_from_auth(self):
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
self.assertEqual(resp.status_code, 401)
# /health is auth-exempt; no live backends so 503, never 401
self.assertNotEqual(resp.status_code, 401)
def test_health_with_valid_api_key(self):
cfg = _proxy_config([], api_keys=["mykey"])

View File

@@ -56,7 +56,9 @@ class TestMonitorEndpoints(unittest.TestCase):
slot_tracker.set_capacity("http://b1", 4)
data = TestClient(app).get("/monitor/data").json()
for key in ("uptime", "total_requests", "queue_depth", "session_count",
"live_backend_count", "backends", "queue", "sessions_by_model"):
"live_backend_count", "backends", "queue",
"model_stats", "backend_stats",
"session_hits", "session_misses", "session_hit_rate"):
self.assertIn(key, data)
def test_monitor_data_backend_fields(self):

View File

@@ -22,7 +22,12 @@ def _entry(**kwargs) -> QueueEntry:
class TestScheduler(unittest.IsolatedAsyncioTestCase):
def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0):
def _make_scheduler(
self,
live_backends=None,
model_affinity_sched_bonus: int = 0,
queue_aging_equalization: float = 30.0,
):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
registry = BackendRegistry(cfg)
for state in (live_backends or []):
@@ -41,7 +46,8 @@ class TestScheduler(unittest.IsolatedAsyncioTestCase):
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=max_queue_skip,
model_affinity_sched_bonus=model_affinity_sched_bonus,
queue_aging_equalization=queue_aging_equalization,
)
return scheduler, queue, registry, slot_tracker, session_store
@@ -212,88 +218,258 @@ class TestScheduler(unittest.IsolatedAsyncioTestCase):
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
# ------------------------------------------------------------------
# N-skip reordering
# Priority scheduling
# ------------------------------------------------------------------
async def test_no_reorder_when_max_queue_skip_zero(self):
"""Default FIFO: model-B request is not promoted over model-A."""
async def test_pure_fifo_when_bonus_zero(self):
"""Default (bonus=0): blocked head-of-queue does not prevent later entries."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0)
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2])
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# Fill b1; b2 is free with m2 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m2")
# Queue: [m1-request (blocked), m2-request (could go to b2)]
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
# Release b2 slot so m2 can be served
await slots.release("http://b2", "m2")
await scheduler._dispatch_all()
# m2 can be dispatched even with max_queue_skip=0 because _dispatch_all
# scans all entries (not strict head-of-line per model)
self.assertFalse(e_m1.future.done())
self.assertTrue(e_m2.future.done())
# skip_count must NOT be bumped when max_queue_skip=0
self.assertEqual(e_m1.skip_count, 0)
async def test_affinity_promotes_matching_model(self):
"""With max_queue_skip>0, a matching model gets promoted."""
async def test_warm_bonus_applies_when_only_one_model_in_queue(self):
"""With bonus>0 and a single model type in the queue, the warm model fills both free slots."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3)
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# capacity=3: one slot occupied, two free
slots.set_capacity("http://b1", 3)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Two m1 requests — no competing model, diversity pass is a no-op
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m1")
await queue.enqueue(e1)
await queue.enqueue(e2)
await scheduler._dispatch_all()
# Both dispatched: warm model freely fills available slots when nothing else is waiting
self.assertTrue(e1.future.done())
self.assertTrue(e2.future.done())
async def test_warm_model_does_not_block_unrepresented_model(self):
"""With bonus>0, an unrepresented cold model gets the diversity slot before a warm one."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
# b1 already has m1 in-flight
# m1 is in-flight on b1 (warm); m2 is cold and unrepresented
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: [m2-entry (no affinity), m1-entry (affinity match)]
e_other = _entry(model_id="m2") # no backend serves m2
# Queue: cold m2 request arrives first, warm m1 request arrives second
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m1 entry is promoted (affinity pass), m2 stays (no backend)
self.assertTrue(e_m1.future.done())
self.assertFalse(e_other.future.done())
# e_other was bypassed once
self.assertEqual(e_other.skip_count, 1)
# Diversity pass gives the free slot to m2 (unrepresented) regardless of warm bonus
self.assertTrue(e_m2.future.done())
self.assertFalse(e_m1.future.done())
async def test_skip_count_caps_reordering(self):
"""Once skip_count reaches max_queue_skip the entry freezes at head."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2)
slots.set_capacity("http://b1", 4)
# ------------------------------------------------------------------
# Model diversity (unrepresented-first pass)
# ------------------------------------------------------------------
# b1 has m1 active
async def test_diversity_dispatches_cold_model_over_warm(self):
"""With one free slot, the diversity pass gives it to a cold model, not another warm request."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# 2 slots: 1 occupied by m1 (warm), 1 free
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_other = _entry(model_id="m2")
e_other.skip_count = 2 # already at limit — must not be bypassed
# Queue m1 (gets warm_bonus=10) then m2 (cold, bonus=0)
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# m2 dispatched by diversity pass despite lower priority; m1 has no free slot left
self.assertTrue(e_m2.future.done(), "cold m2 should win the diversity slot")
self.assertFalse(e_m1.future.done(), "warm m1 has no slot left after diversity pass")
async def test_diversity_loads_each_model_on_separate_backend(self):
"""Two free backends and two queued models → each model lands on its own backend."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1, b2], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# m1 is warm on b1; m2 is cold; only b2 is free
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass dispatches m2 to b2; priority pass dispatches m1 — no free slot remains
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b2")
self.assertFalse(e_m1.future.done(), "b1 is full, b2 taken by m2")
async def test_diversity_respects_max_models(self):
"""max_models=1 prevents the diversity pass from loading a second model simultaneously."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 2)
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass tries m2 but can_accept returns False (max_models=1 with m1 active)
# Priority pass dispatches m1 (same model, warm, 1 slot free)
self.assertTrue(e_m1.future.done())
self.assertFalse(e_m2.future.done())
async def test_diversity_skipped_when_bonus_zero(self):
"""Pure FIFO mode (bonus=0): no diversity pass, FIFO order holds."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# m1 arrives first (FIFO should serve it first), m2 is cold
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# FIFO: m1 dispatched first (it arrived first), m2 has no slot left
self.assertTrue(e_m1.future.done())
self.assertFalse(e_m2.future.done())
async def test_loop_retries_after_sticky_window_expires(self):
"""Scheduler loop dispatches a blocked entry once the sticky window expires."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 2)
slots.set_model_unload_delay(0.05)
# Acquire and release m1 — starts a 0.05 s sticky window that blocks m2.
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
await slots.release("http://b1", "m1")
entry = _entry(model_id="m2")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done(), "m2 should be blocked by sticky window")
# Run the loop with a short interval so it retries before the test times out.
scheduler._wakeup_interval = 0.1
scheduler.start()
try:
async with asyncio.timeout(1.0):
result = await entry.future
finally:
await scheduler.stop()
self.assertEqual(result.url, "http://b1")
async def test_diversity_waives_sticky_for_unrepresented_model(self):
"""Diversity pass clears a sticky window blocking an unrepresented model."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 2)
slots.set_model_unload_delay(60.0)
# m1 was active and released — sticky window now blocks m2
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
await slots.release("http://b1", "m1")
self.assertFalse(slots.can_accept("http://b1", "m2"))
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass waives sticky for m2 and dispatches it
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b1")
async def test_aging_overtakes_warm_bonus(self):
"""After equalization time, an aged cold request outranks the warm bonus."""
b1 = _make_state("http://b1", models=["m1", "m2"])
# equalization=0.1s so aging is fast enough to test synchronously
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10, queue_aging_equalization=0.1
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# m2 request arrives and waits long enough to exceed the bonus
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await asyncio.sleep(0.15) # age_score > bonus after equalization
# m1 warm request arrives after m2 has already aged past equalization
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# Affinity pass stops at e_other (skip_count >= max_queue_skip),
# so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass.
# e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass.
self.assertTrue(e_m1.future.done())
# skip_count must NOT increase further (entry was frozen)
self.assertEqual(e_other.skip_count, 2)
# m2's age_score now exceeds the warm bonus → m2 dispatched first
self.assertTrue(e_m2.future.done())
self.assertFalse(e_m1.future.done())
if __name__ == "__main__":

View File

@@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
self.assertIsNone(tracker.global_model_usage("othermodel"))
# ------------------------------------------------------------------
# Warm-hold / model_unload_delay tests
# ------------------------------------------------------------------
async def test_sticky_window_blocks_other_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
# Window active: model-b should be rejected
self.assertFalse(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_allows_same_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-a"))
async def test_sticky_window_expires(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(0.05)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await asyncio.sleep(0.1)
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_not_started_when_delay_zero(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(0.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_not_started_while_slots_remain(self):
"""Window must not start until ALL slots for the model drain."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a") # one slot still held
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_reset_acquired_clears_sticky_state(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await tracker.reset_acquired("http://b")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_reset_acquired_updates_global_counts(self):
tracker = SlotTracker()
tracker.set_capacity("http://b1", 4)
@@ -260,5 +319,81 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
self.assertEqual(tracker.global_model_usage("bigmodel"), (0, 1))
# ------------------------------------------------------------------
# waive_sticky_if_idle tests
# ------------------------------------------------------------------
async def test_waive_sticky_clears_window_for_unrepresented_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertTrue(waived)
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_waive_sticky_noop_when_no_free_slot(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
# Slot is now free but let's fill it first
await tracker.acquire("http://b", "model-a")
# sticky is cleared when re-acquired; set up sticky again manually is tricky,
# so use a fresh tracker with the slot held before release
tracker2 = SlotTracker()
tracker2.set_capacity("http://b", 1)
tracker2.set_model_unload_delay(60.0)
await tracker2.acquire("http://b", "model-a")
await tracker2.release("http://b", "model-a")
# Now exhaust the slot
await tracker2.acquire("http://b", "model-a")
waived = tracker2.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
async def test_waive_sticky_noop_for_same_model(self):
"""waive_sticky_if_idle must not clear the window for the sticky model itself."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
waived = tracker.waive_sticky_if_idle("http://b", "model-a")
self.assertFalse(waived)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
def test_waive_sticky_noop_when_max_models_also_blocks(self):
"""Do not waive if max_models would still block the requesting model.
The sticky window is only set when active_models empties, so this scenario
can only be created by direct state manipulation (not through normal acquire/release).
The guard is still present in waive_sticky_if_idle for defensive correctness.
"""
import time as _time
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_max_models("http://b", 1)
state = tracker._ensure("http://b")
# Manually inject: sticky=model-a, active_models={model-c: 1}
state.sticky_model = "model-a"
state.sticky_until = _time.monotonic() + 60.0
state.active_models["model-c"] = 1
state.acquired = 1
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
def test_waive_sticky_noop_when_no_active_window(self):
"""Returns False and has no effect when there is no sticky window."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
if __name__ == "__main__":
unittest.main()