feat: model unload

This commit is contained in:
2026-05-18 00:34:27 +02:00
parent bcebaf0e93
commit 13fb341354
6 changed files with 86 additions and 1 deletions

View File

@@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. |
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
### Per-backend fields

View File

@@ -8,6 +8,7 @@
"default_slot_capacity": 1,
"default_max_models": 1,
"max_queue_skip": 4,
"model_unload_delay": 3.0,
"model_limits": {
"my-very-large-model": 1
},

View File

@@ -44,6 +44,10 @@ class ProxyConfig(BaseSettings):
# positions; each bypassed entry accumulates a skip count and is immune to
# further skipping once it reaches N.
max_queue_skip: int = 0
# Seconds to keep a backend sticky to its last model after all slots drain.
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
# generation) that arrive shortly after the main response. 0 = disabled.
model_unload_delay: float = 3.0
# Per-model global concurrency cap across all backends.
# Use for very large models that cannot run concurrently due to RAM constraints.
model_limits: dict[str, int] = {}

View File

@@ -96,7 +96,9 @@ def _init_slot_tracker(
registry: BackendRegistry,
slot_tracker: SlotTracker,
default_max_models: int | None,
model_unload_delay: float,
) -> None:
slot_tracker.set_model_unload_delay(model_unload_delay)
for state in registry.get_all_states():
slot_tracker.set_capacity(state.url, state.slot_capacity)
effective_max = (
@@ -363,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
timeout=aiohttp.ClientTimeout(total=300),
connector=aiohttp.TCPConnector(ssl=False),
)
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
_init_global_model_limits(config, slot_tracker)
registry.start()
scheduler.start()

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass, field
@@ -11,6 +12,10 @@ class _SlotState:
acquired: int = 0
active_models: dict[str, int] = field(default_factory=dict)
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
# Warm-hold window: after the last slot of a model drains, block other models
# for a brief period so follow-up requests (titles, suggestions) hit the same backend.
sticky_model: str | None = None
sticky_until: float = 0.0
class SlotTracker:
@@ -24,6 +29,11 @@ class SlotTracker:
self._slots: dict[str, _SlotState] = {}
self._global_limits: dict[str, int] = {}
self._global_counts: dict[str, int] = {}
self._model_unload_delay: float = 0.0
def set_model_unload_delay(self, delay: float) -> None:
"""Seconds to keep a backend sticky to its last model after all slots drain."""
self._model_unload_delay = max(0.0, delay)
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
if url not in self._slots:
@@ -50,6 +60,9 @@ class SlotTracker:
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
if state.acquired >= state.capacity:
return False
if model_id and state.sticky_model and time.monotonic() < state.sticky_until:
if model_id != state.sticky_model:
return False
if model_id and model_id not in state.active_models:
if state.max_models is not None and len(state.active_models) >= state.max_models:
return False
@@ -127,6 +140,9 @@ class SlotTracker:
count = state.active_models[model_id] - 1
if count <= 0:
del state.active_models[model_id]
if self._model_unload_delay > 0 and not state.active_models:
state.sticky_model = model_id
state.sticky_until = time.monotonic() + self._model_unload_delay
else:
state.active_models[model_id] = count
if model_id and model_id in self._global_limits:
@@ -146,4 +162,6 @@ class SlotTracker:
)
state.acquired = 0
state.active_models.clear()
state.sticky_model = None
state.sticky_until = 0.0
state.condition.notify_all()

View File

@@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
self.assertIsNone(tracker.global_model_usage("othermodel"))
# ------------------------------------------------------------------
# Warm-hold / model_unload_delay tests
# ------------------------------------------------------------------
async def test_sticky_window_blocks_other_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
# Window active: model-b should be rejected
self.assertFalse(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_allows_same_model(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-a"))
async def test_sticky_window_expires(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(0.05)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await asyncio.sleep(0.1)
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_not_started_when_delay_zero(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(0.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_sticky_window_not_started_while_slots_remain(self):
"""Window must not start until ALL slots for the model drain."""
tracker = SlotTracker()
tracker.set_capacity("http://b", 4)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a") # one slot still held
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_reset_acquired_clears_sticky_state(self):
tracker = SlotTracker()
tracker.set_capacity("http://b", 2)
tracker.set_model_unload_delay(60.0)
await tracker.acquire("http://b", "model-a")
await tracker.release("http://b", "model-a")
self.assertFalse(tracker.can_accept("http://b", "model-b"))
await tracker.reset_acquired("http://b")
self.assertTrue(tracker.can_accept("http://b", "model-b"))
async def test_reset_acquired_updates_global_counts(self):
tracker = SlotTracker()
tracker.set_capacity("http://b1", 4)