# SPDX-FileCopyrightText: 2019-2025 Helmholtz Centre Potsdam GFZ German Research Centre for Geosciences
# SPDX-FileCopyrightText: 2020-2021 Helmholtz-Zentrum Geesthacht GmbH
# SPDX-FileCopyrightText: 2021-2025 Helmholtz-Zentrum hereon GmbH
#
# SPDX-License-Identifier: Apache-2.0
"""Transform a python class into a corresponding pydantic model.
The :class:`BackendClass` model in this module generates subclasses based upon
a python class (similarly as the
:class:`~demessaging.backend.function.BackendFunction` does it for functions).
"""
from __future__ import annotations
import inspect
from textwrap import dedent
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import docstring_parser
from pydantic import Field # pylint: disable=no-name-in-module
from pydantic import BaseModel, create_model
from pydantic.json_schema import JsonSchemaValue
from demessaging.backend import utils
from demessaging.backend.function import (
BackendFunction,
BackendFunctionConfig,
FunctionAPIModel,
ReturnModel,
)
from demessaging.config import ClassConfig
from demessaging.utils import append_parameter_docs, merge_config
[docs]
@append_parameter_docs
class BackendClassConfig(ClassConfig):
"""Configuration class for a backend module class."""
models: Dict[str, Type[BackendFunction]] = Field(
default_factory=dict,
description=(
"Mapping of method name to the function model for the "
"methods of this class"
),
)
Class: Type[object] = Field(
description="The class that corresponds to this config."
)
class_name: str = Field(description="Name of the model class")
@property
def method_configs(self) -> List[BackendFunctionConfig]:
"""Get a list of the method configs."""
return [model.backend_config for model in self.models.values()]
[docs]
def update_from_cls(self) -> None:
"""Update the config from the corresponding function."""
Class = self.Class
if not self.name:
self.name = Class.__name__ or ""
if not self.doc:
self.doc = dedent(inspect.getdoc(Class) or "")
if not self.init_doc:
self.init_doc = dedent(inspect.getdoc(Class.__init__) or "")
if not self.signature:
self.signature = inspect.signature(Class.__init__)
[docs]
class ClassAPIModel(BaseModel):
"""A class in the API suitable for RPC via DASF"""
name: str = Field(
description="The name of the class that is used as identifier in the RPC."
)
rpc_schema: JsonSchemaValue = Field(
description="The JSON Schema for the constructor of the class."
)
methods: List[FunctionAPIModel] = Field(
description="The list of methods that this class provides."
)
[docs]
@append_parameter_docs
class BackendClass(BaseModel):
"""A basis for class models
Do not directly instantiate from this class, rather use the
:meth:`create_model` method.
"""
backend_config: ClassVar[BackendClassConfig]
@property
def return_model(self) -> Type[BaseModel]:
"""The return model of the member function."""
return self.function.return_model
if TYPE_CHECKING:
# added properties for subclasses generated by create_model
function: BackendFunction
def __call__(self) -> ReturnModel:
kws = utils.get_kws(self.backend_config.signature, self)
func_kws = utils.get_kws(
self.function.backend_config.signature, self.function
)
ini: Any = self.backend_config.Class(**kws) # type: ignore
func_name = self.function.func_name
ret = getattr(ini, func_name)(**func_kws)
# now update the function model and return it
function: BackendFunction = self.function # type: ignore
return function.return_model.model_validate(ret) # type: ignore[return-value]
[docs]
@classmethod
def get_constructor_fields(
cls, Class, config: ClassConfig, class_name: Optional[str]
) -> Tuple[Dict[str, Any], BackendClassConfig]:
sig = inspect.signature(Class.__init__)
docstring = docstring_parser.parse(Class.__doc__)
init_docstring = docstring_parser.parse(Class.__init__.__doc__)
docstring.params.extend(init_docstring.params)
docstring.meta.extend(init_docstring.meta)
name = Class.__name__
if not class_name:
class_name = utils.snake_to_camel("Class", name)
config = BackendClassConfig(
Class=Class,
class_name=class_name,
**config.model_copy(deep=True).model_dump(),
)
fields = utils.get_fields(name, sig, docstring, config)
fields["class_name"] = fields.pop("func_name")
return fields, config
[docs]
@classmethod
def create_model(
cls,
Class,
config: Optional[ClassConfig] = None,
methods: Optional[
List[Union[Type[BackendFunction], Callable, str]]
] = None,
class_name: Optional[str] = None,
**kwargs: Any,
) -> Type[BackendClass]:
"""Generate a pydantic model from a class.
Parameters
----------
func: type
A class
config: ClassConfig, optional
The configuration to use. If given, this overrides the
``__pulsar_config__`` of the given `Class`
methods: list of methods, optional
A list of methods or model classes generated with
:func:`FunctionModel`. This overrides the methods in `config` or
the ``__pulsar_config__`` attribute of `Class`
class_name: str, optional
The name for the generated subclass of :class:`pydantic.BaseModel`.
If not given, the name of `Class` is used
``**kwargs``
Any other parameter for the :func:`pydantic.create_model` function
Returns
-------
Subclass of BackendClass
The newly generated model that represents this class.
"""
if config is None:
config = getattr(Class, "__pulsar_config__", ClassConfig())
config = cast(ClassConfig, config)
fields, config = cls.get_constructor_fields(Class, config, class_name)
class_name = cast(str, config.class_name)
name = Class.__name__
if "function" in fields:
raise ValueError(
f"`function` must not be an init parameter for {name}!"
)
if methods:
pass
elif config.methods:
methods = list(config.methods)
if not methods:
names_members = inspect.getmembers(
Class, predicate=inspect.isfunction
)
methods = [t[0] for t in names_members if not t[0].startswith("_")]
if not methods:
raise ValueError("No methods of the class have been specified!")
for method in methods:
if inspect.isclass(method) and issubclass(method, BackendFunction): # type: ignore # noqa: E501
method_name: str = method.backend_config.name # type: ignore
FuncModel: Type[BackendFunction] = cast(
Type[BackendFunction], method
)
elif callable(method):
method_name = cast(str, method.__name__)
FuncModel = BackendFunction.create_model(
cast(Callable, method),
class_name=utils.snake_to_camel(
"Meth", class_name, method_name
),
)
else:
method_name = method
FuncModel = BackendFunction.create_model(
getattr(Class, method_name),
class_name=utils.snake_to_camel(
"Meth", class_name, method_name
),
)
if method_name not in config.models:
config.models[method_name] = FuncModel
config.methods = list(config.models)
models = list(config.models.values())
function_types = models[0]
for model in models[1:]:
function_types = Union[function_types, model] # type: ignore
fields["function"] = (
function_types,
Field(description="The method to call."),
)
kwargs.update(fields)
desc = utils.get_desc(docstring_parser.parse(Class.__doc__))
Model: Type[BackendClass] = create_model( # type: ignore
class_name,
__validators__=config.validators,
__module__=Class.__module__,
__base__=cls,
**kwargs, # type: ignore
)
Model.backend_config = config
config.Class = Class
config.update_from_cls()
if desc:
Model.__doc__ = desc
else:
Model.__doc__ = ""
return Model
@classmethod
def _get_constructor_model(cls) -> BaseModel:
"""A convenience method to create a pydantic model for __init__."""
fields, _ = cls.get_constructor_fields(
cls.backend_config.Class,
ClassConfig(
**cls.backend_config.model_dump(
exclude={"Class", "models", "class_name"}
)
),
cls.backend_config.class_name,
)
fields["__doc__"] = cls.__doc__
return create_model(
cls.backend_config.class_name,
__module__=cls.backend_config.Class.__module__,
**fields,
)
[docs]
@classmethod
def get_api_info(cls) -> ClassAPIModel:
"""Get the API info on the function."""
constructor_model = cls._get_constructor_model()
return ClassAPIModel(
name=cls.backend_config.name,
rpc_schema=constructor_model.model_json_schema(),
methods=[
method_model.get_api_info()
for method_model in cls.backend_config.models.values()
],
)
[docs]
@classmethod
def model_json_schema(cls, *args, **kwargs) -> Dict[str, Any]:
ret = super().model_json_schema(*args, **kwargs)
if cls.backend_config.json_schema_extra:
ret = merge_config(ret, cls.backend_config.json_schema_extra)
return ret
try:
ClassConfig.model_rebuild()
except AttributeError:
ClassConfig.update_forward_refs()