Files
llamacpp-ha/tests/test_middleware.py
2026-05-17 09:54:18 +02:00

80 lines
2.5 KiB
Python

import unittest
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from llamacpp_ha.middleware import ApiKeyMiddleware
def _make_app(api_keys: list[str]) -> FastAPI:
app = FastAPI()
app.add_middleware(ApiKeyMiddleware, api_keys=api_keys)
@app.get("/test")
async def test_endpoint():
return {"ok": True}
@app.get("/monitor")
async def monitor():
return {"monitor": True}
@app.get("/monitor/data")
async def monitor_data():
return {"data": True}
return app
class TestApiKeyMiddleware(unittest.TestCase):
def test_no_keys_configured_passes_all(self):
client = TestClient(_make_app([]))
resp = client.get("/test")
self.assertEqual(resp.status_code, 200)
def test_valid_key_passes(self):
client = TestClient(_make_app(["key1", "key2"]))
resp = client.get("/test", headers={"Authorization": "Bearer key1"})
self.assertEqual(resp.status_code, 200)
def test_valid_second_key_passes(self):
client = TestClient(_make_app(["key1", "key2"]))
resp = client.get("/test", headers={"Authorization": "Bearer key2"})
self.assertEqual(resp.status_code, 200)
def test_missing_key_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test")
self.assertEqual(resp.status_code, 401)
def test_wrong_key_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test", headers={"Authorization": "Bearer wrongkey"})
self.assertEqual(resp.status_code, 401)
def test_malformed_auth_returns_401(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test", headers={"Authorization": "key1"})
self.assertEqual(resp.status_code, 401)
def test_monitor_exempt(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/monitor")
self.assertEqual(resp.status_code, 200)
def test_monitor_data_exempt(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/monitor/data")
self.assertEqual(resp.status_code, 200)
def test_error_response_is_json(self):
client = TestClient(_make_app(["key1"]))
resp = client.get("/test")
self.assertEqual(resp.headers["content-type"], "application/json")
body = resp.json()
self.assertIn("error", body)
if __name__ == "__main__":
unittest.main()