fix statistics
This commit is contained in:
@@ -126,11 +126,13 @@ async def forward_best_effort(
|
||||
request: Request,
|
||||
registry_backends: list[BackendState],
|
||||
session: aiohttp.ClientSession,
|
||||
preferred_url: str | None = None,
|
||||
) -> Response:
|
||||
"""Forward without slot gating to any live backend (catch-all paths).
|
||||
|
||||
Backends are tried in round-robin order; on connection failure the next
|
||||
backend is attempted until all live backends are exhausted.
|
||||
backend is attempted until all live backends are exhausted. When
|
||||
preferred_url is given (session affinity hint) that backend is tried first.
|
||||
"""
|
||||
if not registry_backends:
|
||||
return Response(content="No live backends", status_code=503)
|
||||
@@ -139,6 +141,11 @@ async def forward_best_effort(
|
||||
start = next(_best_effort_counter) % n
|
||||
ordered = registry_backends[start:] + registry_backends[:start]
|
||||
|
||||
if preferred_url:
|
||||
preferred = [b for b in ordered if b.url == preferred_url]
|
||||
rest = [b for b in ordered if b.url != preferred_url]
|
||||
ordered = preferred + rest
|
||||
|
||||
path = request.url.path
|
||||
query = request.url.query
|
||||
body = await request.body()
|
||||
|
||||
@@ -232,13 +232,19 @@ async def _catch_all(
|
||||
*,
|
||||
http: _HttpSession,
|
||||
registry: BackendRegistry,
|
||||
stats: ProxyStats,
|
||||
session_store: SessionStore,
|
||||
) -> Response:
|
||||
if http.client is None:
|
||||
return Response(content="Proxy not ready", status_code=503)
|
||||
stats.increment_requests()
|
||||
# Non-inference paths are not counted as "requests served".
|
||||
# If the client has a session cookie, prefer the backend that is already
|
||||
# holding its KV-cache so web-UI API calls stay consistent with the chat.
|
||||
session_id = _session_id_from(request)
|
||||
preferred_url: str | None = None
|
||||
if session_id:
|
||||
preferred_url = await session_store.get_preferred_backend(session_id)
|
||||
live = registry.get_all_live_backends()
|
||||
return await forward_best_effort(request, live, http.client)
|
||||
return await forward_best_effort(request, live, http.client, preferred_url=preferred_url)
|
||||
|
||||
|
||||
async def _inference_endpoint(
|
||||
@@ -394,7 +400,7 @@ def create_app(config: ProxyConfig) -> FastAPI:
|
||||
)
|
||||
|
||||
async def catch_all_handler(request: Request, full_path: str) -> Response: # noqa: ARG001
|
||||
return await _catch_all(request, http=http, registry=registry, stats=stats)
|
||||
return await _catch_all(request, http=http, registry=registry, session_store=session_store)
|
||||
|
||||
app.add_api_route("/v1/models", list_models_handler, methods=["GET"])
|
||||
app.add_api_route("/health", health_handler, methods=["GET"])
|
||||
|
||||
@@ -324,6 +324,45 @@ class TestForwardBestEffort(unittest.IsolatedAsyncioTestCase):
|
||||
response = await forward_best_effort(req, [state1, state2], session)
|
||||
self.assertEqual(response.status_code, 502)
|
||||
|
||||
async def test_preferred_url_tried_first(self):
|
||||
"""When preferred_url is given, that backend is attempted before others."""
|
||||
state1 = _make_state("http://b1")
|
||||
state2 = _make_state("http://b2")
|
||||
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_iter(method, url, **kwargs):
|
||||
calls.append(url.split("//")[1].split("/")[0]) # extract host
|
||||
ctx, _ = _mock_aiohttp_response(status=200, body=b"ok")
|
||||
return ctx
|
||||
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(side_effect=fake_iter)
|
||||
req = _make_request({})
|
||||
|
||||
await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2")
|
||||
self.assertEqual(calls[0], "b2")
|
||||
|
||||
async def test_preferred_url_fallback_on_failure(self):
|
||||
"""If preferred backend fails, the next one is still tried."""
|
||||
state1 = _make_state("http://b1")
|
||||
state2 = _make_state("http://b2")
|
||||
|
||||
ctx_fail = MagicMock()
|
||||
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("down"))
|
||||
ctx_fail.__aexit__ = AsyncMock(return_value=False)
|
||||
ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok")
|
||||
|
||||
def pick(method, url, **kwargs):
|
||||
return ctx_fail if "b2" in url else ctx_ok
|
||||
|
||||
session = MagicMock()
|
||||
session.request = MagicMock(side_effect=pick)
|
||||
req = _make_request({})
|
||||
|
||||
response = await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
async def test_status_code_and_body_preserved(self):
|
||||
state = _make_state("http://b1")
|
||||
ctx, _ = _mock_aiohttp_response(status=404, body=b"not found")
|
||||
|
||||
Reference in New Issue
Block a user