80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
import unittest
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.testclient import TestClient
|
|
|
|
from llamacpp_ha.middleware import ApiKeyMiddleware
|
|
|
|
|
|
def _make_app(api_keys: list[str]) -> FastAPI:
|
|
app = FastAPI()
|
|
app.add_middleware(ApiKeyMiddleware, api_keys=api_keys)
|
|
|
|
@app.get("/test")
|
|
async def test_endpoint():
|
|
return {"ok": True}
|
|
|
|
@app.get("/monitor")
|
|
async def monitor():
|
|
return {"monitor": True}
|
|
|
|
@app.get("/monitor/data")
|
|
async def monitor_data():
|
|
return {"data": True}
|
|
|
|
return app
|
|
|
|
|
|
class TestApiKeyMiddleware(unittest.TestCase):
|
|
def test_no_keys_configured_passes_all(self):
|
|
client = TestClient(_make_app([]))
|
|
resp = client.get("/test")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_valid_key_passes(self):
|
|
client = TestClient(_make_app(["key1", "key2"]))
|
|
resp = client.get("/test", headers={"Authorization": "Bearer key1"})
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_valid_second_key_passes(self):
|
|
client = TestClient(_make_app(["key1", "key2"]))
|
|
resp = client.get("/test", headers={"Authorization": "Bearer key2"})
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_missing_key_returns_401(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/test")
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_wrong_key_returns_401(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/test", headers={"Authorization": "Bearer wrongkey"})
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_malformed_auth_returns_401(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/test", headers={"Authorization": "key1"})
|
|
self.assertEqual(resp.status_code, 401)
|
|
|
|
def test_monitor_exempt(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/monitor")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_monitor_data_exempt(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/monitor/data")
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
def test_error_response_is_json(self):
|
|
client = TestClient(_make_app(["key1"]))
|
|
resp = client.get("/test")
|
|
self.assertEqual(resp.headers["content-type"], "application/json")
|
|
body = resp.json()
|
|
self.assertIn("error", body)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|