Skip to content

Commit

Permalink
Tests on knn
Browse files Browse the repository at this point in the history
  • Loading branch information
travisjungroth committed Dec 19, 2022
1 parent 565bc82 commit 19d3adb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
30 changes: 30 additions & 0 deletions mtree4.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 75,9 @@ def limit(self):
heapq.heappop(self.pq)
return float("inf") if len(self.items) < self.k else -self.pq[0][0]

def sorted(self) -> list[Item]:
return [item for item in sorted(self.pq) if item in self.items]


class DistanceFunction:
def __init__(self, fn: Callable = default_distance) -> None:
Expand Down Expand Up @@ -130,6 133,27 @@ def __repr__(self):
def __iter__(self) -> Iterable[Value]:
yield from self.root

def knn(self, value: Value, k: int) -> list[Value]:
assert k >= 0
if not k:
return []
if k >= self.length:
return sorted(self, key=partial(self.distance_function, value))
results = LimitedSet(k)
pq = PriorityQueue()
pq.push(self.root.min_distance(value), self.root)
while pq:
node: ParentNode
min_distance_q, node = pq.pop()
results.discard(node)
for child_node in node.children:
# if abs(node.distance(value) - node.distance(child_node.router)) - child_node.radius <= results.limit():
# if child_node.min_distance(value) <= results.limit():
if isinstance(child_node, ParentNode):
pq.push(child_node.min_distance(value), child_node)
results.add(child_node.max_distance(value), child_node)
return results.sorted()


class Node(Generic[Value]):
def __init__(self, tree: MTree[Value], router=Value) -> None:
Expand All @@ -144,6 168,12 @@ def distance(self, item: Union[Node, Value]) -> Distance:
return self.distance_function(self.router, item.router) item.radius
return self.distance_function(self.router, item)

def min_distance(self, value: Value) -> Distance:
return max(0, self.distance_function(self.router, value) - self.radius)

def max_distance(self, value: Value) -> Distance:
return self.distance_function(self.router, value) self.radius


class ValueNode(Node[Value]):
def __repr__(self):
Expand Down
12 changes: 12 additions & 0 deletions tests_prop.py
Original file line number Diff line number Diff line change
@@ -1,3 1,5 @@
from functools import partial
from heapq import nsmallest
from typing import Any, Iterable

from hypothesis import given, strategies as st
Expand Down Expand Up @@ -63,3 65,13 @@ def test_sufficient_radius(values, cap):
for node in get_nodes(tree.root, ParentNode):
for value in node:
assert node.distance(value) <= node.radius


class TestKnn:
@given(st.sets(st.text()), st.text(), st.integers(0))
def test_x(self, values, needle, k):
tree = MTree(values)
res = tree.knn(needle, k)
needle_distance = partial(tree.distance_function, needle)
assert res == sorted(res, key=needle_distance)
x = nsmallest(k, values, key=needle_distance)

0 comments on commit 19d3adb

Please sign in to comment.