From 13fb341354a9e32b0ca44cd91dea6922d5052345 Mon Sep 17 00:00:00 2001 From: chacha Date: Mon, 18 May 2026 00:34:27 +0200 Subject: [PATCH] feat: model unload --- README.md | 1 + config.json.example | 1 + src/llamacpp_ha/config.py | 4 +++ src/llamacpp_ha/proxy.py | 4 ++- src/llamacpp_ha/slot_tracker.py | 18 ++++++++++ tests/test_slot_tracker.py | 59 +++++++++++++++++++++++++++++++++ 6 files changed, 86 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 15f736f..d83ba8a 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ Configuration is a JSON file. All fields also accept environment variable overri | `default_slot_capacity` | `1` | Initial slot count per backend used before the first `/slots` poll completes | | `default_max_models` | `null` | Maximum concurrent models per backend (null = unlimited). Applied to backends that do not set their own `max_models`. | | `max_queue_skip` | `0` | How many times a queued request may be bypassed by a model-affinity promotion before it is frozen at head-of-line. `0` disables reordering. | +| `model_unload_delay` | `3.0` | Seconds a backend stays sticky to its last model after all slots drain. Prevents unnecessary model swaps for follow-up requests (title generation, suggestions) that arrive shortly after the main response. `0` disables. | | `model_limits` | `{}` | Per-model global concurrency cap across all backends (e.g. `{"my-large-model": 1}`). Use for models too large to run simultaneously due to RAM constraints. | ### Per-backend fields diff --git a/config.json.example b/config.json.example index 4f1cb06..4a31710 100644 --- a/config.json.example +++ b/config.json.example @@ -8,6 +8,7 @@ "default_slot_capacity": 1, "default_max_models": 1, "max_queue_skip": 4, + "model_unload_delay": 3.0, "model_limits": { "my-very-large-model": 1 }, diff --git a/src/llamacpp_ha/config.py b/src/llamacpp_ha/config.py index 3cac6ce..0e49675 100644 --- a/src/llamacpp_ha/config.py +++ b/src/llamacpp_ha/config.py @@ -44,6 +44,10 @@ class ProxyConfig(BaseSettings): # positions; each bypassed entry accumulates a skip count and is immune to # further skipping once it reaches N. max_queue_skip: int = 0 + # Seconds to keep a backend sticky to its last model after all slots drain. + # Prevents unnecessary model swaps for follow-up requests (e.g. title/suggestion + # generation) that arrive shortly after the main response. 0 = disabled. + model_unload_delay: float = 3.0 # Per-model global concurrency cap across all backends. # Use for very large models that cannot run concurrently due to RAM constraints. model_limits: dict[str, int] = {} diff --git a/src/llamacpp_ha/proxy.py b/src/llamacpp_ha/proxy.py index 0cf6817..bd743b4 100644 --- a/src/llamacpp_ha/proxy.py +++ b/src/llamacpp_ha/proxy.py @@ -96,7 +96,9 @@ def _init_slot_tracker( registry: BackendRegistry, slot_tracker: SlotTracker, default_max_models: int | None, + model_unload_delay: float, ) -> None: + slot_tracker.set_model_unload_delay(model_unload_delay) for state in registry.get_all_states(): slot_tracker.set_capacity(state.url, state.slot_capacity) effective_max = ( @@ -363,7 +365,7 @@ def create_app(config: ProxyConfig) -> FastAPI: timeout=aiohttp.ClientTimeout(total=300), connector=aiohttp.TCPConnector(ssl=False), ) - _init_slot_tracker(registry, slot_tracker, config.default_max_models) + _init_slot_tracker(registry, slot_tracker, config.default_max_models, config.model_unload_delay) _init_global_model_limits(config, slot_tracker) registry.start() scheduler.start() diff --git a/src/llamacpp_ha/slot_tracker.py b/src/llamacpp_ha/slot_tracker.py index ab104c3..53fbc61 100644 --- a/src/llamacpp_ha/slot_tracker.py +++ b/src/llamacpp_ha/slot_tracker.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from dataclasses import dataclass, field @@ -11,6 +12,10 @@ class _SlotState: acquired: int = 0 active_models: dict[str, int] = field(default_factory=dict) condition: asyncio.Condition = field(default_factory=asyncio.Condition) + # Warm-hold window: after the last slot of a model drains, block other models + # for a brief period so follow-up requests (titles, suggestions) hit the same backend. + sticky_model: str | None = None + sticky_until: float = 0.0 class SlotTracker: @@ -24,6 +29,11 @@ class SlotTracker: self._slots: dict[str, _SlotState] = {} self._global_limits: dict[str, int] = {} self._global_counts: dict[str, int] = {} + self._model_unload_delay: float = 0.0 + + def set_model_unload_delay(self, delay: float) -> None: + """Seconds to keep a backend sticky to its last model after all slots drain.""" + self._model_unload_delay = max(0.0, delay) def _ensure(self, url: str, capacity: int = 1) -> _SlotState: if url not in self._slots: @@ -50,6 +60,9 @@ class SlotTracker: def _can_acquire(self, state: _SlotState, model_id: str) -> bool: if state.acquired >= state.capacity: return False + if model_id and state.sticky_model and time.monotonic() < state.sticky_until: + if model_id != state.sticky_model: + return False if model_id and model_id not in state.active_models: if state.max_models is not None and len(state.active_models) >= state.max_models: return False @@ -127,6 +140,9 @@ class SlotTracker: count = state.active_models[model_id] - 1 if count <= 0: del state.active_models[model_id] + if self._model_unload_delay > 0 and not state.active_models: + state.sticky_model = model_id + state.sticky_until = time.monotonic() + self._model_unload_delay else: state.active_models[model_id] = count if model_id and model_id in self._global_limits: @@ -146,4 +162,6 @@ class SlotTracker: ) state.acquired = 0 state.active_models.clear() + state.sticky_model = None + state.sticky_until = 0.0 state.condition.notify_all() diff --git a/tests/test_slot_tracker.py b/tests/test_slot_tracker.py index 837f968..003fa81 100644 --- a/tests/test_slot_tracker.py +++ b/tests/test_slot_tracker.py @@ -248,6 +248,65 @@ class TestSlotTracker(unittest.IsolatedAsyncioTestCase): self.assertEqual(tracker.global_model_usage("bigmodel"), (1, 2)) self.assertIsNone(tracker.global_model_usage("othermodel")) + # ------------------------------------------------------------------ + # Warm-hold / model_unload_delay tests + # ------------------------------------------------------------------ + + async def test_sticky_window_blocks_other_model(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + tracker.set_model_unload_delay(60.0) + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") + # Window active: model-b should be rejected + self.assertFalse(tracker.can_accept("http://b", "model-b")) + + async def test_sticky_window_allows_same_model(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + tracker.set_model_unload_delay(60.0) + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") + self.assertTrue(tracker.can_accept("http://b", "model-a")) + + async def test_sticky_window_expires(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + tracker.set_model_unload_delay(0.05) + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") + self.assertFalse(tracker.can_accept("http://b", "model-b")) + await asyncio.sleep(0.1) + self.assertTrue(tracker.can_accept("http://b", "model-b")) + + async def test_sticky_window_not_started_when_delay_zero(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + tracker.set_model_unload_delay(0.0) + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") + self.assertTrue(tracker.can_accept("http://b", "model-b")) + + async def test_sticky_window_not_started_while_slots_remain(self): + """Window must not start until ALL slots for the model drain.""" + tracker = SlotTracker() + tracker.set_capacity("http://b", 4) + tracker.set_model_unload_delay(60.0) + await tracker.acquire("http://b", "model-a") + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") # one slot still held + self.assertTrue(tracker.can_accept("http://b", "model-b")) + + async def test_reset_acquired_clears_sticky_state(self): + tracker = SlotTracker() + tracker.set_capacity("http://b", 2) + tracker.set_model_unload_delay(60.0) + await tracker.acquire("http://b", "model-a") + await tracker.release("http://b", "model-a") + self.assertFalse(tracker.can_accept("http://b", "model-b")) + await tracker.reset_acquired("http://b") + self.assertTrue(tracker.can_accept("http://b", "model-b")) + async def test_reset_acquired_updates_global_counts(self): tracker = SlotTracker() tracker.set_capacity("http://b1", 4)