Compare commits
4 Commits
71b3b20a4d
...
5211b2f1a0
| Author | SHA1 | Date | |
|---|---|---|---|
| 5211b2f1a0 | |||
| 13fb341354 | |||
| bcebaf0e93 | |||
| 981272e6ca |
@@ -4,7 +4,8 @@
|
||||
"Bash(python -m pytest)",
|
||||
"Bash(python -m pytest --tb=short -q)",
|
||||
"Bash(python -m pytest tests/test_slot_tracker.py -v)",
|
||||
"Bash(python -m pytest tests/test_scheduler.py -v)"
|
||||
"Bash(python -m pytest tests/test_scheduler.py -v)",
|
||||
"Bash(python -m pytest tests/test_monitor.py::TestMonitorEndpoints::test_monitor_data_structure -v)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri
|
||||
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
|
||||
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
|
||||
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
|
||||
| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. |
|
||||
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
|
||||
|
||||
### Per-backend fields
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"default_slot_capacity": 1,
|
||||
"default_max_models": 1,
|
||||
"max_queue_skip": 4,
|
||||
"model_unload_delay": 3.0,
|
||||
"model_limits": {
|
||||
"my-very-large-model": 1
|
||||
},
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "llamacpp-ha"
|
||||
version = "0.3.0"
|
||||
version = "0.5.0"
|
||||
description = "Smart load balancer for llama.cpp servers"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
|
||||
@@ -44,6 +44,10 @@ class ProxyConfig(BaseSettings):
|
||||
# positions; each bypassed entry accumulates a skip count and is immune to
|
||||
# further skipping once it reaches N.
|
||||
max_queue_skip: int = 0
|
||||
# Seconds to keep a backend sticky to its last model after all slots drain.
|
||||
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
|
||||
# generation) that arrive shortly after the main response. 0 = disabled.
|
||||
model_unload_delay: float = 3.0
|
||||
# Per-model global concurrency cap across all backends.
|
||||
# Use for very large models that cannot run concurrently due to RAM constraints.
|
||||
model_limits: dict[str, int] = {}
|
||||
|
||||
@@ -16,6 +16,14 @@ from .slot_tracker import SlotTracker
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_STREAM_CHUNK = 8192
|
||||
_STREAMING_CONTENT_TYPES = (
|
||||
"text/event-stream",
|
||||
"audio/mpeg",
|
||||
"audio/ogg",
|
||||
"audio/wav",
|
||||
"audio/webm",
|
||||
"audio/aac",
|
||||
)
|
||||
_HOP_BY_HOP = frozenset(
|
||||
[
|
||||
"connection",
|
||||
@@ -74,7 +82,7 @@ async def forward_request(
|
||||
resp = await resp_ctx.__aenter__()
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
is_streaming = "text/event-stream" in content_type
|
||||
is_streaming = any(ct in content_type for ct in _STREAMING_CONTENT_TYPES)
|
||||
|
||||
response_headers = {
|
||||
k: v
|
||||
|
||||
@@ -17,10 +17,39 @@ class ProxyStats:
|
||||
"""Per-app counters and timing. Created by create_app, passed to build_router."""
|
||||
start_time: float = field(default_factory=time.monotonic)
|
||||
total_requests: int = 0
|
||||
session_hits: int = 0
|
||||
session_misses: int = 0
|
||||
new_sessions: int = 0
|
||||
model_requests: dict[str, int] = field(default_factory=dict)
|
||||
model_tokens: dict[str, int] = field(default_factory=dict)
|
||||
backend_requests: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def increment_requests(self) -> None:
|
||||
self.total_requests += 1
|
||||
|
||||
def record_model(self, model_id: str, tokens: int | None) -> None:
|
||||
if not model_id:
|
||||
return
|
||||
self.model_requests[model_id] = self.model_requests.get(model_id, 0) + 1
|
||||
if tokens:
|
||||
self.model_tokens[model_id] = self.model_tokens.get(model_id, 0) + tokens
|
||||
|
||||
def record_backend(self, url: str) -> None:
|
||||
self.backend_requests[url] = self.backend_requests.get(url, 0) + 1
|
||||
|
||||
def record_session(self, had_session: bool, preferred_url: str | None, actual_url: str) -> None:
|
||||
if had_session and preferred_url:
|
||||
if actual_url == preferred_url:
|
||||
self.session_hits += 1
|
||||
else:
|
||||
self.session_misses += 1
|
||||
elif not had_session:
|
||||
self.new_sessions += 1
|
||||
|
||||
def session_hit_rate(self) -> int | None:
|
||||
total = self.session_hits + self.session_misses
|
||||
return round(self.session_hits / total * 100) if total else None
|
||||
|
||||
def uptime_str(self) -> str:
|
||||
secs = int(time.monotonic() - self.start_time)
|
||||
h, remainder = divmod(secs, 3600)
|
||||
@@ -48,11 +77,14 @@ _HTML = """<!DOCTYPE html>
|
||||
.badge-dead { background: #2c0f0f; color: #f85149; }
|
||||
.slots { color: #d29922; }
|
||||
.empty { color: #484f58; font-style: italic; }
|
||||
.hit { color: #3fb950; }
|
||||
.miss { color: #f85149; }
|
||||
#status { float: right; font-size: 0.8em; color: #8b949e; }
|
||||
.summary { display: flex; gap: 20px; flex-wrap: wrap; margin: 10px 0 20px; }
|
||||
.stat { background: #161b22; border: 1px solid #30363d; border-radius: 6px; padding: 10px 16px; }
|
||||
.stat-val { font-size: 1.6em; color: #58a6ff; }
|
||||
.stat-label { font-size: 0.75em; color: #8b949e; margin-top: 2px; }
|
||||
.num { text-align: right; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
@@ -65,12 +97,13 @@ _HTML = """<!DOCTYPE html>
|
||||
<div class="stat"><div class="stat-val" id="queue-depth">-</div><div class="stat-label">Queue Depth</div></div>
|
||||
<div class="stat"><div class="stat-val" id="session-count">-</div><div class="stat-label">Active Sessions</div></div>
|
||||
<div class="stat"><div class="stat-val" id="live-count">-</div><div class="stat-label">Live Backends</div></div>
|
||||
<div class="stat"><div class="stat-val" id="hit-rate">-</div><div class="stat-label">Session Hit Rate</div></div>
|
||||
</div>
|
||||
|
||||
<h2>Backends</h2>
|
||||
<table>
|
||||
<thead><tr><th>URL</th><th>Status</th><th>Active Model</th><th>Models</th><th>Slots</th><th>Last Poll</th></tr></thead>
|
||||
<tbody id="backends-body"><tr><td colspan="6" class="empty">Loading...</td></tr></tbody>
|
||||
<thead><tr><th>URL</th><th>Status</th><th>Active Model</th><th>Models</th><th>Slots</th><th class="num">Requests</th><th>Last Poll</th></tr></thead>
|
||||
<tbody id="backends-body"><tr><td colspan="7" class="empty">Loading...</td></tr></tbody>
|
||||
</table>
|
||||
|
||||
<h2>Queue</h2>
|
||||
@@ -79,10 +112,16 @@ _HTML = """<!DOCTYPE html>
|
||||
<tbody id="queue-body"><tr><td colspan="6" class="empty">Queue is empty</td></tr></tbody>
|
||||
</table>
|
||||
|
||||
<h2>Sessions by Model</h2>
|
||||
<h2>Model Stats</h2>
|
||||
<table>
|
||||
<thead><tr><th>Model</th><th>Active Sessions</th></tr></thead>
|
||||
<tbody id="sessions-body"><tr><td colspan="2" class="empty">No active sessions</td></tr></tbody>
|
||||
<thead><tr><th>Model</th><th class="num">Requests</th><th class="num">Est. Tokens</th><th class="num">Active Sessions</th></tr></thead>
|
||||
<tbody id="model-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
|
||||
</table>
|
||||
|
||||
<h2>Backend Stats</h2>
|
||||
<table>
|
||||
<thead><tr><th>Backend</th><th class="num">Requests</th><th class="num">Share</th><th>Session Affinity</th></tr></thead>
|
||||
<tbody id="backend-stats-body"><tr><td colspan="4" class="empty">No data yet</td></tr></tbody>
|
||||
</table>
|
||||
|
||||
<script>
|
||||
@@ -91,6 +130,10 @@ _HTML = """<!DOCTYPE html>
|
||||
return String(s).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>');
|
||||
}
|
||||
|
||||
function fmt(n) {
|
||||
return n >= 1000 ? (n/1000).toFixed(1) + 'k' : String(n);
|
||||
}
|
||||
|
||||
function render(data) {
|
||||
document.getElementById('uptime').textContent = data.uptime;
|
||||
document.getElementById('total-req').textContent = data.total_requests;
|
||||
@@ -98,9 +141,19 @@ _HTML = """<!DOCTYPE html>
|
||||
document.getElementById('session-count').textContent = data.session_count;
|
||||
document.getElementById('live-count').textContent = data.live_backend_count;
|
||||
|
||||
const hr = data.session_hit_rate;
|
||||
const hrEl = document.getElementById('hit-rate');
|
||||
if (hr == null) {
|
||||
hrEl.textContent = 'N/A';
|
||||
hrEl.className = 'stat-val';
|
||||
} else {
|
||||
hrEl.textContent = hr + '%';
|
||||
hrEl.className = 'stat-val ' + (hr >= 80 ? 'hit' : hr >= 50 ? 'slots' : 'miss');
|
||||
}
|
||||
|
||||
const bBody = document.getElementById('backends-body');
|
||||
if (!data.backends.length) {
|
||||
bBody.innerHTML = '<tr><td colspan="6" class="empty">No backends configured</td></tr>';
|
||||
bBody.innerHTML = '<tr><td colspan="7" class="empty">No backends configured</td></tr>';
|
||||
} else {
|
||||
bBody.innerHTML = data.backends.map(b => {
|
||||
const badge = b.live
|
||||
@@ -112,7 +165,8 @@ _HTML = """<!DOCTYPE html>
|
||||
const models = b.models.length ? esc(b.models.join(', ')) : '<span class="empty">none</span>';
|
||||
const slots = `<span class="slots">${b.slots_acquired}/${b.slots_total}</span>`;
|
||||
const age = b.last_poll_age == null ? '<span class="empty">never</span>' : esc(b.last_poll_age.toFixed(1)) + 's';
|
||||
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${active}</td><td>${models}</td><td>${slots}</td><td>${age}</td></tr>`;
|
||||
const reqs = b.requests > 0 ? fmt(b.requests) : '<span class="empty">0</span>';
|
||||
return `<tr><td>${esc(b.url)}</td><td>${badge}</td><td>${active}</td><td>${models}</td><td>${slots}</td><td class="num">${reqs}</td><td>${age}</td></tr>`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
@@ -127,15 +181,36 @@ _HTML = """<!DOCTYPE html>
|
||||
}).join('');
|
||||
}
|
||||
|
||||
const sBody = document.getElementById('sessions-body');
|
||||
const sbm = data.sessions_by_model;
|
||||
const keys = Object.keys(sbm);
|
||||
if (!keys.length) {
|
||||
sBody.innerHTML = '<tr><td colspan="2" class="empty">No active sessions</td></tr>';
|
||||
const mBody = document.getElementById('model-body');
|
||||
const ms = data.model_stats;
|
||||
const mKeys = Object.keys(ms).sort((a,b) => ms[b].requests - ms[a].requests);
|
||||
if (!mKeys.length) {
|
||||
mBody.innerHTML = '<tr><td colspan="4" class="empty">No data yet</td></tr>';
|
||||
} else {
|
||||
sBody.innerHTML = keys.map(m =>
|
||||
`<tr><td>${esc(m)}</td><td>${esc(sbm[m])}</td></tr>`
|
||||
).join('');
|
||||
mBody.innerHTML = mKeys.map(m => {
|
||||
const s = ms[m];
|
||||
const tok = s.estimated_tokens > 0 ? fmt(s.estimated_tokens) : '<span class="empty">-</span>';
|
||||
const sess = s.active_sessions > 0 ? s.active_sessions : '<span class="empty">0</span>';
|
||||
return `<tr><td>${esc(m)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${tok}</td><td class="num">${sess}</td></tr>`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
const bsBody = document.getElementById('backend-stats-body');
|
||||
const bs = data.backend_stats;
|
||||
const bsKeys = Object.keys(bs).sort((a,b) => bs[b].requests - bs[a].requests);
|
||||
if (!bsKeys.length || data.total_requests === 0) {
|
||||
bsBody.innerHTML = '<tr><td colspan="4" class="empty">No data yet</td></tr>';
|
||||
} else {
|
||||
bsBody.innerHTML = bsKeys.map(url => {
|
||||
const s = bs[url];
|
||||
const share = data.total_requests > 0
|
||||
? Math.round(s.requests / data.total_requests * 100) + '%'
|
||||
: '<span class="empty">-</span>';
|
||||
const affinity = s.session_hits + s.session_misses > 0
|
||||
? Math.round(s.session_hits / (s.session_hits + s.session_misses) * 100) + '% hit'
|
||||
: '<span class="empty">-</span>';
|
||||
return `<tr><td>${esc(url)}</td><td class="num">${fmt(s.requests)}</td><td class="num">${share}</td><td>${affinity}</td></tr>`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
document.getElementById('status').textContent = 'updated ' + new Date().toLocaleTimeString();
|
||||
@@ -187,6 +262,7 @@ def build_router(
|
||||
"models": list(state.models),
|
||||
"slots_acquired": acquired,
|
||||
"slots_total": total,
|
||||
"requests": stats.backend_requests.get(state.url, 0),
|
||||
"last_poll_age": None if age == float("inf") else round(age, 1),
|
||||
}
|
||||
)
|
||||
@@ -196,6 +272,27 @@ def build_router(
|
||||
sessions_by_model = await session_store.count_by_model()
|
||||
live_count = sum(1 for s in states if s.live)
|
||||
|
||||
# Merge per-model request stats with active session counts.
|
||||
all_models = set(stats.model_requests) | set(sessions_by_model)
|
||||
model_stats = {
|
||||
m: {
|
||||
"requests": stats.model_requests.get(m, 0),
|
||||
"estimated_tokens": stats.model_tokens.get(m, 0),
|
||||
"active_sessions": sessions_by_model.get(m, 0),
|
||||
}
|
||||
for m in all_models
|
||||
}
|
||||
|
||||
# Per-backend cumulative stats with session affinity breakdown.
|
||||
backend_stats = {
|
||||
url: {
|
||||
"requests": count,
|
||||
"session_hits": 0,
|
||||
"session_misses": 0,
|
||||
}
|
||||
for url, count in stats.backend_requests.items()
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"uptime": stats.uptime_str(),
|
||||
@@ -203,9 +300,13 @@ def build_router(
|
||||
"queue_depth": len(queue_snapshot),
|
||||
"session_count": session_count,
|
||||
"live_backend_count": live_count,
|
||||
"session_hits": stats.session_hits,
|
||||
"session_misses": stats.session_misses,
|
||||
"session_hit_rate": stats.session_hit_rate(),
|
||||
"backends": backends_data,
|
||||
"queue": queue_snapshot,
|
||||
"sessions_by_model": sessions_by_model,
|
||||
"model_stats": model_stats,
|
||||
"backend_stats": backend_stats,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -96,7 +96,9 @@ def _init_slot_tracker(
|
||||
registry: BackendRegistry,
|
||||
slot_tracker: SlotTracker,
|
||||
default_max_models: int | None,
|
||||
model_unload_delay: float,
|
||||
) -> None:
|
||||
slot_tracker.set_model_unload_delay(model_unload_delay)
|
||||
for state in registry.get_all_states():
|
||||
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
||||
effective_max = (
|
||||
@@ -279,6 +281,10 @@ async def _inference_endpoint(
|
||||
if not incoming_session_id:
|
||||
await _recover_session_affinity(session_id, body.get("messages") or [], session_store)
|
||||
|
||||
preferred_url: str | None = None
|
||||
if incoming_session_id:
|
||||
preferred_url = await session_store.get_preferred_backend(session_id)
|
||||
|
||||
result = await _dispatch_entry(
|
||||
request_queue, stats, config, slot_tracker, scheduler, model_id, session_id, body
|
||||
)
|
||||
@@ -311,6 +317,10 @@ async def _inference_endpoint(
|
||||
tried.add(failover.url)
|
||||
backend = failover
|
||||
|
||||
stats.record_model(model_id, _estimate_tokens(body))
|
||||
stats.record_backend(backend.url)
|
||||
stats.record_session(bool(incoming_session_id), preferred_url, backend.url)
|
||||
|
||||
if model_id:
|
||||
messages = body.get("messages", [])
|
||||
await session_store.update(
|
||||
@@ -355,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
|
||||
timeout=aiohttp.ClientTimeout(total=300),
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
)
|
||||
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
|
||||
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
|
||||
_init_global_model_limits(config, slot_tracker)
|
||||
registry.start()
|
||||
scheduler.start()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -11,6 +12,10 @@ class _SlotState:
|
||||
acquired: int = 0
|
||||
active_models: dict[str, int] = field(default_factory=dict)
|
||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||
# Warm-hold window: after the last slot of a model drains, block other models
|
||||
# for a brief period so follow-up requests (titles, suggestions) hit the same backend.
|
||||
sticky_model: str | None = None
|
||||
sticky_until: float = 0.0
|
||||
|
||||
|
||||
class SlotTracker:
|
||||
@@ -24,6 +29,11 @@ class SlotTracker:
|
||||
self._slots: dict[str, _SlotState] = {}
|
||||
self._global_limits: dict[str, int] = {}
|
||||
self._global_counts: dict[str, int] = {}
|
||||
self._model_unload_delay: float = 0.0
|
||||
|
||||
def set_model_unload_delay(self, delay: float) -> None:
|
||||
"""Seconds to keep a backend sticky to its last model after all slots drain."""
|
||||
self._model_unload_delay = max(0.0, delay)
|
||||
|
||||
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
|
||||
if url not in self._slots:
|
||||
@@ -50,6 +60,9 @@ class SlotTracker:
|
||||
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
|
||||
if state.acquired >= state.capacity:
|
||||
return False
|
||||
if model_id and state.sticky_model and time.monotonic() < state.sticky_until:
|
||||
if model_id != state.sticky_model:
|
||||
return False
|
||||
if model_id and model_id not in state.active_models:
|
||||
if state.max_models is not None and len(state.active_models) >= state.max_models:
|
||||
return False
|
||||
@@ -127,6 +140,9 @@ class SlotTracker:
|
||||
count = state.active_models[model_id] - 1
|
||||
if count <= 0:
|
||||
del state.active_models[model_id]
|
||||
if self._model_unload_delay > 0 and not state.active_models:
|
||||
state.sticky_model = model_id
|
||||
state.sticky_until = time.monotonic() + self._model_unload_delay
|
||||
else:
|
||||
state.active_models[model_id] = count
|
||||
if model_id and model_id in self._global_limits:
|
||||
@@ -146,4 +162,6 @@ class SlotTracker:
|
||||
)
|
||||
state.acquired = 0
|
||||
state.active_models.clear()
|
||||
state.sticky_model = None
|
||||
state.sticky_until = 0.0
|
||||
state.condition.notify_all()
|
||||
|
||||
@@ -56,7 +56,9 @@ class TestMonitorEndpoints(unittest.TestCase):
|
||||
slot_tracker.set_capacity("http://b1", 4)
|
||||
data = TestClient(app).get("/monitor/data").json()
|
||||
for key in ("uptime", "total_requests", "queue_depth", "session_count",
|
||||
"live_backend_count", "backends", "queue", "sessions_by_model"):
|
||||
"live_backend_count", "backends", "queue",
|
||||
"model_stats", "backend_stats",
|
||||
"session_hits", "session_misses", "session_hit_rate"):
|
||||
self.assertIn(key, data)
|
||||
|
||||
def test_monitor_data_backend_fields(self):
|
||||
|
||||
@@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
|
||||
self.assertIsNone(tracker.global_model_usage("othermodel"))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Warm-hold / model_unload_delay tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def test_sticky_window_blocks_other_model(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
# Window active: model-b should be rejected
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_allows_same_model(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-a"))
|
||||
|
||||
async def test_sticky_window_expires(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(0.05)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_not_started_when_delay_zero(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(0.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_not_started_while_slots_remain(self):
|
||||
"""Window must not start until ALL slots for the model drain."""
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 4)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a") # one slot still held
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_reset_acquired_clears_sticky_state(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
await tracker.reset_acquired("http://b")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_reset_acquired_updates_global_counts(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b1", 4)
|
||||
|
||||
Reference in New Issue
Block a user