134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
import json
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from llamacpp_ha.config import BackendConfig, ProxyConfig
|
|
|
|
|
|
class TestBackendConfig(unittest.TestCase):
|
|
def test_url_trailing_slash_stripped(self):
|
|
b = BackendConfig(url="http://localhost:8080/")
|
|
self.assertEqual(b.url, "http://localhost:8080")
|
|
|
|
def test_defaults(self):
|
|
b = BackendConfig(url="http://localhost:8080")
|
|
self.assertIsNone(b.api_key)
|
|
self.assertEqual(b.model_ids, [])
|
|
self.assertIsNone(b.max_models)
|
|
|
|
def test_explicit_model_ids(self):
|
|
b = BackendConfig(url="http://x", model_ids=["llama3", "mistral"])
|
|
self.assertEqual(b.model_ids, ["llama3", "mistral"])
|
|
|
|
def test_max_models(self):
|
|
b = BackendConfig(url="http://x", max_models=1)
|
|
self.assertEqual(b.max_models, 1)
|
|
|
|
def test_max_models_none_is_unlimited(self):
|
|
b = BackendConfig(url="http://x", max_models=None)
|
|
self.assertIsNone(b.max_models)
|
|
|
|
|
|
class TestProxyConfig(unittest.TestCase):
|
|
def _write_config(self, data: dict) -> str:
|
|
f = tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".json", delete=False
|
|
)
|
|
json.dump(data, f)
|
|
f.close()
|
|
return f.name
|
|
|
|
def test_defaults(self):
|
|
cfg = ProxyConfig()
|
|
self.assertEqual(cfg.host, "0.0.0.0")
|
|
self.assertEqual(cfg.port, 8080)
|
|
self.assertEqual(cfg.api_keys, [])
|
|
self.assertEqual(cfg.poll_interval, 5.0)
|
|
self.assertEqual(cfg.slot_wait_timeout, 30.0)
|
|
self.assertEqual(cfg.session_idle_ttl, 300.0)
|
|
self.assertEqual(cfg.backends, [])
|
|
self.assertIsNone(cfg.default_max_models)
|
|
self.assertEqual(cfg.model_affinity_sched_bonus, 0)
|
|
self.assertEqual(cfg.queue_aging_equalization, 30.0)
|
|
|
|
def test_from_file_minimal(self):
|
|
path = self._write_config(
|
|
{"backends": [{"url": "http://localhost:8081"}]}
|
|
)
|
|
try:
|
|
cfg = ProxyConfig.from_file(path)
|
|
self.assertEqual(len(cfg.backends), 1)
|
|
self.assertEqual(cfg.backends[0].url, "http://localhost:8081")
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
def test_from_file_full(self):
|
|
data = {
|
|
"host": "127.0.0.1",
|
|
"port": 9090,
|
|
"api_keys": ["key1", "key2"],
|
|
"poll_interval": 10,
|
|
"slot_wait_timeout": 60,
|
|
"session_idle_ttl": 600,
|
|
"default_max_models": 2,
|
|
"model_affinity_sched_bonus": 10,
|
|
"queue_aging_equalization": 30.0,
|
|
"backends": [
|
|
{"url": "http://b1", "api_key": "secret", "model_ids": ["m1"], "max_models": 1},
|
|
{"url": "http://b2/"},
|
|
],
|
|
}
|
|
path = self._write_config(data)
|
|
try:
|
|
cfg = ProxyConfig.from_file(path)
|
|
self.assertEqual(cfg.host, "127.0.0.1")
|
|
self.assertEqual(cfg.port, 9090)
|
|
self.assertEqual(cfg.api_keys, ["key1", "key2"])
|
|
self.assertEqual(cfg.poll_interval, 10)
|
|
self.assertEqual(cfg.default_max_models, 2)
|
|
self.assertEqual(cfg.model_affinity_sched_bonus, 10)
|
|
self.assertEqual(cfg.queue_aging_equalization, 30.0)
|
|
self.assertEqual(cfg.backends[0].api_key, "secret")
|
|
self.assertEqual(cfg.backends[0].model_ids, ["m1"])
|
|
self.assertEqual(cfg.backends[0].max_models, 1)
|
|
self.assertEqual(cfg.backends[1].url, "http://b2")
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
def test_from_file_with_overrides(self):
|
|
path = self._write_config({"port": 8080, "backends": []})
|
|
try:
|
|
cfg = ProxyConfig.from_file(path, port=9999)
|
|
self.assertEqual(cfg.port, 9999)
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
def test_env_var_override(self):
|
|
old = os.environ.get("LLAMACPP_HA_PORT")
|
|
try:
|
|
os.environ["LLAMACPP_HA_PORT"] = "7777"
|
|
cfg = ProxyConfig()
|
|
self.assertEqual(cfg.port, 7777)
|
|
finally:
|
|
if old is None:
|
|
os.environ.pop("LLAMACPP_HA_PORT", None)
|
|
else:
|
|
os.environ["LLAMACPP_HA_PORT"] = old
|
|
|
|
def test_invalid_json_raises(self):
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".json", delete=False
|
|
) as f:
|
|
f.write("not json {{{")
|
|
path = f.name
|
|
try:
|
|
with self.assertRaises(Exception):
|
|
ProxyConfig.from_file(path)
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|