Files
llamacpp-ha/tests/test_config.py
2026-05-17 09:54:18 +02:00

131 lines
4.3 KiB
Python

import json
import os
import tempfile
import unittest
from llamacpp_ha.config import BackendConfig, ProxyConfig
class TestBackendConfig(unittest.TestCase):
def test_url_trailing_slash_stripped(self):
b = BackendConfig(url="http://localhost:8080/")
self.assertEqual(b.url, "http://localhost:8080")
def test_defaults(self):
b = BackendConfig(url="http://localhost:8080")
self.assertIsNone(b.api_key)
self.assertEqual(b.model_ids, [])
self.assertIsNone(b.max_models)
def test_explicit_model_ids(self):
b = BackendConfig(url="http://x", model_ids=["llama3", "mistral"])
self.assertEqual(b.model_ids, ["llama3", "mistral"])
def test_max_models(self):
b = BackendConfig(url="http://x", max_models=1)
self.assertEqual(b.max_models, 1)
def test_max_models_none_is_unlimited(self):
b = BackendConfig(url="http://x", max_models=None)
self.assertIsNone(b.max_models)
class TestProxyConfig(unittest.TestCase):
def _write_config(self, data: dict) -> str:
f = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
)
json.dump(data, f)
f.close()
return f.name
def test_defaults(self):
cfg = ProxyConfig()
self.assertEqual(cfg.host, "0.0.0.0")
self.assertEqual(cfg.port, 8080)
self.assertEqual(cfg.api_keys, [])
self.assertEqual(cfg.poll_interval, 5.0)
self.assertEqual(cfg.slot_wait_timeout, 30.0)
self.assertEqual(cfg.session_idle_ttl, 300.0)
self.assertEqual(cfg.backends, [])
self.assertIsNone(cfg.default_max_models)
self.assertEqual(cfg.max_queue_skip, 0)
def test_from_file_minimal(self):
path = self._write_config(
{"backends": [{"url": "http://localhost:8081"}]}
)
try:
cfg = ProxyConfig.from_file(path)
self.assertEqual(len(cfg.backends), 1)
self.assertEqual(cfg.backends[0].url, "http://localhost:8081")
finally:
os.unlink(path)
def test_from_file_full(self):
data = {
"host": "127.0.0.1",
"port": 9090,
"api_keys": ["key1", "key2"],
"poll_interval": 10,
"slot_wait_timeout": 60,
"session_idle_ttl": 600,
"default_max_models": 2,
"max_queue_skip": 5,
"backends": [
{"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1},
{"url": "http://b2/"},
],
}
path = self._write_config(data)
try:
cfg = ProxyConfig.from_file(path)
self.assertEqual(cfg.host, "127.0.0.1")
self.assertEqual(cfg.port, 9090)
self.assertEqual(cfg.api_keys, ["key1", "key2"])
self.assertEqual(cfg.poll_interval, 10)
self.assertEqual(cfg.default_max_models, 2)
self.assertEqual(cfg.max_queue_skip, 5)
self.assertEqual(cfg.backends[0].api_key, "secret")
self.assertEqual(cfg.backends[0].model_ids, ["m1"])
self.assertEqual(cfg.backends[0].max_models, 1)
self.assertEqual(cfg.backends[1].url, "http://b2")
finally:
os.unlink(path)
def test_from_file_with_overrides(self):
path = self._write_config({"port": 8080, "backends": []})
try:
cfg = ProxyConfig.from_file(path, port=9999)
self.assertEqual(cfg.port, 9999)
finally:
os.unlink(path)
def test_env_var_override(self):
old = os.environ.get("LLAMACPP_HA_PORT")
try:
os.environ["LLAMACPP_HA_PORT"] = "7777"
cfg = ProxyConfig()
self.assertEqual(cfg.port, 7777)
finally:
if old is None:
os.environ.pop("LLAMACPP_HA_PORT", None)
else:
os.environ["LLAMACPP_HA_PORT"] = old
def test_invalid_json_raises(self):
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
f.write("not json {{{")
path = f.name
try:
with self.assertRaises(Exception):
ProxyConfig.from_file(path)
finally:
os.unlink(path)
if __name__ == "__main__":
unittest.main()