from __future__ import annotations import unittest from os import chdir from pathlib import Path from typing import Annotated from src.pyrestresource import ( RestField, register_rest_rootpoint, RestResourceBase, rsrc_verb, RestRequestParams_GET, RestRequestParams_POST, RestRequestParams_Dict_GET, RestRequestParams_PUT, T_SupportedRESTFields, ResourcePlugin_field_default, ResourcePlugin_RestResourceBase_default, RestResourcePluginException_InvalidPluginSignature, ) testdir_path = Path(__file__).parent.resolve() chdir(testdir_path.parent.resolve()) # to allow mock-ing, all the tested classes are in a function def init_classes(): class ResourcePlugin_version_get(ResourcePlugin_field_default): def handle_field_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: return "1.5.6" class ResourcePlugin_version_put(ResourcePlugin_field_default): def handle_field_put(self, resource: Info_put, params: RestRequestParams_PUT) -> Info_put: return "42" class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default): def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get: return Info_get(version="65.45", api_version="98.321") class Info_get(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field version: Annotated[str, RestField(plugin=ResourcePlugin_version_get)] api_version: str class Info_put(RestResourceBase): # test plugin injection within annotation # + test plugin on a simple field version: Annotated[str, RestField(plugin=ResourcePlugin_version_put)] api_version: str @register_rest_rootpoint class RootApp(RestResourceBase): # test plugin injection within Field value # + test plugin on a RestResourceBase field info: Info_get = RestField( default=Info_get(version="0.0.1", api_version="0.0.2"), plugin=ResourcePlugin_Info, ) info_put: Info_put = RestField( default=Info_put(version="0.0.1", api_version="0.0.2"), ) info2: Info_get = RestField(default=Info_get(version="0.0.2", api_version="0.0.3")) # this add the classes to globals to allow using them later on # => this is only for uinit-testing purpose and is not needed in real use globals()[Info_get.__name__] = Info_get globals()[Info_put.__name__] = Info_put globals()[RootApp.__name__] = RootApp def init_bad_plugin1(): # plugin not inheriting from the right base type class ResourcePlugin_TestResource: ... class TestResource(RestResourceBase): tetvaluestr: Annotated[str, RestField(plugin=ResourcePlugin_TestResource)] @register_rest_rootpoint class RootApp2(RestResourceBase): test: TestResource = RestField(default=TestResource(tetvaluestr="testvalue")) RootApp2() class Test_RestAPI_Plugin_PUT(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) init_classes() self.testapp = RootApp() def test_put_field_version_fieldplugin(self): self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"') result = self.testapp.process_request("/info_put", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "42", "api_version": "0.0.2"}') result = self.testapp.process_request("/info_put/version", rsrc_verb.GET) self.assertEqual(result.get_result(), '"42"') def test_put_field_version_resourceplugin(self): self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info_put", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "42", "api_version": "98.321"}') class Test_RestAPI_Plugin_GET(unittest.TestCase): def setUp(self) -> None: chdir(testdir_path.parent.resolve()) init_classes() self.testapp = RootApp() def test_get_root(self): result = self.testapp.process_request("/", rsrc_verb.GET) self.assertEqual(result.get_result(), "{}") def test_get_version(self): result = self.testapp.process_request("/info", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info2", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__trailing_slash(self): result = self.testapp.process_request("/info/", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info//", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info///", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("/info2/", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("/info2//", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("/info2///", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__multiple_slash(self): result = self.testapp.process_request("//info", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("///info", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "98.321"}') result = self.testapp.process_request("//info2", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') result = self.testapp.process_request("///info2", rsrc_verb.GET) self.assertEqual(result.get_result(), '{"version": "1.5.6", "api_version": "0.0.3"}') def test_get_version__nested_value(self): result = self.testapp.process_request("/info/api_version", rsrc_verb.GET) self.assertEqual(result.get_result(), '"98.321"') result = self.testapp.process_request("/info/version", rsrc_verb.GET) self.assertEqual(result.get_result(), '"1.5.6"') result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET) self.assertEqual(result.get_result(), '"0.0.3"') result = self.testapp.process_request("/info2/version", rsrc_verb.GET) self.assertEqual(result.get_result(), '"1.5.6"') def test_defect_plugin_field(self): with self.assertRaises(RestResourcePluginException_InvalidPluginSignature): init_bad_plugin1()