Files
llamacpp-ha/tests/test_forwarder.py
2026-05-17 22:38:36 +02:00

379 lines
14 KiB
Python

import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import Request
from starlette.datastructures import Headers
from llamacpp_ha.config import BackendConfig
from llamacpp_ha.forwarder import _forward_headers, forward_best_effort, forward_request
from llamacpp_ha.registry import BackendState
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url="http://b1", api_key=None):
cfg = BackendConfig(url=url, api_key=api_key)
return BackendState(config=cfg)
def _make_request(headers: dict, method="POST", path="/v1/chat/completions") -> MagicMock:
req = MagicMock()
req.headers = Headers(headers=headers)
req.url.path = path
req.url.query = ""
req.method = method
req.body = AsyncMock(return_value=b"")
return req
def _make_scheduler():
sched = MagicMock()
sched.notify_slot_released = MagicMock()
return sched
class TestForwardHeaders(unittest.TestCase):
def test_injects_backend_api_key(self):
state = _make_state(api_key="backend-secret")
req = _make_request({"Authorization": "Bearer client-key", "Content-Type": "application/json"})
headers = _forward_headers(req, state)
self.assertEqual(headers["Authorization"], "Bearer backend-secret")
combined = {k.lower(): v for k, v in headers.items()}
self.assertIn("application/json", combined.get("content-type", ""))
def test_removes_auth_when_no_backend_key(self):
state = _make_state(api_key=None)
req = _make_request({"Authorization": "Bearer client-key"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("authorization", combined)
def test_hop_by_hop_stripped(self):
state = _make_state()
req = _make_request({
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
"X-Custom": "value",
})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("connection", combined)
self.assertNotIn("transfer-encoding", combined)
self.assertEqual(combined.get("x-custom"), "value")
def test_host_header_stripped(self):
state = _make_state()
req = _make_request({"Host": "proxy.local", "Accept": "application/json"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("host", combined)
self.assertIn("accept", combined)
def _mock_aiohttp_response(
status: int = 200,
content_type: str = "application/json",
body: bytes = b'{"ok":true}',
headers: dict | None = None,
) -> MagicMock:
resp = MagicMock()
resp.status = status
all_headers = {"Content-Type": content_type}
if headers:
all_headers.update(headers)
resp.headers = all_headers
resp.read = AsyncMock(return_value=body)
resp.content = MagicMock()
resp.content.iter_chunked = MagicMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx, resp
class TestForwardRequestNonStreaming(unittest.IsolatedAsyncioTestCase):
async def _acquire(self, tracker: SlotTracker, url: str) -> None:
async with asyncio.timeout(1.0):
await tracker.acquire(url)
async def test_successful_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"choices":[]}')
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.body, b'{"choices":[]}')
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_slot_released_on_forward_error(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(side_effect=Exception("network error"))
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
with self.assertRaises(Exception):
await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_non_streaming_returns_full_body(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
body = b'{"id":"xyz","choices":[{"text":"hello"}]}'
ctx, _ = _mock_aiohttp_response(body=body)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.body, body)
async def test_status_code_preserved(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=429, body=b"rate limited")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 429)
async def test_model_id_tracked_in_active_models(self):
"""model_id passed to forward_request is reflected in slot tracker."""
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1", "my-model")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b"{}")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
await forward_request(req, state, session, slot_tracker, scheduler, model_id="my-model")
self.assertEqual(slot_tracker.active_model_set("http://b1"), frozenset())
class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
async def test_streaming_sse_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
sse_chunks = [
b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n',
b'data: [DONE]\n\n',
]
async def fake_iter_chunked(_size):
for chunk in sse_chunks:
yield chunk
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
from fastapi.responses import StreamingResponse
self.assertIsInstance(response, StreamingResponse)
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk)
self.assertEqual(chunks, sse_chunks)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_streaming_slot_released_on_stream_end(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
async def fake_iter_chunked(_size):
yield b"data: chunk1\n\n"
yield b"data: chunk2\n\n"
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 1)
chunks = [chunk async for chunk in response.body_iterator]
self.assertEqual(len(chunks), 2)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
class TestForwardBestEffort(unittest.IsolatedAsyncioTestCase):
async def test_no_backends_returns_503(self):
req = _make_request({})
session = MagicMock()
response = await forward_best_effort(req, [], session)
self.assertEqual(response.status_code, 503)
async def test_single_backend_success(self):
state = _make_state("http://b1")
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"ok":true}')
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_best_effort(req, [state], session)
self.assertEqual(response.status_code, 200)
async def test_falls_back_on_connection_error(self):
"""If the first selected backend fails, the next one is tried."""
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
ctx_fail = MagicMock()
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("connection refused"))
ctx_fail.__aexit__ = AsyncMock(return_value=False)
ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok")
def pick_ctx(method, url, **kwargs):
return ctx_fail if "b1" in url else ctx_ok
session = MagicMock()
session.request = MagicMock(side_effect=pick_ctx)
req = _make_request({})
response = await forward_best_effort(req, [state1, state2], session)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.body, b"ok")
async def test_all_backends_fail_returns_502(self):
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
ctx_fail = MagicMock()
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("network error"))
ctx_fail.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx_fail)
req = _make_request({})
response = await forward_best_effort(req, [state1, state2], session)
self.assertEqual(response.status_code, 502)
async def test_preferred_url_tried_first(self):
"""When preferred_url is given, that backend is attempted before others."""
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
calls: list[str] = []
def fake_iter(method, url, **kwargs):
calls.append(url.split("//")[1].split("/")[0]) # extract host
ctx, _ = _mock_aiohttp_response(status=200, body=b"ok")
return ctx
session = MagicMock()
session.request = MagicMock(side_effect=fake_iter)
req = _make_request({})
await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2")
self.assertEqual(calls[0], "b2")
async def test_preferred_url_fallback_on_failure(self):
"""If preferred backend fails, the next one is still tried."""
state1 = _make_state("http://b1")
state2 = _make_state("http://b2")
ctx_fail = MagicMock()
ctx_fail.__aenter__ = AsyncMock(side_effect=Exception("down"))
ctx_fail.__aexit__ = AsyncMock(return_value=False)
ctx_ok, _ = _mock_aiohttp_response(status=200, body=b"ok")
def pick(method, url, **kwargs):
return ctx_fail if "b2" in url else ctx_ok
session = MagicMock()
session.request = MagicMock(side_effect=pick)
req = _make_request({})
response = await forward_best_effort(req, [state1, state2], session, preferred_url="http://b2")
self.assertEqual(response.status_code, 200)
async def test_status_code_and_body_preserved(self):
state = _make_state("http://b1")
ctx, _ = _mock_aiohttp_response(status=404, body=b"not found")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_best_effort(req, [state], session)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.body, b"not found")
if __name__ == "__main__":
unittest.main()