feat: model unload
This commit is contained in:
@@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri
|
|||||||
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
|
| `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes |
|
||||||
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
|
| `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. |
|
||||||
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
|
| `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. |
|
||||||
|
| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. |
|
||||||
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
|
| `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. |
|
||||||
|
|
||||||
### Per-backend fields
|
### Per-backend fields
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
"default_slot_capacity": 1,
|
"default_slot_capacity": 1,
|
||||||
"default_max_models": 1,
|
"default_max_models": 1,
|
||||||
"max_queue_skip": 4,
|
"max_queue_skip": 4,
|
||||||
|
"model_unload_delay": 3.0,
|
||||||
"model_limits": {
|
"model_limits": {
|
||||||
"my-very-large-model": 1
|
"my-very-large-model": 1
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ class ProxyConfig(BaseSettings):
|
|||||||
# positions; each bypassed entry accumulates a skip count and is immune to
|
# positions; each bypassed entry accumulates a skip count and is immune to
|
||||||
# further skipping once it reaches N.
|
# further skipping once it reaches N.
|
||||||
max_queue_skip: int = 0
|
max_queue_skip: int = 0
|
||||||
|
# Seconds to keep a backend sticky to its last model after all slots drain.
|
||||||
|
# Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion
|
||||||
|
# generation) that arrive shortly after the main response. 0 = disabled.
|
||||||
|
model_unload_delay: float = 3.0
|
||||||
# Per-model global concurrency cap across all backends.
|
# Per-model global concurrency cap across all backends.
|
||||||
# Use for very large models that cannot run concurrently due to RAM constraints.
|
# Use for very large models that cannot run concurrently due to RAM constraints.
|
||||||
model_limits: dict[str, int] = {}
|
model_limits: dict[str, int] = {}
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ def _init_slot_tracker(
|
|||||||
registry: BackendRegistry,
|
registry: BackendRegistry,
|
||||||
slot_tracker: SlotTracker,
|
slot_tracker: SlotTracker,
|
||||||
default_max_models: int | None,
|
default_max_models: int | None,
|
||||||
|
model_unload_delay: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
slot_tracker.set_model_unload_delay(model_unload_delay)
|
||||||
for state in registry.get_all_states():
|
for state in registry.get_all_states():
|
||||||
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
slot_tracker.set_capacity(state.url, state.slot_capacity)
|
||||||
effective_max = (
|
effective_max = (
|
||||||
@@ -363,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
|
|||||||
timeout=aiohttp.ClientTimeout(total=300),
|
timeout=aiohttp.ClientTimeout(total=300),
|
||||||
connector=aiohttp.TCPConnector(ssl=False),
|
connector=aiohttp.TCPConnector(ssl=False),
|
||||||
)
|
)
|
||||||
_init_slot_tracker(registry, slot_tracker, config.default_max_models)
|
_init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay)
|
||||||
_init_global_model_limits(config, slot_tracker)
|
_init_global_model_limits(config, slot_tracker)
|
||||||
registry.start()
|
registry.start()
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@@ -11,6 +12,10 @@ class _SlotState:
|
|||||||
acquired: int = 0
|
acquired: int = 0
|
||||||
active_models: dict[str, int] = field(default_factory=dict)
|
active_models: dict[str, int] = field(default_factory=dict)
|
||||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||||
|
# Warm-hold window: after the last slot of a model drains, block other models
|
||||||
|
# for a brief period so follow-up requests (titles, suggestions) hit the same backend.
|
||||||
|
sticky_model: str | None = None
|
||||||
|
sticky_until: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class SlotTracker:
|
class SlotTracker:
|
||||||
@@ -24,6 +29,11 @@ class SlotTracker:
|
|||||||
self._slots: dict[str, _SlotState] = {}
|
self._slots: dict[str, _SlotState] = {}
|
||||||
self._global_limits: dict[str, int] = {}
|
self._global_limits: dict[str, int] = {}
|
||||||
self._global_counts: dict[str, int] = {}
|
self._global_counts: dict[str, int] = {}
|
||||||
|
self._model_unload_delay: float = 0.0
|
||||||
|
|
||||||
|
def set_model_unload_delay(self, delay: float) -> None:
|
||||||
|
"""Seconds to keep a backend sticky to its last model after all slots drain."""
|
||||||
|
self._model_unload_delay = max(0.0, delay)
|
||||||
|
|
||||||
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
|
def _ensure(self, url: str, capacity: int = 1) -> _SlotState:
|
||||||
if url not in self._slots:
|
if url not in self._slots:
|
||||||
@@ -50,6 +60,9 @@ class SlotTracker:
|
|||||||
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
|
def _can_acquire(self, state: _SlotState, model_id: str) -> bool:
|
||||||
if state.acquired >= state.capacity:
|
if state.acquired >= state.capacity:
|
||||||
return False
|
return False
|
||||||
|
if model_id and state.sticky_model and time.monotonic() < state.sticky_until:
|
||||||
|
if model_id != state.sticky_model:
|
||||||
|
return False
|
||||||
if model_id and model_id not in state.active_models:
|
if model_id and model_id not in state.active_models:
|
||||||
if state.max_models is not None and len(state.active_models) >= state.max_models:
|
if state.max_models is not None and len(state.active_models) >= state.max_models:
|
||||||
return False
|
return False
|
||||||
@@ -127,6 +140,9 @@ class SlotTracker:
|
|||||||
count = state.active_models[model_id] - 1
|
count = state.active_models[model_id] - 1
|
||||||
if count <= 0:
|
if count <= 0:
|
||||||
del state.active_models[model_id]
|
del state.active_models[model_id]
|
||||||
|
if self._model_unload_delay > 0 and not state.active_models:
|
||||||
|
state.sticky_model = model_id
|
||||||
|
state.sticky_until = time.monotonic() + self._model_unload_delay
|
||||||
else:
|
else:
|
||||||
state.active_models[model_id] = count
|
state.active_models[model_id] = count
|
||||||
if model_id and model_id in self._global_limits:
|
if model_id and model_id in self._global_limits:
|
||||||
@@ -146,4 +162,6 @@ class SlotTracker:
|
|||||||
)
|
)
|
||||||
state.acquired = 0
|
state.acquired = 0
|
||||||
state.active_models.clear()
|
state.active_models.clear()
|
||||||
|
state.sticky_model = None
|
||||||
|
state.sticky_until = 0.0
|
||||||
state.condition.notify_all()
|
state.condition.notify_all()
|
||||||
|
|||||||
@@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
|
self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2))
|
||||||
self.assertIsNone(tracker.global_model_usage("othermodel"))
|
self.assertIsNone(tracker.global_model_usage("othermodel"))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Warm-hold / model_unload_delay tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def test_sticky_window_blocks_other_model(self):
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 2)
|
||||||
|
tracker.set_model_unload_delay(60.0)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a")
|
||||||
|
# Window active: model-b should be rejected
|
||||||
|
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||||
|
|
||||||
|
async def test_sticky_window_allows_same_model(self):
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 2)
|
||||||
|
tracker.set_model_unload_delay(60.0)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a")
|
||||||
|
self.assertTrue(tracker.can_accept("http://b", "model-a"))
|
||||||
|
|
||||||
|
async def test_sticky_window_expires(self):
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 2)
|
||||||
|
tracker.set_model_unload_delay(0.05)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a")
|
||||||
|
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||||
|
|
||||||
|
async def test_sticky_window_not_started_when_delay_zero(self):
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 2)
|
||||||
|
tracker.set_model_unload_delay(0.0)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a")
|
||||||
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||||
|
|
||||||
|
async def test_sticky_window_not_started_while_slots_remain(self):
|
||||||
|
"""Window must not start until ALL slots for the model drain."""
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 4)
|
||||||
|
tracker.set_model_unload_delay(60.0)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a") # one slot still held
|
||||||
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||||
|
|
||||||
|
async def test_reset_acquired_clears_sticky_state(self):
|
||||||
|
tracker = SlotTracker()
|
||||||
|
tracker.set_capacity("http://b", 2)
|
||||||
|
tracker.set_model_unload_delay(60.0)
|
||||||
|
await tracker.acquire("http://b", "model-a")
|
||||||
|
await tracker.release("http://b", "model-a")
|
||||||
|
self.assertFalse(tracker.can_accept("http://b", "model-b"))
|
||||||
|
await tracker.reset_acquired("http://b")
|
||||||
|
self.assertTrue(tracker.can_accept("http://b", "model-b"))
|
||||||
|
|
||||||
async def test_reset_acquired_updates_global_counts(self):
|
async def test_reset_acquired_updates_global_counts(self):
|
||||||
tracker = SlotTracker()
|
tracker = SlotTracker()
|
||||||
tracker.set_capacity("http://b1", 4)
|
tracker.set_capacity("http://b1", 4)
|
||||||
|
|||||||
Reference in New Issue
Block a user