Compare commits

...

11 Commits

Author SHA1 Message Date
9685ce97f9 update monitoring and scheduling balance over time. 2026-05-19 23:01:39 +02:00
0557964442 improve and test scheduler 2026-05-19 22:39:50 +02:00
123d550808 add more stats 2026-05-19 22:18:44 +02:00
28a6aa739d again 2026-05-19 22:00:26 +02:00
6e35208a66 improve backend sharing across models 2026-05-19 21:38:07 +02:00
579185654b polish 2026-05-19 21:21:42 +02:00
3c536241f6 integration 2026-05-19 21:11:35 +02:00
aad8854393 bump version 2026-05-19 20:57:57 +02:00
b379faebdb change scheduling strategy to hibrid priority queue with aging 2026-05-19 20:57:38 +02:00
0f5aabbf15 bump version 2026-05-18 01:03:06 +02:00
7cf16dcace improve cache 2026-05-18 01:02:57 +02:00
19 changed files with 1003 additions and 241 deletions

View File

@@ -5,7 +5,10 @@
"Bash(python -m pytest --tb=short -q)",
"Bash(python -m pytest tests/test_slot_tracker.py -v)",
"Bash(python -m pytest tests/test_scheduler.py -v)",
"Bash(python -m pytest tests/test_monitor.py::TestMonitorEndpoints::test_monitor_data_structure -v)"
"Bash(python -m pytest tests/test_monitor.py::TestMonitorEndpoints::test_monitor_data_structure -v)",
"Bash(python -m pytest -q)",
"Bash(python -m pytest tests/test_config.py -q)",
"Bash(python -m pytest -x -q)"
]
}
}

View File

@@ -7,7 +7,8 @@
"session_idle_ttl": 300,
"default_slot_capacity": 1,
"default_max_models": 1,
"max_queue_skip": 4,
"model_affinity_sched_bonus": 10,
"queue_aging_equalization": 30.0,
"model_unload_delay": 3.0,
"model_limits": {
"my-very-large-model": 1

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "llamacpp-ha"
version = "0.5.0"
version = "0.14.0"
description = "Smart load balancer for llama.cpp servers"
requires-python = ">=3.13"
dependencies = [

View File

@@ -38,12 +38,15 @@ class ProxyConfig(BaseSettings):
# Fallback max_models applied to any backend that does not set its own.
# None = unlimited. Set to 1 globally when all backends are llama.cpp.
default_max_models: int | None = None
# How many queue positions a model-affinity request may skip ahead.
# 0 = pure FIFO (default). N > 0 enables reordering: the scheduler looks
# for a request matching an already-active model and can promote it up to N
# positions; each bypassed entry accumulates a skip count and is immune to
# further skipping once it reaches N.
max_queue_skip: int = 0
# Priority bonus added to requests whose model is currently warm (active or in
# the warm-hold window) on an available backend. 0 = pure FIFO.
# Works together with queue_aging_equalization: after that many seconds a
# waiting request's age score catches up to the bonus, restoring FIFO order.
model_affinity_sched_bonus: int = 0
# Seconds after which an aging request's accumulated score equals
# model_affinity_sched_bonus, ensuring starvation is impossible.
# Only meaningful when model_affinity_sched_bonus > 0.
queue_aging_equalization: float = 30.0
# Seconds to keep a backend sticky to its last model after all slots drain.
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
# generation) that arrive shortly after the main response. 0 = disabled.

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
import itertools
import logging
import time
from collections.abc import Callable
from typing import AsyncIterator
import aiohttp
@@ -63,12 +65,18 @@ async def forward_request(
scheduler: Scheduler,
model_id: str = "",
path_override: str | None = None,
on_done: Callable[[float], None] | None = None,
) -> Response:
path = path_override or request.url.path
query = request.url.query
target = backend.url + path + (f"?{query}" if query else "")
headers = _forward_headers(request, backend)
body = await request.body()
_start = time.monotonic()
def _record() -> None:
if on_done:
on_done(time.monotonic() - _start)
try:
resp_ctx = session.request(
@@ -96,6 +104,7 @@ async def forward_request(
async for chunk in resp.content.iter_chunked(_STREAM_CHUNK):
yield chunk
finally:
_record()
await resp_ctx.__aexit__(None, None, None)
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
@@ -111,6 +120,7 @@ async def forward_request(
data = await resp.read()
finally:
await resp_ctx.__aexit__(None, None, None)
_record()
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
return Response(
@@ -122,6 +132,7 @@ async def forward_request(
except Exception as exc:
log.error("Forward error to %s: %s", backend.url, exc)
_record()
await slot_tracker.release(backend.url, model_id)
scheduler.notify_slot_released()
raise

View File

@@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/"])
_EXEMPT_PATHS = frozenset(["/monitor", "/monitor/", "/monitor/data", "/monitor/data/", "/health", "/health/"])
class ApiKeyMiddleware(BaseHTTPMiddleware):

View File

@@ -1,13 +1,16 @@
from __future__ import annotations
import math
import time
from collections import deque
from collections.abc import Callable
from dataclasses import dataclass, field
from fastapi import APIRouter
from fastapi.responses import HTMLResponse, JSONResponse
from .queue import RequestQueue
from .registry import BackendRegistry
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry, BackendState
from .session_store import SessionStore
from .slot_tracker import SlotTracker
@@ -22,7 +25,12 @@ class ProxyStats:
new_sessions: int = 0
model_requests: dict[str, int] = field(default_factory=dict)
model_tokens: dict[str, int] = field(default_factory=dict)
backend_requests: dict[str, int] = field(default_factory=dict)
backend_session_hits: dict[str, int] = field(default_factory=dict)
backend_session_misses: dict[str, int] = field(default_factory=dict)
# Queue wait time: enqueue → scheduler dispatch (rolling 1-hour window)
accept_samples: deque = field(default_factory=deque)
# Per-backend rolling samples: deque of (finish_timestamp, duration_seconds)
backend_busy_samples: dict[str, deque] = field(default_factory=dict)
def increment_requests(self) -> None:
self.total_requests += 1
@@ -34,18 +42,64 @@ class ProxyStats:
if tokens:
self.model_tokens[model_id] = self.model_tokens.get(model_id, 0) + tokens
def record_backend(self, url: str) -> None:
self.backend_requests[url] = self.backend_requests.get(url, 0) + 1
def record_session(self, had_session: bool, preferred_url: str | None, actual_url: str) -> None:
if had_session and preferred_url:
def record_session(self, preferred_url: str | None, actual_url: str) -> None:
if preferred_url:
if actual_url == preferred_url:
self.session_hits += 1
self.backend_session_hits[actual_url] = self.backend_session_hits.get(actual_url, 0) + 1
else:
self.session_misses += 1
elif not had_session:
# Count miss against the preferred backend (the one that was expected but missed).
self.backend_session_misses[preferred_url] = self.backend_session_misses.get(preferred_url, 0) + 1
else:
self.new_sessions += 1
def record_accept_time(self, wait_seconds: float) -> None:
self.accept_samples.append((time.monotonic(), wait_seconds))
def record_request_duration(self, url: str, duration: float) -> None:
if url not in self.backend_busy_samples:
self.backend_busy_samples[url] = deque()
self.backend_busy_samples[url].append((time.monotonic(), duration))
def rolling_backend_stats(self, capacity_map: dict[str, int], window: float = 3600.0) -> dict[str, dict]:
cutoff = time.monotonic() - window
per_backend: dict[str, dict] = {}
total_reqs = 0
for url, samples in self.backend_busy_samples.items():
while samples and samples[0][0] < cutoff:
samples.popleft()
n = len(samples)
busy = sum(d for _, d in samples)
capacity = capacity_map.get(url, 1)
util = min(busy / (capacity * window) * 100, 100.0) if capacity > 0 else None
per_backend[url] = {
"requests": n,
"utilization_pct": util,
"avg_duration_s": busy / n if n else None,
}
total_reqs += n
for url, s in per_backend.items():
s["share_pct"] = round(s["requests"] / total_reqs * 100) if total_reqs else None
return per_backend
def accept_time_stats(self, window: float = 3600.0) -> dict | None:
cutoff = time.monotonic() - window
while self.accept_samples and self.accept_samples[0][0] < cutoff:
self.accept_samples.popleft()
if not self.accept_samples:
return None
values = sorted(v for _, v in self.accept_samples)
n = len(values)
p10_idx = max(0, math.ceil(0.10 * n) - 1)
p90_idx = max(0, math.ceil(0.90 * n) - 1)
return {
"p10": round(values[p10_idx], 3),
"mean": round(sum(values) / n, 3),
"p90": round(values[p90_idx], 3),
"sample_count": n,
}
def session_hit_rate(self) -> int | None:
total = self.session_hits + self.session_misses
return round(self.session_hits / total * 100) if total else None
@@ -98,31 +152,29 @@ _HTML = """<!DOCTYPE html>
<div class="stat"><div class="stat-val" id="session-count">-</div><div class="stat-label">Active Sessions</div></div>
<div class="stat"><div class="stat-val" id="live-count">-</div><div class="stat-label">Live Backends</div></div>
<div class="stat"><div class="stat-val" id="hit-rate">-</div><div class="stat-label">Session Hit Rate</div></div>
<div class="stat"><div class="stat-val" id="wait-p10">-</div><div class="stat-label">p10 Queue Wait (1h)</div></div>
<div class="stat"><div class="stat-val" id="wait-mean">-</div><div class="stat-label">Mean Queue Wait (1h)</div></div>
<div class="stat"><div class="stat-val" id="wait-p90">-</div><div class="stat-label">p90 Queue Wait (1h)</div></div>
</div>
<h2>Backends</h2>
<table>
<thead><tr><th>URL</th><th>Status</th><th>Active Model</th><th>Models</th><th>Slots</th><th class="num">Requests</th><th>Last Poll</th></tr></thead>
<tbody id="backends-body"><tr><td colspan="7" class="empty">Loading...</td></tr></tbody>
<thead><tr><th>URL</th><th>Status</th><th>Active Model</th><th>Slots</th><th class="num">Reqs (1h)</th><th class="num">Share (1h)</th><th class="num">Util (1h)</th><th class="num">Avg Req (1h)</th><th>Session Affinity</th><th>Last Poll</th></tr></thead>
<tbody id="backends-body"><tr><td colspan="10" class="empty">Loading...</td></tr></tbody>
</table>
<h2>Queue</h2>
<table>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Wait (s)</th><th>Est. Tokens</th><th>Skips</th></tr></thead>
<thead><tr><th>Request ID</th><th>Model</th><th>Session</th><th>Age</th><th>Est. Tokens</th><th class="num">Priority</th></tr></thead>
<tbody id="queue-body"><tr><td colspan="6" class="empty">Queue is empty</td></tr></tbody>
</table>
<h2>Model Stats</h2>
<table>
<thead><tr><th>Model</th><th class="num">Requests</th><th class="num">Est. Tokens</th><th class="num">Active Sessions</th></tr></thead>
<thead><tr><th>Model</th><th class="num">Requests</th><th class="num">Est. Tokens</th><th class="num">Tracked Sessions</th></tr></thead>
<tbody id="model-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
</table>
<h2>Backend Stats</h2>
<table>
<thead><tr><th>Backend</th><th class="num">Requests</th><th class="num">Share</th><th>Session Affinity</th></tr></thead>
<tbody id="backend-stats-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
</table>
<script>
(function() {
@@ -134,6 +186,22 @@ _HTML = """<!DOCTYPE html>
return n >= 1000 ? (n/1000).toFixed(1) + 'k' : String(n);
}
function fmtAge(s) {
if (s < 60) return s.toFixed(1) + 's';
const m = Math.floor(s / 60);
const r = Math.floor(s % 60);
return m + 'm ' + String(r).padStart(2, '0') + 's';
}
function fmtDur(s) {
if (s == null) return '<span class="empty">-</span>';
if (s < 1) return Math.round(s * 1000) + 'ms';
if (s < 60) return s.toFixed(1) + 's';
const m = Math.floor(s / 60);
const r = Math.floor(s % 60);
return m + 'm ' + String(r).padStart(2, '0') + 's';
}
function render(data) {
document.getElementById('uptime').textContent = data.uptime;
document.getElementById('total-req').textContent = data.total_requests;
@@ -151,9 +219,14 @@ _HTML = """<!DOCTYPE html>
hrEl.className = 'stat-val ' + (hr >= 80 ? 'hit' : hr >= 50 ? 'slots' : 'miss');
}
const at = data.accept_time;
document.getElementById('wait-p10').textContent = at ? fmtAge(at.p10) : 'N/A';
document.getElementById('wait-mean').textContent = at ? fmtAge(at.mean) : 'N/A';
document.getElementById('wait-p90').textContent = at ? fmtAge(at.p90) : 'N/A';
const bBody = document.getElementById('backends-body');
if (!data.backends.length) {
bBody.innerHTML = '<tr><td colspan="7" class="empty">No backends configured</td></tr>';
bBody.innerHTML = '<tr><td colspan="10" class="empty">No backends configured</td></tr>';
} else {
bBody.innerHTML = data.backends.map(b => {
const badge = b.live
@@ -162,11 +235,18 @@ _HTML = """<!DOCTYPE html>
const active = b.active_models.length
? b.active_models.map(m => `<span class="badge badge-live">${esc(m)}</span>`).join(' ')
: '<span class="empty">idle</span>';
const models = b.models.length ? esc(b.models.join(', ')) : '<span class="empty">none</span>';
const slots = `<span class="slots">${b.slots_acquired}/${b.slots_total}</span>`;
const reqs = b.stat_requests > 0 ? fmt(b.stat_requests) : '<span class="empty">0</span>';
const share = b.stat_share_pct != null ? b.stat_share_pct + '%' : '<span class="empty">-</span>';
const util = b.stat_utilization_pct != null
? `<span class="${b.stat_utilization_pct >= 80 ? 'miss' : b.stat_utilization_pct >= 50 ? 'slots' : 'hit'}">${b.stat_utilization_pct.toFixed(1)}%</span>`
: '<span class="empty">-</span>';
const avgReq = fmtDur(b.stat_avg_duration_s);
const affinity = b.stat_affinity != null
? `<span class="${b.stat_affinity >= 80 ? 'hit' : b.stat_affinity >= 50 ? 'slots' : 'miss'}">${b.stat_affinity}% hit</span>`
: '<span class="empty">-</span>';
const age = b.last_poll_age == null ? '<span class="empty">never</span>' : esc(b.last_poll_age.toFixed(1)) + 's';
const reqs = b.requests > 0 ? fmt(b.requests) : '<span class="empty">0</span>';
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${active}</td><td>${models}</td><td>${slots}</td><td class="num">${reqs}</td><td>${age}</td></tr>`;
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${active}</td><td>${slots}</td><td class="num">${reqs}</td><td class="num">${share}</td><td class="num">${util}</td><td class="num">${avgReq}</td><td>${affinity}</td><td>${age}</td></tr>`;
}).join('');
}
@@ -177,7 +257,8 @@ _HTML = """<!DOCTYPE html>
qBody.innerHTML = data.queue.map(e => {
const tok = e.estimated_tokens != null ? esc(e.estimated_tokens) : '<span class="empty">-</span>';
const sid = e.session_id ? esc(e.session_id) : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${esc(e.wait_seconds.toFixed(2))}</td><td>${tok}</td><td>${esc(e.skip_count)}</td></tr>`;
const pri = e.priority != null ? `<span class="slots">${e.priority.toFixed(1)}</span>` : '<span class="empty">-</span>';
return `<tr><td>${esc(e.request_id.slice(0,12))}</td><td>${esc(e.model_id||'-')}</td><td>${sid}</td><td>${fmtAge(e.wait_seconds)}</td><td>${tok}</td><td class="num">${pri}</td></tr>`;
}).join('');
}
@@ -190,29 +271,11 @@ _HTML = """<!DOCTYPE html>
mBody.innerHTML = mKeys.map(m => {
const s = ms[m];
const tok = s.estimated_tokens > 0 ? fmt(s.estimated_tokens) : '<span class="empty">-</span>';
const sess = s.active_sessions > 0 ? s.active_sessions : '<span class="empty">0</span>';
const sess = s.tracked_sessions > 0 ? s.tracked_sessions : '<span class="empty">0</span>';
return `<tr><td>${esc(m)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${tok}</td><td class="num">${sess}</td></tr>`;
}).join('');
}
const bsBody = document.getElementById('backend-stats-body');
const bs = data.backend_stats;
const bsKeys = Object.keys(bs).sort((a,b) => bs[b].requests - bs[a].requests);
if (!bsKeys.length || data.total_requests === 0) {
bsBody.innerHTML = '<tr><td colspan="4" class="empty">No data yet</td></tr>';
} else {
bsBody.innerHTML = bsKeys.map(url => {
const s = bs[url];
const share = data.total_requests > 0
? Math.round(s.requests / data.total_requests * 100) + '%'
: '<span class="empty">-</span>';
const affinity = s.session_hits + s.session_misses > 0
? Math.round(s.session_hits / (s.session_hits + s.session_misses) * 100) + '% hit'
: '<span class="empty">-</span>';
return `<tr><td>${esc(url)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${share}</td><td>${affinity}</td></tr>`;
}).join('');
}
document.getElementById('status').textContent = 'updated ' + new Date().toLocaleTimeString();
}
@@ -234,12 +297,42 @@ _HTML = """<!DOCTYPE html>
"""
def _build_backend_entry(state: BackendState, slot_tracker: SlotTracker) -> dict:
acquired, total = slot_tracker.usage(state.url)
age = state.last_poll_age
return {
"url": state.url,
"live": state.live,
"active_models": sorted(slot_tracker.active_model_set(state.url)),
"models": list(state.models),
"slots_acquired": acquired,
"slots_total": total,
"last_poll_age": None if age == float("inf") else round(age, 1),
}
def _enrich_backend_stats(entry: dict, stats: ProxyStats, rolling: dict[str, dict]) -> None:
url = entry["url"]
rs = rolling.get(url, {})
hits = stats.backend_session_hits.get(url, 0)
misses = stats.backend_session_misses.get(url, 0)
n = rs.get("requests", 0)
util = rs.get("utilization_pct")
avg = rs.get("avg_duration_s")
entry["stat_requests"] = n
entry["stat_share_pct"] = rs.get("share_pct")
entry["stat_utilization_pct"] = round(util, 1) if util is not None else None
entry["stat_avg_duration_s"] = round(avg, 2) if avg is not None else None
entry["stat_affinity"] = round(hits / (hits + misses) * 100) if hits + misses else None
def build_router(
registry: BackendRegistry,
slot_tracker: SlotTracker,
request_queue: RequestQueue,
session_store: SessionStore,
stats: ProxyStats,
priority_fn: Callable[[QueueEntry], float] | None = None,
) -> APIRouter:
router = APIRouter()
@@ -250,48 +343,27 @@ def build_router(
@router.get("/monitor/data", include_in_schema=False)
async def monitor_data() -> JSONResponse:
states = registry.get_all_states()
backends_data = []
for state in states:
acquired, total = slot_tracker.usage(state.url)
age = state.last_poll_age
backends_data.append(
{
"url": state.url,
"live": state.live,
"active_models": sorted(slot_tracker.active_model_set(state.url)),
"models": list(state.models),
"slots_acquired": acquired,
"slots_total": total,
"requests": stats.backend_requests.get(state.url, 0),
"last_poll_age": None if age == float("inf") else round(age, 1),
}
)
backends_data = [_build_backend_entry(s, slot_tracker) for s in states]
queue_snapshot = await request_queue.snapshot()
queue_snapshot = await request_queue.snapshot(score_fn=priority_fn)
session_count = await session_store.count()
sessions_by_model = await session_store.count_by_model()
live_count = sum(1 for s in states if s.live)
# Merge per-model request stats with active session counts.
all_models = set(stats.model_requests) | set(sessions_by_model)
model_stats = {
m: {
"requests": stats.model_requests.get(m, 0),
"estimated_tokens": stats.model_tokens.get(m, 0),
"active_sessions": sessions_by_model.get(m, 0),
"tracked_sessions": sessions_by_model.get(m, 0),
}
for m in all_models
}
# Per-backend cumulative stats with session affinity breakdown.
backend_stats = {
url: {
"requests": count,
"session_hits": 0,
"session_misses": 0,
}
for url, count in stats.backend_requests.items()
}
capacity_map = {e["url"]: e["slots_total"] for e in backends_data}
rolling = stats.rolling_backend_stats(capacity_map)
for entry in backends_data:
_enrich_backend_stats(entry, stats, rolling)
return JSONResponse(
{
@@ -303,10 +375,10 @@ def build_router(
"session_hits": stats.session_hits,
"session_misses": stats.session_misses,
"session_hit_rate": stats.session_hit_rate(),
"accept_time": stats.accept_time_stats(),
"backends": backends_data,
"queue": queue_snapshot,
"model_stats": model_stats,
"backend_stats": backend_stats,
}
)

View File

@@ -8,7 +8,9 @@ from dataclasses import dataclass, field
class BackendCandidate:
"""Minimal view of a backend passed to routing policies."""
url: str
index: int # position in the original backends list
index: int
slots_acquired: int = 0
slots_total: int = 1 # position in the original backends list
class RoutingPolicy(ABC):
@@ -28,3 +30,26 @@ class RoundRobinPolicy(RoutingPolicy):
chosen = candidates[count % len(candidates)]
self._counters[model_id] = count + 1
return chosen
class LeastConnectionsPolicy(RoutingPolicy):
"""Pick the least-loaded backend; round-robin among ties.
Prefers the backend with the lowest slot utilization ratio so that an idle
backend is always chosen over a busy one. When multiple backends share the
same load (the common all-idle case) a per-model round-robin counter spreads
new requests evenly.
"""
def __init__(self) -> None:
self._counters: dict[str, int] = {}
def select(self, model_id: str, candidates: list[BackendCandidate]) -> BackendCandidate:
if not candidates:
raise ValueError("No candidates available")
min_ratio = min(c.slots_acquired / max(c.slots_total, 1) for c in candidates)
least_loaded = [c for c in candidates if c.slots_acquired / max(c.slots_total, 1) == min_ratio]
count = self._counters.get(model_id, 0)
chosen = least_loaded[count % len(least_loaded)]
self._counters[model_id] = count + 1
return chosen

View File

@@ -14,7 +14,7 @@ from .config import ProxyConfig
from .forwarder import forward_best_effort, forward_request
from .middleware import ApiKeyMiddleware
from .monitor import ProxyStats, build_router as build_monitor_router
from .policies import RoundRobinPolicy
from .policies import LeastConnectionsPolicy
from .queue import QueueEntry, RequestQueue
from .registry import BackendRegistry, BackendState
from .scheduler import Scheduler
@@ -175,6 +175,7 @@ async def _dispatch_entry(
{"error": {"message": "No slot available (timeout)", "type": "overloaded"}},
status_code=503,
)
stats.record_accept_time(entry.wait_seconds)
return backend
@@ -281,9 +282,7 @@ async def _inference_endpoint(
if not incoming_session_id:
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
preferred_url: str | None = None
if incoming_session_id:
preferred_url = await session_store.get_preferred_backend(session_id)
preferred_url = await session_store.get_preferred_backend(session_id)
result = await _dispatch_entry(
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
@@ -303,6 +302,7 @@ async def _inference_endpoint(
slot_tracker=slot_tracker,
scheduler=scheduler,
model_id=model_id,
on_done=lambda dur, _url=backend.url: stats.record_request_duration(_url, dur),
)
break
except aiohttp.ClientError as exc:
@@ -318,8 +318,7 @@ async def _inference_endpoint(
backend = failover
stats.record_model(model_id, _estimate_tokens(body))
stats.record_backend(backend.url)
stats.record_session(bool(incoming_session_id), preferred_url, backend.url)
stats.record_session(preferred_url, backend.url)
if model_id:
messages = body.get("messages", [])
@@ -355,8 +354,10 @@ def create_app(config: ProxyConfig) -> FastAPI:
registry=registry,
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=config.max_queue_skip,
policy=LeastConnectionsPolicy(),
model_affinity_sched_bonus=config.model_affinity_sched_bonus,
queue_aging_equalization=config.queue_aging_equalization,
wakeup_interval=config.model_unload_delay if config.model_unload_delay > 0 else 5.0,
)
@asynccontextmanager
@@ -393,6 +394,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
request_queue=request_queue,
session_store=session_store,
stats=stats,
priority_fn=scheduler._priority if config.model_affinity_sched_bonus > 0 else None,
)
)

View File

@@ -17,10 +17,6 @@ class QueueEntry:
future: asyncio.Future | None = field(default=None)
# Populated by scheduler at dispatch time
assigned_backend: str | None = None
# Incremented each time a later entry is dispatched ahead of this one via
# model-affinity reordering. When it reaches max_queue_skip the entry
# becomes immune to further skipping (starvation prevention).
skip_count: int = 0
@property
def wait_seconds(self) -> float:
@@ -59,19 +55,25 @@ class RequestQueue:
async with self._lock:
return len(self._entries)
async def snapshot(self) -> list[dict]:
async def snapshot(self, score_fn=None) -> list[dict]:
async with self._lock:
return [
{
entries = (
sorted(self._entries, key=score_fn, reverse=True)
if score_fn
else self._entries
)
rows = []
for e in entries:
d: dict = {
"request_id": e.request_id,
"model_id": e.model_id,
"session_id": (e.session_id or "")[:8] or None,
"wait_seconds": round(e.wait_seconds, 2),
"estimated_tokens": e.estimated_tokens,
"skip_count": e.skip_count,
"priority": round(score_fn(e), 1) if score_fn else None,
}
for e in self._entries
]
rows.append(d)
return rows
def notify(self) -> None:
"""Wake the scheduler without holding the lock."""

View File

@@ -139,8 +139,8 @@ class BackendRegistry:
models = state.models
capacity = state.slot_capacity
was_live = state.live
async with self._lock:
was_live = state.live
state.live = live
if live:
state.models = models

View File

@@ -18,36 +18,41 @@ class Scheduler:
Dispatch order
--------------
When max_queue_skip == 0 (default) the queue is pure FIFO.
When model_affinity_sched_bonus == 0 (default) the queue is pure FIFO.
When max_queue_skip > N the scheduler runs a two-phase dispatch on every
wakeup:
When model_affinity_sched_bonus > 0 each dispatch cycle runs two passes:
Phase 1 — model-affinity promotion
The scheduler scans the queue looking for entries whose model is already
in-flight on a free backend. It promotes those entries ahead of earlier
entries that would require a model switch. For every entry that is
bypassed, its skip_count is incremented. Scanning stops as soon as an
entry with skip_count >= max_queue_skip is encountered (that entry is
frozen at the head and must be served next, preventing starvation).
1. Diversity pass — for each distinct model that has NO active or warm
backend, the highest-priority entry for that model is dispatched first.
This guarantees that N distinct models in the queue will occupy N
backends (up to available capacity) before any model gets a second slot.
Respects SlotTracker.can_accept() so max_models still applies.
Phase 2 — standard FIFO
Remaining entries are dispatched in arrival order to any backend that
can accept the model (slot free, max_models constraint satisfied).
2. Priority pass — remaining entries are dispatched by effective priority:
priority = warm_bonus + age_score
warm_bonus = model_affinity_sched_bonus when the requested model is
currently warm (active or in warm-hold window) on any
backend; 0 otherwise.
age_score = wait_seconds × (bonus / queue_aging_equalization)
grows linearly with queue time. After
queue_aging_equalization seconds a cold request's
age_score equals the warm bonus, guaranteeing starvation
is impossible.
Python's sort is stable, so equal-priority entries stay in arrival order
(FIFO within the same priority band).
Session affinity
----------------
Session affinity is applied inside _resolve_backend as a hint: if the
preferred backend is in the candidate set it is chosen first. Affinity
is re-pinned to whichever backend ultimately serves the request; it is
never queued on a preferred backend when a free alternative exists.
preferred backend is in the candidate set it is chosen first.
Preemption prevention
---------------------
SlotTracker.can_accept() enforces the per-backend max_models limit. A
backend with max_models=1 that is serving model-A will reject model-B
requests until all model-A slots are released, preventing llama.cpp from
preempting the in-flight generation.
SlotTracker.can_accept() enforces the per-backend max_models limit.
"""
def __init__(
@@ -57,14 +62,24 @@ class Scheduler:
slot_tracker: SlotTracker,
session_store: SessionStore,
policy: RoutingPolicy,
max_queue_skip: int = 0,
model_affinity_sched_bonus: int = 0,
queue_aging_equalization: float = 30.0,
wakeup_interval: float = 5.0,
) -> None:
self._queue = queue
self._registry = registry
self._slots = slot_tracker
self._sessions = session_store
self._policy = policy
self._max_queue_skip = max_queue_skip
self._affinity_bonus = model_affinity_sched_bonus
# Points per second so that age_score == bonus after queue_aging_equalization s.
self._aging_rate: float = (
model_affinity_sched_bonus / queue_aging_equalization
if model_affinity_sched_bonus > 0 and queue_aging_equalization > 0
else 0.0
)
# Periodic retry covers cases where no explicit wakeup fires (e.g. sticky-window expiry).
self._wakeup_interval = max(wakeup_interval, 1.0)
self._task: asyncio.Task | None = None
def notify_slot_released(self) -> None:
@@ -81,10 +96,37 @@ class Scheduler:
async def _loop(self) -> None:
while True:
await self._queue.wakeup_event.wait()
try:
async with asyncio.timeout(self._wakeup_interval):
await self._queue.wakeup_event.wait()
except TimeoutError:
pass
self._queue.wakeup_event.clear()
await self._dispatch_all()
# ------------------------------------------------------------------
# Priority
# ------------------------------------------------------------------
def _priority(self, entry: QueueEntry) -> float:
age_score = entry.wait_seconds * self._aging_rate
warm_bonus = self._affinity_bonus if self._is_warm_for(entry) else 0
return age_score + warm_bonus
def _is_warm_for(self, entry: QueueEntry) -> bool:
"""True if the requested model is warm (active or in hold window) on any backend.
Slot availability is intentionally not checked here: the bonus reflects that
the model is already in GPU memory, so the request should sit at high priority
before a slot opens rather than only after one becomes free.
"""
if not entry.model_id:
return False
return any(
self._slots.is_warm(b.url, entry.model_id)
for b in self._registry.get_backends_for_model(entry.model_id)
)
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
@@ -92,7 +134,6 @@ class Scheduler:
async def _dispatch_all(self) -> None:
entries = await self._queue.pending()
# Prune stale futures first so they don't count in skip logic.
for entry in entries:
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
@@ -101,70 +142,95 @@ class Scheduler:
if not entries:
return
dispatched: set[int] = set()
draining_for: dict[str, str] = {}
if self._max_queue_skip > 0:
await self._affinity_pass(entries, dispatched)
if self._affinity_bonus > 0:
entries = sorted(entries, key=self._priority, reverse=True)
entries, draining_for = await self._dispatch_unrepresented(entries)
# Standard FIFO pass for all remaining live entries.
for entry in entries:
if id(entry) in dispatched:
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
continue
# Skip backends reserved for draining by a different model.
skip_urls = frozenset(
url for url, reserved in draining_for.items()
if reserved != entry.model_id
)
if await self._try_dispatch(entry, skip_urls=skip_urls):
await self._queue.remove(entry)
async def _dispatch_unrepresented(
self, entries: list[QueueEntry]
) -> tuple[list[QueueEntry], dict[str, str]]:
"""Diversity pass: dispatch one entry per model with no active/warm backend.
Processes entries in caller-supplied priority order. Returns the entries
that were not dispatched and a draining_for map (backend url → model) for
backends that should be reserved: a free physical slot exists but max_models
blocks the unrepresented model, so the warm model must not refill that slot.
"""
dispatched_models: set[str] = set()
remaining: list[QueueEntry] = []
draining_for: dict[str, str] = {}
for entry in entries:
model = entry.model_id
if not model or self._is_warm_for(entry) or model in dispatched_models:
remaining.append(entry)
continue
if entry.future is None or entry.future.done():
await self._queue.remove(entry)
continue
# Clear any sticky window that is the sole barrier for this model so that
# a continuously-renewing sticky (model A keeps re-queuing) cannot starve model B.
for b in self._registry.get_backends_for_model(model):
self._slots.waive_sticky_if_idle(b.url, model)
if await self._try_dispatch(entry):
dispatched.add(id(entry))
await self._queue.remove(entry)
dispatched_models.add(model)
else:
drain_url = self._find_drain_backend(model)
if drain_url:
draining_for[drain_url] = model
remaining.append(entry)
async def _affinity_pass(
self, entries: list[QueueEntry], dispatched: set[int]
) -> None:
"""Phase 1: promote entries whose model is already active on a free backend.
return remaining, draining_for
Stops scanning as soon as an entry with skip_count >= max_queue_skip is
encountered — that entry is frozen and must be served by the FIFO pass.
def _find_drain_backend(self, model_id: str) -> str | None:
"""Return the url of the backend with the most free slots blocked only by max_models.
Called when the diversity pass cannot dispatch model_id. Reserving this backend
prevents the warm model from refilling the free slot, allowing it to drain so
model_id can start. Returns None if no such backend exists.
A backend is only eligible if every model currently active on it has at least one
*other* live backend it can use, so we never completely starve a model by blocking
its only available backend.
"""
for i, entry in enumerate(entries):
if id(entry) in dispatched:
best_url: str | None = None
best_free = 0
for b in self._registry.get_backends_for_model(model_id):
if not self._slots.is_blocked_by_max_models(b.url, model_id):
continue
if entry.skip_count >= self._max_queue_skip:
break # frozen head-of-line; FIFO pass must handle it next
free = self._slots.free_slot_count(b.url)
if free <= best_free:
continue
# Skip if any active model would have no alternative backend after this reserve.
active = self._slots.active_model_set(b.url)
if not all(
any(alt.url != b.url for alt in self._registry.get_backends_for_model(m))
for m in active
):
continue
best_free = free
best_url = b.url
return best_url
if await self._try_dispatch_affinity(entry):
dispatched.add(id(entry))
await self._queue.remove(entry)
# Bump skip_count for every earlier entry we bypassed.
for j in range(i):
if id(entries[j]) not in dispatched:
entries[j].skip_count += 1
async def _try_dispatch_affinity(self, entry: QueueEntry) -> bool:
"""Dispatch only to a backend that already has entry.model_id in-flight.
Skipped when a completely idle backend exists — piling onto a warm backend
while another is idle reduces effective capacity without KV-cache benefit
(different conversations don't share KV cache across sessions).
"""
if not entry.model_id:
return False
live_backends = self._registry.get_backends_for_model(entry.model_id)
if any(
self._slots.usage(b.url)[0] == 0 and self._slots.can_accept(b.url, entry.model_id)
for b in live_backends
):
return False
active_backends = [
b for b in live_backends
if entry.model_id in self._slots.active_model_set(b.url)
and self._slots.can_accept(b.url, entry.model_id)
]
if not active_backends:
return False
return await self._acquire_and_resolve(entry, active_backends)
async def _try_dispatch(self, entry: QueueEntry) -> bool:
"""Standard dispatch: any live backend that can accept this model."""
async def _try_dispatch(
self, entry: QueueEntry, skip_urls: frozenset[str] = frozenset()
) -> bool:
"""Dispatch to any live backend that can accept this model."""
if entry.model_id:
live_backends = self._registry.get_backends_for_model(entry.model_id)
else:
@@ -174,7 +240,8 @@ class Scheduler:
return False
free_backends = [
b for b in live_backends if self._slots.can_accept(b.url, entry.model_id)
b for b in live_backends
if b.url not in skip_urls and self._slots.can_accept(b.url, entry.model_id)
]
if not free_backends:
return False
@@ -216,7 +283,13 @@ class Scheduler:
if b.url == preferred_url:
return b
backend_candidates = [
BackendCandidate(url=b.url, index=i) for i, b in enumerate(candidates)
BackendCandidate(
url=b.url,
index=i,
slots_acquired=self._slots.usage(b.url)[0],
slots_total=self._slots.usage(b.url)[1],
)
for i, b in enumerate(candidates)
]
chosen = self._policy.select(entry.model_id or "_any", backend_candidates)
return next(b for b in candidates if b.url == chosen.url)

View File

@@ -60,9 +60,9 @@ class SessionStore:
session = self._sessions[session_id]
if model_id is not None:
session.model_id = model_id
if messages is not None:
if messages:
session.last_message_index = len(messages)
session.prefix_hash = compute_prefix_hash(messages)
session.prefix_hash = compute_prefix_hash([messages[-1]])
if preferred_backend is not None:
session.preferred_backend = preferred_backend
session.touch()
@@ -105,7 +105,7 @@ class SessionStore:
k = s.last_message_index
if s.is_expired(self._ttl) or not s.preferred_backend or not s.prefix_hash or k == 0 or k > len(messages):
return 0
return k if compute_prefix_hash(messages[:k]) == s.prefix_hash else 0
return k if compute_prefix_hash([messages[k - 1]]) == s.prefix_hash else 0
async def find_by_prefix(self, messages: list[dict]) -> str | None:
"""Return the preferred backend whose stored conversation is a prefix of messages.

View File

@@ -86,6 +86,15 @@ class SlotTracker:
return True
return self._can_acquire(state, model_id)
def is_warm(self, url: str, model_id: str) -> bool:
"""True if model_id is active or in warm-hold window on this backend."""
state = self._slots.get(url)
if state is None:
return False
if model_id in state.active_models:
return True
return state.sticky_model == model_id and time.monotonic() < state.sticky_until
def active_model_set(self, url: str) -> frozenset[str]:
"""Models currently in-flight on this backend."""
state = self._slots.get(url)
@@ -93,6 +102,46 @@ class SlotTracker:
return frozenset()
return frozenset(state.active_models)
def free_slot_count(self, url: str) -> int:
state = self._slots.get(url)
if state is None:
return 0
return max(0, state.capacity - state.acquired)
def is_blocked_by_max_models(self, url: str, model_id: str) -> bool:
"""True if a physical slot is free but max_models is the sole blocker for model_id."""
state = self._slots.get(url)
if state is None or state.acquired >= state.capacity:
return False
if model_id in state.active_models:
return False
return state.max_models is not None and len(state.active_models) >= state.max_models
def waive_sticky_if_idle(self, url: str, model_id: str) -> bool:
"""Clear the sticky window if it is the sole barrier for model_id on this backend.
Only waives when all three conditions hold:
1. Physical slot is free (acquired < capacity).
2. The sticky window is active for a *different* model.
3. max_models would NOT independently block model_id.
Returns True if the window was cleared.
"""
state = self._slots.get(url)
if state is None:
return False
if state.acquired >= state.capacity:
return False
if not state.sticky_model or time.monotonic() >= state.sticky_until:
return False
if state.sticky_model == model_id:
return False
if model_id not in state.active_models:
if state.max_models is not None and len(state.active_models) >= state.max_models:
return False
state.sticky_model = None
state.sticky_until = 0.0
return True
def usage(self, url: str) -> tuple[int, int]:
state = self._slots.get(url)
if state is None:

View File

@@ -49,7 +49,8 @@ class TestProxyConfig(unittest.TestCase):
self.assertEqual(cfg.session_idle_ttl, 300.0)
self.assertEqual(cfg.backends, [])
self.assertIsNone(cfg.default_max_models)
self.assertEqual(cfg.max_queue_skip, 0)
self.assertEqual(cfg.model_affinity_sched_bonus, 0)
self.assertEqual(cfg.queue_aging_equalization, 30.0)
def test_from_file_minimal(self):
path = self._write_config(
@@ -71,7 +72,8 @@ class TestProxyConfig(unittest.TestCase):
"slot_wait_timeout": 60,
"session_idle_ttl": 600,
"default_max_models": 2,
"max_queue_skip": 5,
"model_affinity_sched_bonus": 10,
"queue_aging_equalization": 30.0,
"backends": [
{"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1},
{"url": "http://b2/"},
@@ -85,7 +87,8 @@ class TestProxyConfig(unittest.TestCase):
self.assertEqual(cfg.api_keys, ["key1", "key2"])
self.assertEqual(cfg.poll_interval, 10)
self.assertEqual(cfg.default_max_models, 2)
self.assertEqual(cfg.max_queue_skip, 5)
self.assertEqual(cfg.model_affinity_sched_bonus, 10)
self.assertEqual(cfg.queue_aging_equalization, 30.0)
self.assertEqual(cfg.backends[0].api_key, "secret")
self.assertEqual(cfg.backends[0].model_ids, ["m1"])
self.assertEqual(cfg.backends[0].max_models, 1)

View File

@@ -44,12 +44,13 @@ class TestHealthEndpoint(unittest.TestCase):
resp = client.get("/health")
self.assertEqual(resp.status_code, 503)
def test_health_with_api_key_required(self):
def test_health_exempt_from_auth(self):
cfg = _proxy_config(["http://b1"], api_keys=["mykey"])
app = create_app(cfg)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/health")
self.assertEqual(resp.status_code, 401)
# /health is auth-exempt; no live backends so 503, never 401
self.assertNotEqual(resp.status_code, 401)
def test_health_with_valid_api_key(self):
cfg = _proxy_config([], api_keys=["mykey"])

View File

@@ -57,9 +57,14 @@ class TestMonitorEndpoints(unittest.TestCase):
data = TestClient(app).get("/monitor/data").json()
for key in ("uptime", "total_requests", "queue_depth", "session_count",
"live_backend_count", "backends", "queue",
"model_stats", "backend_stats",
"model_stats", "accept_time",
"session_hits", "session_misses", "session_hit_rate"):
self.assertIn(key, data)
# backend cumulative stats are now merged into the backends list
backend = data["backends"][0]
for stat_key in ("stat_requests", "stat_share_pct", "stat_utilization_pct",
"stat_avg_duration_s", "stat_affinity"):
self.assertIn(stat_key, backend)
def test_monitor_data_backend_fields(self):
state = BackendState(config=BackendConfig(url="http://b1"), live=True, models=["m1"])

View File

@@ -22,7 +22,12 @@ def _entry(**kwargs) -> QueueEntry:
class TestScheduler(unittest.IsolatedAsyncioTestCase):
def _make_scheduler(self, live_backends=None, max_queue_skip: int = 0):
def _make_scheduler(
self,
live_backends=None,
model_affinity_sched_bonus: int = 0,
queue_aging_equalization: float = 30.0,
):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in (live_backends or [])])
registry = BackendRegistry(cfg)
for state in (live_backends or []):
@@ -41,7 +46,8 @@ class TestScheduler(unittest.IsolatedAsyncioTestCase):
slot_tracker=slot_tracker,
session_store=session_store,
policy=RoundRobinPolicy(),
max_queue_skip=max_queue_skip,
model_affinity_sched_bonus=model_affinity_sched_bonus,
queue_aging_equalization=queue_aging_equalization,
)
return scheduler, queue, registry, slot_tracker, session_store
@@ -212,108 +218,538 @@ class TestScheduler(unittest.IsolatedAsyncioTestCase):
self.assertTrue(entry.future.done(), "same model should still be dispatchable")
# ------------------------------------------------------------------
# N-skip reordering
# Priority scheduling
# ------------------------------------------------------------------
async def test_no_reorder_when_max_queue_skip_zero(self):
"""Default FIFO: model-B request is not promoted over model-A."""
async def test_pure_fifo_when_bonus_zero(self):
"""Default (bonus=0): blocked head-of-queue does not prevent later entries."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=0)
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2])
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# Fill b1; b2 is free with m2 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m2")
# Queue: [m1-request (blocked), m2-request (could go to b2)]
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
# Release b2 slot so m2 can be served
await slots.release("http://b2", "m2")
await scheduler._dispatch_all()
# m2 can be dispatched even with max_queue_skip=0 because _dispatch_all
# scans all entries (not strict head-of-line per model)
self.assertFalse(e_m1.future.done())
self.assertTrue(e_m2.future.done())
# skip_count must NOT be bumped when max_queue_skip=0
self.assertEqual(e_m1.skip_count, 0)
async def test_affinity_promotes_matching_model(self):
"""With max_queue_skip>0, a matching model gets promoted."""
async def test_warm_bonus_applies_when_only_one_model_in_queue(self):
"""With bonus>0 and a single model type in the queue, the warm model fills both free slots."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=3)
slots.set_capacity("http://b1", 2)
# b1 already has m1 in-flight
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# capacity=3: one slot occupied, two free
slots.set_capacity("http://b1", 3)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: [m2-entry (no affinity), m1-entry (affinity match)]
e_other = _entry(model_id="m2") # no backend serves m2
# Two m1 requests — no competing model, diversity pass is a no-op
e1 = _entry(model_id="m1")
e2 = _entry(model_id="m1")
await queue.enqueue(e1)
await queue.enqueue(e2)
await scheduler._dispatch_all()
# Both dispatched: warm model freely fills available slots when nothing else is waiting
self.assertTrue(e1.future.done())
self.assertTrue(e2.future.done())
async def test_warm_model_does_not_block_unrepresented_model(self):
"""With bonus>0, an unrepresented cold model gets the diversity slot before a warm one."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
# m1 is in-flight on b1 (warm); m2 is cold and unrepresented
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue: cold m2 request arrives first, warm m1 request arrives second
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m1 entry is promoted (affinity pass), m2 stays (no backend)
# Diversity pass gives the free slot to m2 (unrepresented) regardless of warm bonus
self.assertTrue(e_m2.future.done())
self.assertFalse(e_m1.future.done())
# ------------------------------------------------------------------
# Model diversity (unrepresented-first pass)
# ------------------------------------------------------------------
async def test_diversity_dispatches_cold_model_over_warm(self):
"""With one free slot, the diversity pass gives it to a cold model, not another warm request."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
# 2 slots: 1 occupied by m1 (warm), 1 free
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# Queue m1 (gets warm_bonus=10) then m2 (cold, bonus=0)
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# m2 dispatched by diversity pass despite lower priority; m1 has no free slot left
self.assertTrue(e_m2.future.done(), "cold m2 should win the diversity slot")
self.assertFalse(e_m1.future.done(), "warm m1 has no slot left after diversity pass")
async def test_diversity_loads_each_model_on_separate_backend(self):
"""Two free backends and two queued models → each model lands on its own backend."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1, b2], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 1)
slots.set_capacity("http://b2", 1)
# m1 is warm on b1; m2 is cold; only b2 is free
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass dispatches m2 to b2; priority pass dispatches m1 — no free slot remains
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b2")
self.assertFalse(e_m1.future.done(), "b1 is full, b2 taken by m2")
async def test_diversity_respects_max_models(self):
"""max_models=1 prevents the diversity pass from loading a second model simultaneously."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 2)
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass tries m2 but can_accept returns False (max_models=1 with m1 active)
# Priority pass dispatches m1 (same model, warm, 1 slot free)
self.assertTrue(e_m1.future.done())
self.assertFalse(e_other.future.done())
# e_other was bypassed once
self.assertEqual(e_other.skip_count, 1)
self.assertFalse(e_m2.future.done())
async def test_affinity_skips_when_idle_backend_available(self):
"""Warm-model routing is bypassed when a completely idle backend exists."""
b1 = _make_state("http://b1", models=["m1"])
b2 = _make_state("http://b2", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1, b2], max_queue_skip=3)
async def test_diversity_skipped_when_bonus_zero(self):
"""Pure FIFO mode (bonus=0): no diversity pass, FIFO order holds."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# b2 is warm (m1 active, 1/2 slots used); b1 is completely idle (0/2)
# m1 arrives first (FIFO should serve it first), m2 is cold
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# FIFO: m1 dispatched first (it arrived first), m2 has no slot left
self.assertTrue(e_m1.future.done())
self.assertFalse(e_m2.future.done())
async def test_loop_retries_after_sticky_window_expires(self):
"""Scheduler loop dispatches a blocked entry once the sticky window expires."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1])
slots.set_capacity("http://b1", 2)
slots.set_model_unload_delay(0.05)
# Acquire and release m1 — starts a 0.05 s sticky window that blocks m2.
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
await slots.release("http://b1", "m1")
entry = _entry(model_id="m2")
await queue.enqueue(entry)
await scheduler._dispatch_all()
self.assertFalse(entry.future.done(), "m2 should be blocked by sticky window")
# Run the loop with a short interval so it retries before the test times out.
scheduler._wakeup_interval = 0.1
scheduler.start()
try:
async with asyncio.timeout(1.0):
result = await entry.future
finally:
await scheduler.stop()
self.assertEqual(result.url, "http://b1")
async def test_diversity_waives_sticky_for_unrepresented_model(self):
"""Diversity pass clears a sticky window blocking an unrepresented model."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10
)
slots.set_capacity("http://b1", 2)
slots.set_model_unload_delay(60.0)
# m1 was active and released — sticky window now blocks m2
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
await slots.release("http://b1", "m1")
self.assertFalse(slots.can_accept("http://b1", "m2"))
e_m1 = _entry(model_id="m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m1)
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
# Diversity pass waives sticky for m2 and dispatches it
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b1")
async def test_aging_overtakes_warm_bonus(self):
"""After equalization time, an aged cold request outranks the warm bonus."""
b1 = _make_state("http://b1", models=["m1", "m2"])
# equalization=0.1s so aging is fast enough to test synchronously
scheduler, queue, _, slots, _ = self._make_scheduler(
[b1], model_affinity_sched_bonus=10, queue_aging_equalization=0.1
)
# capacity=2: one slot occupied by m1, one free — only one more can be dispatched
slots.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
# m2 request arrives and waits long enough to exceed the bonus
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await asyncio.sleep(0.15) # age_score > bonus after equalization
# m1 warm request arrives after m2 has already aged past equalization
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m2's age_score now exceeds the warm bonus → m2 dispatched first
self.assertTrue(e_m2.future.done())
self.assertFalse(e_m1.future.done())
class TestDrainProtection(unittest.IsolatedAsyncioTestCase):
"""Drain protection: when a free slot exists but max_models blocks an unrepresented
model, the backend is reserved so the warm model cannot refill it."""
def _make(self, backends, **kw):
cfg = ProxyConfig(backends=[BackendConfig(url=b.url) for b in backends])
registry = BackendRegistry(cfg)
for state in backends:
registry._states[state.url] = state
registry._rebuild_index()
slots = SlotTracker()
for state in backends:
slots.set_capacity(state.url, 2)
queue = RequestQueue()
scheduler = Scheduler(
queue=queue,
registry=registry,
slot_tracker=slots,
session_store=SessionStore(),
policy=RoundRobinPolicy(),
model_affinity_sched_bonus=10,
**kw,
)
return scheduler, queue, slots
# ------------------------------------------------------------------
# 1 backend — no drain (warm model has no alternative)
# ------------------------------------------------------------------
async def test_single_backend_no_drain_reservation(self):
"""With 1 backend, drain is not reserved: the active model must still progress."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1])
slots.set_max_models("http://b1", 1)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# m2 can't dispatch (max_models), but m1 is NOT blocked because b1 is
# the only backend and we must not starve m1 completely.
self.assertFalse(e_m2.future.done())
self.assertTrue(e_m1.future.done())
self.assertEqual(e_m1.future.result().url, "http://b1")
async def test_single_backend_second_model_starts_after_full_drain(self):
"""With 1 backend, m2 dispatches once m1 fully drains (no drain protection needed)."""
b1 = _make_state("http://b1", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1])
slots.set_max_models("http://b1", 1)
slots.set_model_unload_delay(60.0)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done())
await slots.release("http://b1", "m1") # m1 fully drains → sticky starts
await scheduler._dispatch_all()
# Diversity pass waives sticky for m2 → dispatches
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b1")
# ------------------------------------------------------------------
# 2 backends — drain reserved on the one with the free slot
# ------------------------------------------------------------------
async def test_two_backends_drain_redirects_warm_model(self):
"""Drain protection: warm model is redirected to the other backend while one drains."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1"]) # m1 only on b2
scheduler, queue, slots = self._make([b1, b2])
slots.set_max_models("http://b1", 1)
# b1: m1 active (1/2), 1 slot free but max_models blocks m2
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m2) # processed first by diversity pass
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done()) # m2 blocked by max_models
self.assertTrue(e_m1.future.done()) # m1 redirected to b2
self.assertEqual(e_m1.future.result().url, "http://b2")
async def test_two_backends_both_saturated_then_second_model_arrives(self):
"""Screenshot scenario: both backends full with m1; m2 eventually starts once one drains.
Step-by-step:
1. Both at 2/2 → m2 queued, no free slots, nothing dispatches.
2. One slot frees on b1; b2 still 2/2 → drain reserved on b1, but m1 also has
nowhere to go, so nothing dispatches.
3. A slot on b2 also frees → m1 goes to b2 (b1 reserved), not b1.
4. b1's last slot drains → sticky cleared → m2 dispatches to b1.
"""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1, b2])
slots.set_max_models("http://b1", 1)
slots.set_max_models("http://b2", 1)
slots.set_model_unload_delay(60.0)
for url in ("http://b1", "http://b2"):
async with asyncio.timeout(1.0):
await slots.acquire(url, "m1")
async with asyncio.timeout(1.0):
await slots.acquire(url, "m1")
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done())
# Step 2: one slot on b1 frees; b2 still 2/2 → m1 has nowhere to go
await slots.release("http://b1", "m1")
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done())
# Step 3: slot on b2 frees → new m1 must go to b2, not refill b1
await slots.release("http://b2", "m1")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done())
self.assertTrue(e_m1.future.done())
self.assertEqual(e_m1.future.result().url, "http://b2")
# Step 4: b1's last m1 slot drains → sticky cleared → m2 dispatches
await slots.release("http://b1", "m1")
await scheduler._dispatch_all()
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b1")
async def test_drain_only_reserves_one_backend(self):
"""Only the backend with the most free slots is reserved; the other stays open."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1, b2])
slots.set_capacity("http://b1", 3) # b1 has more free slots
slots.set_capacity("http://b2", 2)
slots.set_max_models("http://b1", 1)
slots.set_max_models("http://b2", 1)
# b1: 1/3 with m1 (2 free); b2: 1/2 with m1 (1 free)
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b2", "m1")
entry = _entry(model_id="m1")
await queue.enqueue(entry)
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
self.assertTrue(entry.future.done())
# Affinity pass must not force the request onto the warm backend (b2).
# Round-robin picks b1 first (b1 is index 0 in the registry), which is
# correct: b1 is idle and should absorb the load.
self.assertEqual(entry.future.result().url, "http://b1")
self.assertFalse(e_m2.future.done())
# b1 reserved (most free), b2 not → m1 goes to b2
self.assertTrue(e_m1.future.done())
self.assertEqual(e_m1.future.result().url, "http://b2")
async def test_skip_count_caps_reordering(self):
"""Once skip_count reaches max_queue_skip the entry freezes at head."""
b1 = _make_state("http://b1", models=["m1"])
scheduler, queue, _, slots, _ = self._make_scheduler([b1], max_queue_skip=2)
slots.set_capacity("http://b1", 4)
# ------------------------------------------------------------------
# 3 backends
# ------------------------------------------------------------------
async def test_three_backends_drain_redirects_warm_model(self):
"""With 3 backends all full with m1, drain protection redirects m1 away from the
reserved backend when a slot frees, letting that backend drain for m2."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
b3 = _make_state("http://b3", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1, b2, b3])
for url in ("http://b1", "http://b2", "http://b3"):
slots.set_max_models(url, 1)
# b1 full at 2/2; b2 and b3 each hold 1/2 with m1 (one free slot each)
for url in ("http://b1", "http://b2", "http://b3"):
async with asyncio.timeout(1.0):
await slots.acquire(url, "m1")
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1") # b1 → 2/2
# m2 arrives — all backends have m1 active (max_models=1 blocks m2 everywhere)
e_m2 = _entry(model_id="m2")
await queue.enqueue(e_m2)
await scheduler._dispatch_all()
self.assertFalse(e_m2.future.done())
# One slot on b1 frees → b1 has 1/2 with m1 active (max_models still blocks m2)
await slots.release("http://b1", "m1")
e_m1_new = _entry(model_id="m1")
await queue.enqueue(e_m1_new)
await scheduler._dispatch_all()
# m2 still blocked (m1 still active in b1's last slot)
self.assertFalse(e_m2.future.done())
# m1 is redirected to b2 or b3 (not b1, which is reserved for draining)
self.assertTrue(e_m1_new.future.done())
self.assertNotEqual(e_m1_new.future.result().url, "http://b1")
# ------------------------------------------------------------------
# num_parallel > 1 (max_models not set / = None)
# ------------------------------------------------------------------
async def test_no_max_models_no_drain_reservation(self):
"""When max_models is not set, is_blocked_by_max_models is always False: no draining."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1"])
scheduler, queue, slots = self._make([b1, b2])
# No set_max_models call → max_models=None → can_accept never fails on model count
# b1 has m1 active
async with asyncio.timeout(1.0):
await slots.acquire("http://b1", "m1")
e_other = _entry(model_id="m2")
e_other.skip_count = 2 # already at limit — must not be bypassed
e_m2 = _entry(model_id="m2")
e_m1 = _entry(model_id="m1")
await queue.enqueue(e_other)
await queue.enqueue(e_m2)
await queue.enqueue(e_m1)
await scheduler._dispatch_all()
# Affinity pass stops at e_other (skip_count >= max_queue_skip),
# so e_m1 is NOT promoted via affinity. Both get a chance in FIFO pass.
# e_other (m2) has no backend → stays. e_m1 gets dispatched in FIFO pass.
self.assertTrue(e_m1.future.done())
# skip_count must NOT increase further (entry was frozen)
self.assertEqual(e_other.skip_count, 2)
# Without max_models, m2 can take b1's free slot directly (diversity pass)
self.assertTrue(e_m2.future.done())
self.assertEqual(e_m2.future.result().url, "http://b1")
# ------------------------------------------------------------------
# Many-same-model then second model: sustained load scenario
# ------------------------------------------------------------------
async def test_sustained_m1_load_then_m2_eventually_served(self):
"""Many m1 requests fill both backends; m2 requests eventually get served once one
backend drains, even with a continuous stream of new m1 requests arriving."""
b1 = _make_state("http://b1", models=["m1", "m2"])
b2 = _make_state("http://b2", models=["m1", "m2"])
scheduler, queue, slots = self._make([b1, b2])
slots.set_max_models("http://b1", 1)
slots.set_max_models("http://b2", 1)
slots.set_model_unload_delay(60.0)
# Phase 1: saturate both backends with m1
for url in ("http://b1", "http://b2"):
async with asyncio.timeout(1.0):
await slots.acquire(url, "m1")
async with asyncio.timeout(1.0):
await slots.acquire(url, "m1")
# Phase 2: m2 requests arrive — both backends full, nothing dispatches
m2_entries = [_entry(model_id="m2") for _ in range(3)]
for e in m2_entries:
await queue.enqueue(e)
await scheduler._dispatch_all()
self.assertEqual(sum(1 for e in m2_entries if e.future.done()), 0)
# Phase 3: one slot on each backend frees; new m1 arrives
await slots.release("http://b1", "m1")
await slots.release("http://b2", "m1")
# b1: 1/2 with m1 (1 free, max_models blocks m2 → drain reserved)
# b2: 1/2 with m1 (1 free, m1 can still go here)
new_m1 = _entry(model_id="m1")
await queue.enqueue(new_m1)
await scheduler._dispatch_all()
# Drain protection: b1 reserved for m2, m1 redirected to b2
self.assertEqual(sum(1 for e in m2_entries if e.future.done()), 0)
self.assertTrue(new_m1.future.done())
self.assertEqual(new_m1.future.result().url, "http://b2")
# Phase 4: b1's last m1 slot finishes → b1 fully drains → m2 dispatches
await slots.release("http://b1", "m1")
await scheduler._dispatch_all()
self.assertGreater(sum(1 for e in m2_entries if e.future.done()), 0)
dispatched_url = next(e.future.result().url for e in m2_entries if e.future.done())
self.assertEqual(dispatched_url, "http://b1")
if __name__ == "__main__":

View File

@@ -319,5 +319,81 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
self.assertEqual(tracker.global_model_usage("bigmodel"), (0, 1))
# ------------------------------------------------------------------
# waive_sticky_if_idle tests
# ------------------------------------------------------------------
async def test_waive_sticky_clears_window_for_unrepresented_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertTrue(waived)
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_waive_sticky_noop_when_no_free_slot(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 1)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
# Slot is now free but let's fill it first
await tracker.acquire("http://b", "model-a")
# sticky is cleared when re-acquired; set up sticky again manually is tricky,
# so use a fresh tracker with the slot held before release
tracker2 = SlotTracker()
tracker2.set_capacity("http://b", 1)
tracker2.set_model_unload_delay(60.0)
await tracker2.acquire("http://b", "model-a")
await tracker2.release("http://b", "model-a")
# Now exhaust the slot
await tracker2.acquire("http://b", "model-a")
waived = tracker2.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
async def test_waive_sticky_noop_for_same_model(self):
"""waive_sticky_if_idle must not clear the window for the sticky model itself."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
waived = tracker.waive_sticky_if_idle("http://b", "model-a")
self.assertFalse(waived)
self.assertTrue(tracker.can_accept("http://b", "model-a"))
def test_waive_sticky_noop_when_max_models_also_blocks(self):
"""Do not waive if max_models would still block the requesting model.
The sticky window is only set when active_models empties, so this scenario
can only be created by direct state manipulation (not through normal acquire/release).
The guard is still present in waive_sticky_if_idle for defensive correctness.
"""
import time as _time
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_max_models("http://b", 1)
state = tracker._ensure("http://b")
# Manually inject: sticky=model-a, active_models={model-c: 1}
state.sticky_model = "model-a"
state.sticky_until = _time.monotonic() + 60.0
state.active_models["model-c"] = 1
state.acquired = 1
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
def test_waive_sticky_noop_when_no_active_window(self):
"""Returns False and has no effect when there is no sticky window."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
waived = tracker.waive_sticky_if_idle("http://b", "model-b")
self.assertFalse(waived)
if __name__ == "__main__":
unittest.main()