diff --git a/src/llamacpp_ha/forwarder.py b/src/llamacpp_ha/forwarder.py index 4bcdf55..df46564 100644 --- a/src/llamacpp_ha/forwarder.py +++ b/src/llamacpp_ha/forwarder.py @@ -126,11 +126,13 @@ async def forward_best_effort( request: Request, registry_backends: list[BackendState], session: aiohttp.ClientSession, + preferred_url: str | None = None, ) -> Response: """Forward without slot gating to any live backend (catch-all paths). Backends are tried in round-robin order; on connection failure the next - backend is attempted until all live backends are exhausted. + backend is attempted until all live backends are exhausted. When + preferred_url is given (session affinity hint) that backend is tried first. """ if not registry_backends: return Response(content="No live backends", status_code=503) @@ -139,6 +141,11 @@ async def forward_best_effort( start = next(_best_effort_counter) % n ordered = registry_backends[start:] + registry_backends[:start] + if preferred_url: + preferred = [b for b in ordered if b.url == preferred_url] + rest = [b for b in ordered if b.url != preferred_url] + ordered = preferred + rest + path = request.url.path query = request.url.query body = await request.body() diff --git a/src/llamacpp_ha/proxy.py b/src/llamacpp_ha/proxy.py index db3f6b5..4f29407 100644 --- a/src/llamacpp_ha/proxy.py +++ b/src/llamacpp_ha/proxy.py @@ -232,13 +232,19 @@ async def _catch_all( *, http: _HttpSession, registry: BackendRegistry, - stats: ProxyStats, + session_store: SessionStore, ) -> Response: if http.client is None: return Response(content="Proxy not ready", status_code=503) - stats.increment_requests() + # Non-inference paths are not counted as "requests served". + # If the client has a session cookie, prefer the backend that is already + # holding its KV-cache so web-UI API calls stay consistent with the chat. + session_id = _session_id_from(request) + preferred_url: str | None = None + if session_id: + preferred_url = await session_store.get_preferred_backend(session_id) live = registry.get_all_live_backends() - return await forward_best_effort(request, live, http.client) + return await forward_best_effort(request, live, http.client, preferred_url=preferred_url) async def _inference_endpoint( @@ -394,7 +400,7 @@ def create_app(config: ProxyConfig) -> FastAPI: ) async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001 - return await _catch_all(request, http=http, registry=registry, stats=stats) + return await _catch_all(request, http=http, registry=registry, session_store=session_store) app.add_api_route("/v1/models", list_models_handler, methods=["GET"]) app.add_api_route("/health", health_handler, methods=["GET"]) diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 7f4fd5d..cfe1ec0 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -324,6 +324,45 @@ class TestForwardBestEffort(unittest.IsolatedAsyncioTestCase): response = await forward_best_effort(req, [state1, state2], session) self.assertEqual(response.status_code, 502) + async def test_preferred_url_tried_first(self): + """When preferred_url is given, that backend is attempted before others.""" + state1 = _make_state("http://b1") + state2 = _make_state("http://b2") + + calls: list[str] = [] + + def fake_iter(method, url, **kwargs): + calls.append(url.split("//")[1].split("/")[0]) # extract host + ctx, _ = _mock_aiohttp_response(status=200, body=b"ok") + return ctx + + session = MagicMock() + session.request = MagicMock(side_effect=fake_iter) + req = _make_request({}) + + await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2") + self.assertEqual(calls[0], "b2") + + async def test_preferred_url_fallback_on_failure(self): + """If preferred backend fails, the next one is still tried.""" + state1 = _make_state("http://b1") + state2 = _make_state("http://b2") + + ctx_fail = MagicMock() + ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("down")) + ctx_fail.__aexit__ = AsyncMock(return_value=False) + ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok") + + def pick(method, url, **kwargs): + return ctx_fail if "b2" in url else ctx_ok + + session = MagicMock() + session.request = MagicMock(side_effect=pick) + req = _make_request({}) + + response = await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2") + self.assertEqual(response.status_code, 200) + async def test_status_code_and_body_preserved(self): state = _make_state("http://b1") ctx, _ = _mock_aiohttp_response(status=404, body=b"not found")