Source code for demessaging.backend.utils

# SPDX-FileCopyrightText: 2019-2025 Helmholtz Centre Potsdam GFZ German Research Centre for Geosciences
# SPDX-FileCopyrightText: 2020-2021 Helmholtz-Zentrum Geesthacht GmbH
# SPDX-FileCopyrightText: 2021-2025 Helmholtz-Zentrum hereon GmbH
#
# SPDX-License-Identifier: Apache-2.0

"""Utitlity functions for the backend framework."""
from __future__ import annotations

import asyncio
import inspect
import re
import threading
import unicodedata
import warnings
from itertools import chain, starmap
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, Type, Union

from deprogressapi import BaseReport
from pydantic import Field  # pylint: disable=no-name-in-module
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated

try:
    from typing import Literal, get_args, get_origin
except ImportError:
    from typing_extensions import Literal, get_args, get_origin  # type: ignore


if TYPE_CHECKING:
    import docstring_parser
    import isort.identify

    from demessaging.config import ClassConfig, FunctionConfig


[docs] def get_kws(sig, obj) -> Dict[str, Any]: """Get keywords from a signature and a base model.""" return { param: getattr(obj, param) for param in sig.parameters.keys() if param != "self" }
[docs] def get_fields( name: str, sig: inspect.Signature, docstring: docstring_parser.Docstring, config: Union[FunctionConfig, ClassConfig], ) -> Dict[str, Tuple[Any, Any]]: """Get the model fields from a function signature. Parameters ---------- name: str The name of the function or class sig: inspect.Signature The signature of the callable docstring: docstring_parser.Docstring The parser that analyzed the docstring config: FunctionConfig or ClassConfig The configuration for the callable Returns ------- dict A mapping from field name to field parameters to be used in :func:`pydantic.create_model`. """ fields: Dict[str, Tuple[Any, Any]] = { "func_name": ( Literal[name], # type: ignore Field(description=f"The name of the function. Must be {name!r}"), ), } for key, param in sig.parameters.items(): if key == "self": continue field_kws: Dict[str, Any] = {} if param.default is not param.empty: field_kws["default"] = param.default param_doc = next( (p for p in docstring.params if p.arg_name == key), None ) if param_doc is not None: field_kws["description"] = param_doc.description if key in config.field_params: field_kws.update(config.field_params[key]) if key in config.annotations: annotation = config.annotations[key] elif param.annotation is param.empty: warnings.warn( f"Missing signature for {key}, so no validation will " "be made for this parameter!", RuntimeWarning, ) annotation = Any else: annotation = param.annotation if key in config.serializers: serializer = PlainSerializer( config.serializers[key], return_type=str, when_used="json" ) if key not in config.annotations and config.validators.get(key): annotation = Any # we use annotation = Annotated[annotation, serializer] # test for dasf-progress-api reports and add them to the config if param.annotation is not param.empty and _is_progress_report( param.annotation ): if param.default is not param.empty: config.reporter_args[key] = param.default else: config.reporter_args[key] = param.annotation() field_kws["json_schema_extra"] = {"is_reporter": True} fields[key] = (annotation, Field(**field_kws)) # type: ignore return fields
def _is_progress_report(cls_: Type) -> bool: if get_origin(cls_): # we do have a Union-type return any( inspect.isclass(c) and issubclass(c, BaseReport) for c in get_args(cls_) ) elif inspect.isclass(cls_): return issubclass(cls_, BaseReport) else: return False
[docs] def get_desc(docstring: docstring_parser.Docstring) -> str: """Get the description of an object. Parameters ---------- docstring: docstring_parser.Docstring The parser that analyzed the docstring. Returns ------- str The description of the callable. """ desc = "" if docstring.short_description: desc += docstring.short_description if docstring.long_description: if docstring.blank_after_short_description: desc += "\n\n" else: desc += "\n" desc += docstring.long_description return desc.strip()
[docs] def camelize(w: str) -> str: """Camelize a word by making the first letter upper case.""" return w and (w[:1].upper() + w[1:])
[docs] def snake_to_camel(*words: str) -> str: """Transform a list of words into its camelized version.""" return "".join( map(camelize, chain.from_iterable(w.split("_") for w in words)) )
[docs] def slugify(value: str, allow_unicode: bool = False) -> str: """ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated dashes to single dashes. Remove characters that aren't alphanumerics, underscores, or hyphens. Convert to lowercase. Also strip leading and trailing whitespace, dashes, and underscores. Notes ----- taken from https://github.com/django/django/blob/3cadeea077a98367a4ed344d645df0aff243de91/django/utils/text.py """ value = str(value) if allow_unicode: value = unicodedata.normalize("NFKC", value) else: value = ( unicodedata.normalize("NFKD", value) .encode("ascii", "ignore") .decode("ascii") ) value = re.sub(r"[^\w\s-]", "", value.lower()) return re.sub(r"[-\s]+", "-", value).strip("-_")
[docs] class AsyncIoThread(threading.Thread): """A thread that runs an async function. See: func: `run_async` for the implementation.""" def __init__(self, func: Callable, args: Tuple, kwargs: Dict): self.__func = func self.__args = args self.__kwargs = kwargs super().__init__()
[docs] def run(self): self.result = asyncio.run(self.__func(*self.__args, **self.__kwargs))
[docs] def run_async(func: Callable, *args: Any, **kwargs: Any) -> Any: """Run an async function and wait for the result. This function works within standard python scripts, and during a running jupyter session.""" # check if we have a running loop (which is the case for a jupyter # notebook) loop: Any try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): # jupyter notebook thread = AsyncIoThread(func, args, kwargs) thread.start() thread.join() # wait for the thread to finish return thread.result else: # standard python script return asyncio.run(func(*args, **kwargs))
[docs] class ImportMixin: """Mixin class for :class:`isort.identify.Import`. A response to https://github.com/PyCQA/isort/issues/1641. """
[docs] def statement(self: isort.identify.Import) -> str: # type: ignore import_cmd = "cimport" if self.cimport else "import" if self.attribute: import_string = f"from {self.module} {import_cmd} {self.attribute}" else: import_string = f"{import_cmd} {self.module}" if self.alias: import_string += f" as {self.alias}" return import_string
[docs] def get_module_imports(mod: Any) -> str: """Get all the imports from a module Parameters ---------- mod: module The module to use """ try: from isort.api import find_imports_in_code from isort.identify import Import as BaseImport except (ImportError, ModuleNotFoundError): return "" code = inspect.getsource(mod) class Import(ImportMixin, BaseImport): pass # We could use the :meth:`~isort.identify.Import.statement` method here, # but this would not work always (see # https://github.com/PyCQA/isort/issues/1641) imports = starmap(Import, find_imports_in_code(code)) return "\n".join(i.statement() for i in imports)