feat: model unload
This commit is contained in:
@@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri
|
||||
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
|
||||
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
|
||||
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
|
||||
| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. |
|
||||
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
|
||||
|
||||
### Per-backend fields
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"default_slot_capacity": 1,
|
||||
"default_max_models": 1,
|
||||
"max_queue_skip": 4,
|
||||
"model_unload_delay": 3.0,
|
||||
"model_limits": {
|
||||
"my-very-large-model": 1
|
||||
},
|
||||
|
||||
@@ -44,6 +44,10 @@ class ProxyConfig(BaseSettings):
|
||||
# positions; each bypassed entry accumulates a skip count and is immune to
|
||||
# further skipping once it reaches N.
|
||||
max_queue_skip: int = 0
|
||||
# Seconds to keep a backend sticky to its last model after all slots drain.
|
||||
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
|
||||
# generation) that arrive shortly after the main response. 0 = disabled.
|
||||
model_unload_delay: float = 3.0
|
||||
# Per-model global concurrency cap across all backends.
|
||||
# Use for very large models that cannot run concurrently due to RAM constraints.
|
||||
model_limits: dict[str, int] = {}
|
||||
|
||||
@@ -96,7 +96,9 @@ def _init_slot_tracker(
|
||||
registry: BackendRegistry,
|
||||
slot_tracker: SlotTracker,
|
||||
default_max_models: int | None,
|
||||
model_unload_delay: float,
|
||||
) -> None:
|
||||
slot_tracker.set_model_unload_delay(model_unload_delay)
|
||||
for state in registry.get_all_states():
|
||||
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
||||
effective_max = (
|
||||
@@ -363,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
|
||||
timeout=aiohttp.ClientTimeout(total=300),
|
||||
connector=aiohttp.TCPConnector(ssl=False),
|
||||
)
|
||||
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
|
||||
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
|
||||
_init_global_model_limits(config, slot_tracker)
|
||||
registry.start()
|
||||
scheduler.start()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -11,6 +12,10 @@ class _SlotState:
|
||||
acquired: int = 0
|
||||
active_models: dict[str, int] = field(default_factory=dict)
|
||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||
# Warm-hold window: after the last slot of a model drains, block other models
|
||||
# for a brief period so follow-up requests (titles, suggestions) hit the same backend.
|
||||
sticky_model: str | None = None
|
||||
sticky_until: float = 0.0
|
||||
|
||||
|
||||
class SlotTracker:
|
||||
@@ -24,6 +29,11 @@ class SlotTracker:
|
||||
self._slots: dict[str, _SlotState] = {}
|
||||
self._global_limits: dict[str, int] = {}
|
||||
self._global_counts: dict[str, int] = {}
|
||||
self._model_unload_delay: float = 0.0
|
||||
|
||||
def set_model_unload_delay(self, delay: float) -> None:
|
||||
"""Seconds to keep a backend sticky to its last model after all slots drain."""
|
||||
self._model_unload_delay = max(0.0, delay)
|
||||
|
||||
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
|
||||
if url not in self._slots:
|
||||
@@ -50,6 +60,9 @@ class SlotTracker:
|
||||
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
|
||||
if state.acquired >= state.capacity:
|
||||
return False
|
||||
if model_id and state.sticky_model and time.monotonic() < state.sticky_until:
|
||||
if model_id != state.sticky_model:
|
||||
return False
|
||||
if model_id and model_id not in state.active_models:
|
||||
if state.max_models is not None and len(state.active_models) >= state.max_models:
|
||||
return False
|
||||
@@ -127,6 +140,9 @@ class SlotTracker:
|
||||
count = state.active_models[model_id] - 1
|
||||
if count <= 0:
|
||||
del state.active_models[model_id]
|
||||
if self._model_unload_delay > 0 and not state.active_models:
|
||||
state.sticky_model = model_id
|
||||
state.sticky_until = time.monotonic() + self._model_unload_delay
|
||||
else:
|
||||
state.active_models[model_id] = count
|
||||
if model_id and model_id in self._global_limits:
|
||||
@@ -146,4 +162,6 @@ class SlotTracker:
|
||||
)
|
||||
state.acquired = 0
|
||||
state.active_models.clear()
|
||||
state.sticky_model = None
|
||||
state.sticky_until = 0.0
|
||||
state.condition.notify_all()
|
||||
|
||||
@@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
|
||||
self.assertIsNone(tracker.global_model_usage("othermodel"))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Warm-hold / model_unload_delay tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def test_sticky_window_blocks_other_model(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
# Window active: model-b should be rejected
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_allows_same_model(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-a"))
|
||||
|
||||
async def test_sticky_window_expires(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(0.05)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_not_started_when_delay_zero(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(0.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_sticky_window_not_started_while_slots_remain(self):
|
||||
"""Window must not start until ALL slots for the model drain."""
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 4)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a") # one slot still held
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_reset_acquired_clears_sticky_state(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b", 2)
|
||||
tracker.set_model_unload_delay(60.0)
|
||||
await tracker.acquire("http://b", "model-a")
|
||||
await tracker.release("http://b", "model-a")
|
||||
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||
await tracker.reset_acquired("http://b")
|
||||
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||
|
||||
async def test_reset_acquired_updates_global_counts(self):
|
||||
tracker = SlotTracker()
|
||||
tracker.set_capacity("http://b1", 4)
|
||||
|
||||
Reference in New Issue
Block a user