277 lines
9.8 KiB
Python
277 lines
9.8 KiB
Python
import asyncio
|
|
import unittest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from fastapi import Request
|
|
from starlette.datastructures import Headers
|
|
|
|
from llamacpp_ha.config import BackendConfig
|
|
from llamacpp_ha.forwarder import _forward_headers, forward_request
|
|
from llamacpp_ha.registry import BackendState
|
|
from llamacpp_ha.slot_tracker import SlotTracker
|
|
|
|
|
|
def _make_state(url="http://b1", api_key=None):
|
|
cfg = BackendConfig(url=url, api_key=api_key)
|
|
return BackendState(config=cfg)
|
|
|
|
|
|
def _make_request(headers: dict, method="POST", path="/v1/chat/completions") -> MagicMock:
|
|
req = MagicMock()
|
|
req.headers = Headers(headers=headers)
|
|
req.url.path = path
|
|
req.url.query = ""
|
|
req.method = method
|
|
req.body = AsyncMock(return_value=b"")
|
|
return req
|
|
|
|
|
|
def _make_scheduler():
|
|
sched = MagicMock()
|
|
sched.notify_slot_released = MagicMock()
|
|
return sched
|
|
|
|
|
|
class TestForwardHeaders(unittest.TestCase):
|
|
def test_injects_backend_api_key(self):
|
|
state = _make_state(api_key="backend-secret")
|
|
req = _make_request({"Authorization": "Bearer client-key", "Content-Type": "application/json"})
|
|
headers = _forward_headers(req, state)
|
|
self.assertEqual(headers["Authorization"], "Bearer backend-secret")
|
|
combined = {k.lower(): v for k, v in headers.items()}
|
|
self.assertIn("application/json", combined.get("content-type", ""))
|
|
|
|
def test_removes_auth_when_no_backend_key(self):
|
|
state = _make_state(api_key=None)
|
|
req = _make_request({"Authorization": "Bearer client-key"})
|
|
headers = _forward_headers(req, state)
|
|
combined = {k.lower(): v for k, v in headers.items()}
|
|
self.assertNotIn("authorization", combined)
|
|
|
|
def test_hop_by_hop_stripped(self):
|
|
state = _make_state()
|
|
req = _make_request({
|
|
"Connection": "keep-alive",
|
|
"Transfer-Encoding": "chunked",
|
|
"X-Custom": "value",
|
|
})
|
|
headers = _forward_headers(req, state)
|
|
combined = {k.lower(): v for k, v in headers.items()}
|
|
self.assertNotIn("connection", combined)
|
|
self.assertNotIn("transfer-encoding", combined)
|
|
self.assertEqual(combined.get("x-custom"), "value")
|
|
|
|
def test_host_header_stripped(self):
|
|
state = _make_state()
|
|
req = _make_request({"Host": "proxy.local", "Accept": "application/json"})
|
|
headers = _forward_headers(req, state)
|
|
combined = {k.lower(): v for k, v in headers.items()}
|
|
self.assertNotIn("host", combined)
|
|
self.assertIn("accept", combined)
|
|
|
|
|
|
def _mock_aiohttp_response(
|
|
status: int = 200,
|
|
content_type: str = "application/json",
|
|
body: bytes = b'{"ok":true}',
|
|
headers: dict | None = None,
|
|
) -> MagicMock:
|
|
resp = MagicMock()
|
|
resp.status = status
|
|
all_headers = {"Content-Type": content_type}
|
|
if headers:
|
|
all_headers.update(headers)
|
|
resp.headers = all_headers
|
|
resp.read = AsyncMock(return_value=body)
|
|
resp.content = MagicMock()
|
|
resp.content.iter_chunked = MagicMock()
|
|
|
|
ctx = MagicMock()
|
|
ctx.__aenter__ = AsyncMock(return_value=resp)
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
return ctx, resp
|
|
|
|
|
|
class TestForwardRequestNonStreaming(unittest.IsolatedAsyncioTestCase):
|
|
async def _acquire(self, tracker: SlotTracker, url: str) -> None:
|
|
async with asyncio.timeout(1.0):
|
|
await tracker.acquire(url)
|
|
|
|
async def test_successful_passthrough(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
await self._acquire(slot_tracker, "http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
ctx, _ = _mock_aiohttp_response(status=200, body=b'{"choices":[]}')
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
response = await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.body, b'{"choices":[]}')
|
|
acquired, _ = slot_tracker.usage("http://b1")
|
|
self.assertEqual(acquired, 0)
|
|
scheduler.notify_slot_released.assert_called_once()
|
|
|
|
async def test_slot_released_on_forward_error(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
await self._acquire(slot_tracker, "http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
ctx = MagicMock()
|
|
ctx.__aenter__ = AsyncMock(side_effect=Exception("network error"))
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
with self.assertRaises(Exception):
|
|
await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
acquired, _ = slot_tracker.usage("http://b1")
|
|
self.assertEqual(acquired, 0)
|
|
scheduler.notify_slot_released.assert_called_once()
|
|
|
|
async def test_non_streaming_returns_full_body(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
await self._acquire(slot_tracker, "http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
body = b'{"id":"xyz","choices":[{"text":"hello"}]}'
|
|
ctx, _ = _mock_aiohttp_response(body=body)
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
response = await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
self.assertEqual(response.body, body)
|
|
|
|
async def test_status_code_preserved(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
await self._acquire(slot_tracker, "http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
ctx, _ = _mock_aiohttp_response(status=429, body=b"rate limited")
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
response = await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
self.assertEqual(response.status_code, 429)
|
|
|
|
async def test_model_id_tracked_in_active_models(self):
|
|
"""model_id passed to forward_request is reflected in slot tracker."""
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 2)
|
|
async with asyncio.timeout(1.0):
|
|
await slot_tracker.acquire("http://b1", "my-model")
|
|
scheduler = _make_scheduler()
|
|
|
|
ctx, _ = _mock_aiohttp_response(status=200, body=b"{}")
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
await forward_request(req, state, session, slot_tracker, scheduler, model_id="my-model")
|
|
|
|
self.assertEqual(slot_tracker.active_model_set("http://b1"), frozenset())
|
|
|
|
|
|
class TestForwardRequestStreaming(unittest.IsolatedAsyncioTestCase):
|
|
async def test_streaming_sse_passthrough(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slot_tracker.acquire("http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
sse_chunks = [
|
|
b'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n',
|
|
b'data: [DONE]\n\n',
|
|
]
|
|
|
|
async def fake_iter_chunked(_size):
|
|
for chunk in sse_chunks:
|
|
yield chunk
|
|
|
|
resp = MagicMock()
|
|
resp.status = 200
|
|
resp.headers = {"Content-Type": "text/event-stream"}
|
|
resp.content = MagicMock()
|
|
resp.content.iter_chunked = fake_iter_chunked
|
|
|
|
ctx = MagicMock()
|
|
ctx.__aenter__ = AsyncMock(return_value=resp)
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
response = await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
self.assertIsInstance(response, StreamingResponse)
|
|
|
|
chunks = []
|
|
async for chunk in response.body_iterator:
|
|
chunks.append(chunk)
|
|
|
|
self.assertEqual(chunks, sse_chunks)
|
|
acquired, _ = slot_tracker.usage("http://b1")
|
|
self.assertEqual(acquired, 0)
|
|
scheduler.notify_slot_released.assert_called_once()
|
|
|
|
async def test_streaming_slot_released_on_stream_end(self):
|
|
state = _make_state("http://b1")
|
|
slot_tracker = SlotTracker()
|
|
slot_tracker.set_capacity("http://b1", 1)
|
|
async with asyncio.timeout(1.0):
|
|
await slot_tracker.acquire("http://b1")
|
|
scheduler = _make_scheduler()
|
|
|
|
async def fake_iter_chunked(_size):
|
|
yield b"data: chunk1\n\n"
|
|
yield b"data: chunk2\n\n"
|
|
|
|
resp = MagicMock()
|
|
resp.status = 200
|
|
resp.headers = {"Content-Type": "text/event-stream"}
|
|
resp.content = MagicMock()
|
|
resp.content.iter_chunked = fake_iter_chunked
|
|
|
|
ctx = MagicMock()
|
|
ctx.__aenter__ = AsyncMock(return_value=resp)
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
session = MagicMock()
|
|
session.request = MagicMock(return_value=ctx)
|
|
|
|
req = _make_request({})
|
|
response = await forward_request(req, state, session, slot_tracker, scheduler)
|
|
|
|
acquired, _ = slot_tracker.usage("http://b1")
|
|
self.assertEqual(acquired, 1)
|
|
|
|
chunks = [chunk async for chunk in response.body_iterator]
|
|
self.assertEqual(len(chunks), 2)
|
|
|
|
acquired, _ = slot_tracker.usage("http://b1")
|
|
self.assertEqual(acquired, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|