From 94494560146b8c91c0d8e8c175c2cf68eeff7657 Mon Sep 17 00:00:00 2001 From: Krishna Penukonda Date: Mon, 24 Aug 2020 09:29:40 +0800 Subject: [PATCH] Added pipeline components to global registry --- lowpolypy/point_generators.py | 4 +++ lowpolypy/polygonizers.py | 3 ++ lowpolypy/shaders.py | 4 +++ lowpolypy/utils/__init__.py | 55 +++++++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+) create mode 100644 lowpolypy/utils/__init__.py diff --git a/lowpolypy/point_generators.py b/lowpolypy/point_generators.py index d8ef8c0..6e4ff0d 100644 --- a/lowpolypy/point_generators.py +++ b/lowpolypy/point_generators.py @@ -7,6 +7,8 @@ import shapely from shapely.geometry import box, Point, MultiPoint, asMultiPoint +from .utils import registry + class PointGenerator(metaclass=ABCMeta): """ @@ -71,6 +73,7 @@ def remove_duplicates(points, tolerance): return points +@registry.register("PointGenerator", "RandomPoints") class RandomPoints(PointGenerator): def __init__(self, num_points=100): super().__init__() @@ -82,6 +85,7 @@ def forward(self, image): return points +@registry.register("PointGenerator", "ConvPoints") class ConvPoints(PointGenerator): def __init__(self, num_points=1000, num_filler_points=50, weight_filler_points=True): super().__init__() diff --git a/lowpolypy/polygonizers.py b/lowpolypy/polygonizers.py index 898ad4c..3e75fe6 100644 --- a/lowpolypy/polygonizers.py +++ b/lowpolypy/polygonizers.py @@ -3,6 +3,8 @@ from shapely.ops import triangulate from shapely.geometry import MultiPoint +from .utils import registry + class Polygonizer(metaclass=ABCMeta): """ @@ -24,6 +26,7 @@ def simplify(polygons): return polygons +@registry.register("Polygonizer", "DelaunayTriangulator") class DelaunayTriangulator(Polygonizer): def __init__(self): super().__init__() diff --git a/lowpolypy/shaders.py b/lowpolypy/shaders.py index a0ac0b0..487beca 100644 --- a/lowpolypy/shaders.py +++ b/lowpolypy/shaders.py @@ -4,6 +4,8 @@ from PIL import Image from abc import ABCMeta, abstractmethod +from .utils import registry + class Shader(metaclass=ABCMeta): """ @@ -18,6 +20,7 @@ def __call__(self, image, points, polygons, *args, **kwargs): return self.forward(image, points, polygons, *args, **kwargs) +@registry.register("Shader", "MeanShader") class MeanShader(Shader): def __init__(self): super().__init__() @@ -35,6 +38,7 @@ def forward(self, image, points, polygons): return Image.fromarray(shaded) +@registry.register("Shader", "KmeansShader") class KmeansShader(Shader): def __init__(self, num_clusters=3, num_attempts=3): super().__init__() diff --git a/lowpolypy/utils/__init__.py b/lowpolypy/utils/__init__.py new file mode 100644 index 0000000..c39552a --- /dev/null +++ b/lowpolypy/utils/__init__.py @@ -0,0 +1,55 @@ +from typing import Type + + +class registry: + registries = {} + + @classmethod + def register(cls, registry_name: str, key: str): + """ + Decorator for adding an entry to a registry. + + Args: + registry_name: Name of the registry to add the entry to + key: Name to file the entry under + value: The value of the entry + """ + + def inner(obj): + registry = cls.registries.setdefault(registry_name, {}) + registry[key] = obj + return obj + + return inner + + @classmethod + def get(cls, registry_name: str, key: str, allow_passthrough=True): + """ + Get an element from a registry + + Args: + registry_name: Name of the registry. + key: Entry key in the specified registry to retrieve. + allow_passthrough: If True, then if `key` is not a key in the specified registry but is present as a value in the registry, `key` is returned. + """ + try: + registry = cls.registries[registry_name] + except KeyError as e: + raise KeyError(f"No such registry: '{registry_name}'") from e + try: + return registry[key] + except KeyError as e: + if key in registry.values() and allow_passthrough: + return key + raise KeyError( + f"Couldn't find '{key}' in registry '{registry_name}'" + ) from e + + # TODO: Implement custom registration the right way + # @classmethod + # def register_model(cls, name: str, model_class: Type): + # return cls.register("model", name)(model_class) + # + # @classmethod + # def register_dataset(cls, name: str, dataset_class: Type): + # return cls.register("dataset", name)(dataset_class)