Files
pyrestresource/test/test_rest_resource_plugins.py
2023-11-06 01:11:43 +00:00

175 lines
7.0 KiB
Python

from __future__ import annotations
import unittest
from os import chdir
from pathlib import Path
from typing import Annotated
from src.pyrestresource import (
RestField,
register_rest_rootpoint,
RestResourceBase,
rsrc_verb,
RestRequestParams_GET,
RestRequestParams_POST,
RestRequestParams_Dict_GET,
RestRequestParams_PUT,
T_SupportedRESTFields,
ResourcePlugin_field_default,
ResourcePlugin_RestResourceBase_default,
RestResourcePluginException_InvalidPluginSignature,
)
testdir_path = Path(__file__).parent.resolve()
chdir(testdir_path.parent.resolve())
# to allow mock-ing, all the tested classes are in a function
def init_classes():
class ResourcePlugin_version_get(ResourcePlugin_field_default):
def handle_field_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get:
return "1.5.6"
class ResourcePlugin_version_put(ResourcePlugin_field_default):
def handle_field_put(self, resource: Info_put, params: RestRequestParams_PUT) -> Info_put:
return "42"
class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default):
def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get:
return Info_get(version="65.45", api_version="98.321")
class Info_get(RestResourceBase):
# test plugin injection within annotation
# + test plugin on a simple field
version: Annotated[str, RestField(plugin=ResourcePlugin_version_get)]
api_version: str
class Info_put(RestResourceBase):
# test plugin injection within annotation
# + test plugin on a simple field
version: Annotated[str, RestField(plugin=ResourcePlugin_version_put)]
api_version: str
@register_rest_rootpoint
class RootApp(RestResourceBase):
# test plugin injection within Field value
# + test plugin on a RestResourceBase field
info: Info_get = RestField(
default=Info_get(version="0.0.1", api_version="0.0.2"),
plugin=ResourcePlugin_Info,
)
info_put: Info_put = RestField(
default=Info_put(version="0.0.1", api_version="0.0.2"),
)
info2: Info_get = RestField(default=Info_get(version="0.0.2", api_version="0.0.3"))
# this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use
globals()[Info_get.__name__] = Info_get
globals()[Info_put.__name__] = Info_put
globals()[RootApp.__name__] = RootApp
def init_bad_plugin1():
# plugin not inheriting from the right base type
class ResourcePlugin_TestResource:
...
class TestResource(RestResourceBase):
tetvaluestr: Annotated[str, RestField(plugin=ResourcePlugin_TestResource)]
@register_rest_rootpoint
class RootApp2(RestResourceBase):
test: TestResource = RestField(default=TestResource(tetvaluestr="testvalue"))
RootApp2()
class Test_RestAPI_Plugin_PUT(unittest.TestCase):
def setUp(self) -> None:
chdir(testdir_path.parent.resolve())
init_classes()
self.testapp = RootApp()
def test_put_field_version_fieldplugin(self):
self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"')
result = self.testapp.process_request("/info_put", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "42", "api_version": "0.0.2"}')
result = self.testapp.process_request("/info_put/version", rsrc_verb.GET)
self.assertEqual(result.get_result(), '"42"')
def test_put_field_version_resourceplugin(self):
self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info_put", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "42", "api_version": "98.321"}')
class Test_RestAPI_Plugin_GET(unittest.TestCase):
def setUp(self) -> None:
chdir(testdir_path.parent.resolve())
init_classes()
self.testapp = RootApp()
def test_get_root(self):
result = self.testapp.process_request("/", rsrc_verb.GET)
self.assertEqual(result.get_result(), "{}")
def test_get_version(self):
result = self.testapp.process_request("/info", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info2", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__trailing_slash(self):
result = self.testapp.process_request("/info/", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info//", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info///", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info2/", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("/info2//", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("/info2///", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__multiple_slash(self):
result = self.testapp.process_request("//info", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("///info", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("//info2", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("///info2", rsrc_verb.GET)
self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__nested_value(self):
result = self.testapp.process_request("/info/api_version", rsrc_verb.GET)
self.assertEqual(result.get_result(), '"98.321"')
result = self.testapp.process_request("/info/version", rsrc_verb.GET)
self.assertEqual(result.get_result(), '"1.5.6"')
result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET)
self.assertEqual(result.get_result(), '"0.0.3"')
result = self.testapp.process_request("/info2/version", rsrc_verb.GET)
self.assertEqual(result.get_result(), '"1.5.6"')
def test_defect_plugin_field(self):
with self.assertRaises(RestResourcePluginException_InvalidPluginSignature):
init_bad_plugin1()