diff --git a/mtree4.py b/mtree4.py index a91f81b..3df4fb1 100644 --- a/mtree4.py +++ b/mtree4.py @@ -75,6 +75,9 @@ def limit(self): heapq.heappop(self.pq) return float("inf") if len(self.items) < self.k else -self.pq[0][0] + def sorted(self) -> list[Item]: + return [item for item in sorted(self.pq) if item in self.items] + class DistanceFunction: def __init__(self, fn: Callable = default_distance) -> None: @@ -130,6 +133,27 @@ def __repr__(self): def __iter__(self) -> Iterable[Value]: yield from self.root + def knn(self, value: Value, k: int) -> list[Value]: + assert k >= 0 + if not k: + return [] + if k >= self.length: + return sorted(self, key=partial(self.distance_function, value)) + results = LimitedSet(k) + pq = PriorityQueue() + pq.push(self.root.min_distance(value), self.root) + while pq: + node: ParentNode + min_distance_q, node = pq.pop() + results.discard(node) + for child_node in node.children: + # if abs(node.distance(value) - node.distance(child_node.router)) - child_node.radius <= results.limit(): + # if child_node.min_distance(value) <= results.limit(): + if isinstance(child_node, ParentNode): + pq.push(child_node.min_distance(value), child_node) + results.add(child_node.max_distance(value), child_node) + return results.sorted() + class Node(Generic[Value]): def __init__(self, tree: MTree[Value], router=Value) -> None: @@ -144,6 +168,12 @@ def distance(self, item: Union[Node, Value]) -> Distance: return self.distance_function(self.router, item.router) + item.radius return self.distance_function(self.router, item) + def min_distance(self, value: Value) -> Distance: + return max(0, self.distance_function(self.router, value) - self.radius) + + def max_distance(self, value: Value) -> Distance: + return self.distance_function(self.router, value) + self.radius + class ValueNode(Node[Value]): def __repr__(self): diff --git a/tests_prop.py b/tests_prop.py index f919f1d..ad4b93c 100644 --- a/tests_prop.py +++ b/tests_prop.py @@ -1,3 +1,5 @@ +from functools import partial +from heapq import nsmallest from typing import Any, Iterable from hypothesis import given, strategies as st @@ -63,3 +65,13 @@ def test_sufficient_radius(values, cap): for node in get_nodes(tree.root, ParentNode): for value in node: assert node.distance(value) <= node.radius + + +class TestKnn: + @given(st.sets(st.text()), st.text(), st.integers(0)) + def test_x(self, values, needle, k): + tree = MTree(values) + res = tree.knn(needle, k) + needle_distance = partial(tree.distance_function, needle) + assert res == sorted(res, key=needle_distance) + x = nsmallest(k, values, key=needle_distance)