Files
pyrestresource/test/test_rest_resource_plugins.py
2023-10-31 23:58:51 +00:00

217 lines
8.0 KiB
Python

from __future__ import annotations
import unittest
from unittest.mock import patch
from os import chdir
from pathlib import Path
from typing import Optional, Annotated
from pydantic import Field
from uuid import UUID, uuid4
from time import time
import json
print(__name__)
print(__package__)
from src.pyrestresource import (
register_rest_rootpoint,
RestResourceBase,
rsrc_verb,
RestRequestParams_GET,
RestRequestParams_POST,
RestRequestParams_Dict_GET,
RestRequestParams_PUT,
T_SupportedRESTFields,
ResourcePlugin_field_default,
ResourcePlugin_RestResourceBase_default,
)
from pprint import pprint
testdir_path = Path(__file__).parent.resolve()
chdir(testdir_path.parent.resolve())
# to allow mock-ing, all the tested classes are in a function
def init_classes():
class ResourcePlugin_version_get(ResourcePlugin_field_default):
def handle_field_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get:
return "1.5.6"
class ResourcePlugin_version_put(ResourcePlugin_field_default):
def handle_field_put(self, resource: Info_put, params: RestRequestParams_PUT) -> Info_put:
return "42"
class ResourcePlugin_Info(ResourcePlugin_RestResourceBase_default):
def handle_resource_get(self, resource: Info_get, params: RestRequestParams_GET) -> Info_get:
return Info_get(version="65.45", api_version="98.321")
class Info_get(RestResourceBase):
# test plugin injection within annotation
# + test plugin on a simple field
version: Annotated[str, Field(plugin=ResourcePlugin_version_get)]
api_version: str
class Info_put(RestResourceBase):
# test plugin injection within annotation
# + test plugin on a simple field
version: Annotated[str, Field(plugin=ResourcePlugin_version_put)]
api_version: str
@register_rest_rootpoint
class RootApp(RestResourceBase):
# test plugin injection within Field value
# + test plugin on a RestResourceBase field
info: Info_get = Field(
default=Info_get(version="0.0.1", api_version="0.0.2"),
plugin=ResourcePlugin_Info,
)
info_put: Info_put = Field(
default=Info_put(version="0.0.1", api_version="0.0.2"),
)
info2: Info_get = Field(default=Info_get(version="0.0.2", api_version="0.0.3"))
# this add the classes to globals to allow using them later on
# => this is only for uinit-testing purpose and is not needed in real use
globals()[Info_get.__name__] = Info_get
globals()[Info_put.__name__] = Info_put
globals()[RootApp.__name__] = RootApp
def init_bad_plugin1():
# plugin with missing handle_resource_put() method
class ResourcePlugin_TestResource:
def handle_field_get(self, resource: TestResource, params: RestRequestParams_GET) -> TestResource:
return resource
class TestResource(RestResourceBase):
tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)]
@register_rest_rootpoint
class RootApp2(RestResourceBase):
test: TestResource = Field(default=TestResource(tetvaluestr="testvalue"))
RootApp2()
def init_bad_plugin2():
# plugin with missing handle_resource_get() method
class ResourcePlugin_TestResource:
def handle_field_put(self, resource: TestResource, params: RestRequestParams_PUT) -> TestResource:
return resource
class TestResource(RestResourceBase):
tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)]
@register_rest_rootpoint
class RootApp2(RestResourceBase):
test: TestResource = Field(default=TestResource(tetvaluestr="testvalue"))
RootApp2()
def init_bad_plugin3():
# wrong plugin
class ResourcePlugin_TestResource(ResourcePlugin_RestResourceBase_default):
pass
class TestResource(RestResourceBase):
tetvaluestr: Annotated[str, Field(plugin=ResourcePlugin_TestResource)]
@register_rest_rootpoint
class RootApp2(RestResourceBase):
test: TestResource = Field(default=TestResource(tetvaluestr="testvalue"))
RootApp2()
class Test_RestAPI_Plugin_PUT(unittest.TestCase):
def setUp(self) -> None:
chdir(testdir_path.parent.resolve())
init_classes()
self.testapp = RootApp()
def test_put_field_version_fieldplugin(self):
self.testapp.process_request("/info_put/version", rsrc_verb.PUT, '"1.5.6"')
result = self.testapp.process_request("/info_put", rsrc_verb.GET)
print(result)
result = self.testapp.process_request("/info_put/version", rsrc_verb.GET)
print(result)
self.assertEqual(result, '"42"')
def test_put_field_version_resourceplugin(self):
self.testapp.process_request("/info_put", rsrc_verb.PUT, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info_put", rsrc_verb.GET)
self.assertEqual(result, '{"version": "42", "api_version": "98.321"}')
class Test_RestAPI_Plugin_GET(unittest.TestCase):
def setUp(self) -> None:
chdir(testdir_path.parent.resolve())
init_classes()
self.testapp = RootApp()
def test_get_root(self):
result = self.testapp.process_request("/", rsrc_verb.GET)
self.assertEqual(result, "{}")
def test_get_version(self):
result = self.testapp.process_request("/info", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info2", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__trailing_slash(self):
result = self.testapp.process_request("/info/", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info//", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info///", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("/info2/", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("/info2//", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("/info2///", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__multiple_slash(self):
result = self.testapp.process_request("//info", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("///info", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "98.321"}')
result = self.testapp.process_request("//info2", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
result = self.testapp.process_request("///info2", rsrc_verb.GET)
self.assertEqual(result, '{"version": "1.5.6", "api_version": "0.0.3"}')
def test_get_version__nested_value(self):
result = self.testapp.process_request("/info/api_version", rsrc_verb.GET)
self.assertEqual(result, '"98.321"')
result = self.testapp.process_request("/info/version", rsrc_verb.GET)
self.assertEqual(result, '"1.5.6"')
result = self.testapp.process_request("/info2/api_version", rsrc_verb.GET)
self.assertEqual(result, '"0.0.3"')
result = self.testapp.process_request("/info2/version", rsrc_verb.GET)
self.assertEqual(result, '"1.5.6"')
def test_defect_plugin_field(self):
with self.assertRaises(RuntimeError):
init_bad_plugin1()
with self.assertRaises(RuntimeError):
init_bad_plugin2()
with self.assertRaises(RuntimeError):
init_bad_plugin3()