Files
llamacpp-ha/tests/test_forwarder.py
2026-05-17 09:54:18 +02:00

277 lines
9.8 KiB
Python

import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import Request
from starlette.datastructures import Headers
from llamacpp_ha.config import BackendConfig
from llamacpp_ha.forwarder import _forward_headers, forward_request
from llamacpp_ha.registry import BackendState
from llamacpp_ha.slot_tracker import SlotTracker
def _make_state(url="http://b1", api_key=None):
cfg = BackendConfig(url=url, api_key=api_key)
return BackendState(config=cfg)
def _make_request(headers: dict, method="POST", path="/v1/chat/completions") -> MagicMock:
req = MagicMock()
req.headers = Headers(headers=headers)
req.url.path = path
req.url.query = ""
req.method = method
req.body = AsyncMock(return_value=b"")
return req
def _make_scheduler():
sched = MagicMock()
sched.notify_slot_released = MagicMock()
return sched
class TestForwardHeaders(unittest.TestCase):
def test_injects_backend_api_key(self):
state = _make_state(api_key="backend-secret")
req = _make_request({"Authorization": "Bearer client-key", "Content-Type": "application/json"})
headers = _forward_headers(req, state)
self.assertEqual(headers["Authorization"], "Bearer backend-secret")
combined = {k.lower(): v for k, v in headers.items()}
self.assertIn("application/json", combined.get("content-type", ""))
def test_removes_auth_when_no_backend_key(self):
state = _make_state(api_key=None)
req = _make_request({"Authorization": "Bearer client-key"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("authorization", combined)
def test_hop_by_hop_stripped(self):
state = _make_state()
req = _make_request({
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
"X-Custom": "value",
})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("connection", combined)
self.assertNotIn("transfer-encoding", combined)
self.assertEqual(combined.get("x-custom"), "value")
def test_host_header_stripped(self):
state = _make_state()
req = _make_request({"Host": "proxy.local", "Accept": "application/json"})
headers = _forward_headers(req, state)
combined = {k.lower(): v for k, v in headers.items()}
self.assertNotIn("host", combined)
self.assertIn("accept", combined)
def _mock_aiohttp_response(
status: int = 200,
content_type: str = "application/json",
body: bytes = b'{"ok":true}',
headers: dict | None = None,
) -> MagicMock:
resp = MagicMock()
resp.status = status
all_headers = {"Content-Type": content_type}
if headers:
all_headers.update(headers)
resp.headers = all_headers
resp.read = AsyncMock(return_value=body)
resp.content = MagicMock()
resp.content.iter_chunked = MagicMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx, resp
class TestForwardRequestNonStreaming(unittest.IsolatedAsyncioTestCase):
async def _acquire(self, tracker: SlotTracker, url: str) -> None:
async with asyncio.timeout(1.0):
await tracker.acquire(url)
async def test_successful_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"choices":[]}')
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.body, b'{"choices":[]}')
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_slot_released_on_forward_error(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(side_effect=Exception("network error"))
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
with self.assertRaises(Exception):
await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_non_streaming_returns_full_body(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
body = b'{"id":"xyz","choices":[{"text":"hello"}]}'
ctx, _ = _mock_aiohttp_response(body=body)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.body, body)
async def test_status_code_preserved(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
await self._acquire(slot_tracker, "http://b1")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=429, body=b"rate limited")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
self.assertEqual(response.status_code, 429)
async def test_model_id_tracked_in_active_models(self):
"""model_id passed to forward_request is reflected in slot tracker."""
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 2)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1", "my-model")
scheduler = _make_scheduler()
ctx, _ = _mock_aiohttp_response(status=200, body=b"{}")
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
await forward_request(req, state, session, slot_tracker, scheduler, model_id="my-model")
self.assertEqual(slot_tracker.active_model_set("http://b1"), frozenset())
class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
async def test_streaming_sse_passthrough(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
sse_chunks = [
b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n',
b'data: [DONE]\n\n',
]
async def fake_iter_chunked(_size):
for chunk in sse_chunks:
yield chunk
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
from fastapi.responses import StreamingResponse
self.assertIsInstance(response, StreamingResponse)
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk)
self.assertEqual(chunks, sse_chunks)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
scheduler.notify_slot_released.assert_called_once()
async def test_streaming_slot_released_on_stream_end(self):
state = _make_state("http://b1")
slot_tracker = SlotTracker()
slot_tracker.set_capacity("http://b1", 1)
async with asyncio.timeout(1.0):
await slot_tracker.acquire("http://b1")
scheduler = _make_scheduler()
async def fake_iter_chunked(_size):
yield b"data: chunk1\n\n"
yield b"data: chunk2\n\n"
resp = MagicMock()
resp.status = 200
resp.headers = {"Content-Type": "text/event-stream"}
resp.content = MagicMock()
resp.content.iter_chunked = fake_iter_chunked
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=resp)
ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.request = MagicMock(return_value=ctx)
req = _make_request({})
response = await forward_request(req, state, session, slot_tracker, scheduler)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 1)
chunks = [chunk async for chunk in response.body_iterator]
self.assertEqual(len(chunks), 2)
acquired, _ = slot_tracker.usage("http://b1")
self.assertEqual(acquired, 0)
if __name__ == "__main__":
unittest.main()