Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[skip ci] breaking changes of store and index apis #322

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
s
  • Loading branch information
ouonline committed Oct 29, 2024
commit 71f2df212fd0174e7b087bea7d698ac5e3163eac
3 changes: 2 additions & 1 deletion lazyllm/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 1,5 @@
from .registry import LazyLLMRegisterMetaClass, _get_base_cls_from_registry, Register
from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor
from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor, override
from .common import FlatList, Identity, ResultCollector, ArgsDict, CaseInsensitiveDict
from .common import ReprRule, make_repr, modify_repr
from .common import once_flag, call_once, once_wrapper, singleton
Expand Down Expand Up @@ -38,6 38,7 @@
'package',
'kwargs',
'arguments',
'override',

# option
'Option',
Expand Down
6 changes: 6 additions & 0 deletions lazyllm/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 14,12 @@
_F = typing.TypeVar("_F", bound=Callable[..., Any])
def final(f: _F) -> _F: return f

try:
from typing import override
except ImportError:
def override(func: Callable):
return func


class FlatList(list):
def absorb(self, item):
Expand Down
10 changes: 6 additions & 4 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 3,7 @@
from functools import wraps
from typing import Callable, Dict, List, Optional, Set, Union, Tuple
from lazyllm import LOG, config, once_wrapper
from lazyllm.common import override
from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser,
AdaptiveTransform, make_transform, TransformArgs)
from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, StoreBase
Expand All @@ -20,7 21,7 @@ class _FileNodeIndex(IndexBase):
def __init__(self):
self._file_node_map = {}

# override
@override
def update(self, nodes: List[DocNode]) -> None:
for node in nodes:
if node.group != LAZY_ROOT_NAME:
Expand All @@ -29,13 30,13 @@ def update(self, nodes: List[DocNode]) -> None:
if file_name:
self._file_node_map[file_name] = node

# override
@override
def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
# group_name is ignored
left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids}
self._file_node_map = left

# override
@override
def query(self, files: List[str]) -> List[DocNode]:
ret = []
for file in files:
Expand Down Expand Up @@ -329,7 330,8 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_
index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]:
self._lazy_init()

if not index_instance := self.store.get_index(type=index):
index_instance = self.store.get_index(type=index)
ouonline marked this conversation as resolved.
Show resolved Hide resolved
if not index_instance:
ouonline marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(f"index type '{index}' is not supported currently.")

self._dynamic_create_nodes(group_name, self.store)
Expand Down
52 changes: 27 additions & 25 deletions lazyllm/tools/rag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 7,9 @@
import numpy as np
from .component.bm25 import BM25
from lazyllm import LOG, config, ThreadPoolExecutor
from lazyllm.common import override
import pymilvus
from pymilvus.client.abstract import AnnSearchRequest, BaseRanker

# ---------------------------------------------------------------------------- #

Expand Down Expand Up @@ -85,15 87,15 @@ def wrapper(query, nodes, **kwargs):

return decorator(func) if func else decorator

# override
@override
def update(self, nodes: List[DocNode]) -> None:
pass

# override
@override
def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
pass

# override
@override
def query(
self,
query: str,
Expand Down Expand Up @@ -176,17 178,17 @@ def register_similarity(

class MilvusIndex(IndexBase):
class Field:
def __init__(self, name: str, data_type: pymilvus.DataType, index_type: str,
metric_type: str, index_params={}, dim: Optional[int] = None):
def __init__(self, name: str, data_type: pymilvus.DataType,
metric_type: str, index_type: Optional[str] = None,
index_params={}, dim: Optional[int] = None):
self.name = name
self.data_type = data_type
self.index_type = index_type
self.metric_type = metric_type
self.index_params = index_params
self.dim = dim

def __init__(self, embed: Dict[str, Callable],
group_fields: Dict[str, List[MilvusIndex.Field]],
def __init__(self, embed: Dict[str, Callable], group_fields: Dict[str, List[Field]],
uri: str, full_data_store: StoreBase):
self._embed = embed
self._full_data_store = full_data_store
Expand Down Expand Up @@ -219,15 221,15 @@ def __init__(self, embed: Dict[str, Callable],
self._client.create_collection(collection_name=group_name, schema=schema,
index_params=index_params)

# override
@override
def update(self, nodes: List[DocNode]) -> None:
parallel_do_embedding(self._embed, nodes)
for node in nodes:
data = node.embedding.copy()
data[self._primary_key] = node.uid
self._client.upsert(collection_name=node.group, data=data)

# override
@override
def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
if group_name:
self._client.delete(collection_name=group_name,
Expand All @@ -237,22 239,22 @@ def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
self._client.delete(collection_name=group_name,
filter=f'{self._primary_key} in {uids}')

# override
@override
def query(self,
query: str,
group_name: str,
embed_keys: Optional[List[str]] = None,
topk: int = 10,
req: AnnSearchRequest,
ranker: BaseRanker,
limit: int = 10,
timeout: Optional[float] = None,
**kwargs) -> List[DocNode]:
uids = set()
for embed_name in embed_keys:
embed_func = self._embed.get(embed_name)
query_embedding = embed_func(query)
results = self._client.search(collection_name=group_name, data=[query_embedding],
limit=topk, anns_field=embed_name)
if len(results) > 0:
# we have only one `data` for search() so there is only one result in `results`
for result in results[0]:
uids.update(result['id'])

return self._full_data_store.get_nodes(group_name, list(uids))
results = self._client.hybrid_search(
collection_name=group_name, reqs=[req], ranker=ranker, limit=limit,
timeout=timeout)
if len(results) != 1:
raise ValueError(f'return results size [{len(results)}] != 1')

uids = []
for record in results[0]:
uids.append(record['id'])

return self._full_data_store.get_group_nodes(group_name, uids)
85 changes: 33 additions & 52 deletions lazyllm/tools/rag/store.py
Original file line number Diff line number Diff line change
@@ -1,6 1,7 @@
from typing import Any, Dict, List, Optional
import chromadb
from lazyllm import LOG, config
from lazyllm.common import override
from chromadb.api.models.Collection import Collection
from .store_base import StoreBase
from .index_base import IndexBase
Expand All @@ -18,20 19,13 @@

class MapStore(StoreBase):
def __init__(self, node_groups: List[str]):
super().__init__()
# Dict[group_name, Dict[uuid, DocNode]]
self._group2docs: Dict[str, Dict[str, DocNode]] = {
group: {} for group in node_groups
}
self._name2index = {}

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
self._group2docs[node.group][node.uid] = node

self._update_indices(self._name2index, nodes)

# override
@override
def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]:
docs = self._group2docs.get(group_name)
if not docs:
Expand All @@ -47,35 41,29 @@ def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]:
ret.append(doc)
return ret

# override
def remove_nodes(self, group_name: str, uids: List[str] = None) -> None:
if uids:
docs = self._group2docs.get(group_name)
if docs:
self._remove_from_indices(self._name2index, uids)
for uid in uids:
docs.pop(uid, None)
else:
docs = self._group2docs.pop(group_name, None)
if docs:
self._remove_from_indices(self._name2index, [doc.uid for doc in docs])

# override
@override
def is_group_active(self, name: str) -> bool:
docs = self._group2docs.get(name)
return True if docs else False

# override
@override
def all_groups(self) -> List[str]:
return self._group2docs.keys()

# override
def register_index(self, type: str, index: IndexBase) -> None:
self._name2index[type] = index
@override
def _update_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
self._group2docs[node.group][node.uid] = node

# override
def get_index(self, type: str) -> Optional[IndexBase]:
return self._name2index.get(type)
@override
def _remove_nodes(self, group_name: str, uids: List[str] = None) -> None:
if uids:
docs = self._group2docs.get(group_name)
if docs:
for uid in uids:
docs.pop(uid, None)
else:
self._group2docs.pop(group_name, None)

def find_node_by_uid(self, uid: str) -> Optional[DocNode]:
for docs in self._group2docs.values():
Expand All @@ -90,6 78,7 @@ class ChromadbStore(StoreBase):
def __init__(
self, node_groups: List[str], embed_dim: Dict[str, int]
) -> None:
super().__init__()
self._map_store = MapStore(node_groups)
self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"])
LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}")
Expand All @@ -99,38 88,30 @@ def __init__(
}
self._embed_dim = embed_dim

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
self._save_nodes(nodes)

# override
@override
def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]:
return self._map_store.get_nodes(group_name, uids)

# override
def remove_nodes(self, group_name: str, uids: List[str]) -> None:
if uids:
self._delete_group_nodes(group_name, uids)
else:
self._db_client.delete_collection(name=group_name)
return self._map_store.remove_nodes(group_name, uids)

# override
@override
def is_group_active(self, name: str) -> bool:
return self._map_store.is_group_active(name)

# override
@override
def all_groups(self) -> List[str]:
return self._map_store.all_groups()

# override
def register_index(self, type: str, index: IndexBase) -> None:
self._map_store.register_index(type, index)
@override
def _update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
self._save_nodes(nodes)

# override
def get_index(self, type: str) -> Optional[IndexBase]:
return self._map_store.get_index(type)
@override
def _remove_nodes(self, group_name: str, uids: List[str]) -> None:
if uids:
self._delete_group_nodes(group_name, uids)
else:
self._db_client.delete_collection(name=group_name)
return self._map_store.remove_nodes(group_name, uids)

def _load_store(self) -> None:
if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]:
Expand Down
4 changes: 2 additions & 2 deletions lazyllm/tools/rag/store_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 42,12 @@ def _remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> No
pass

@staticmethod
def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None:
def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None:
for _, index in name2index.items():
index.update(nodes)

@staticmethod
def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str],
def _remove_from_indices(name2index: Dict[str, IndexBase], uids: List[str],
group_name: Optional[str] = None) -> None:
for _, index in name2index.items():
index.remove(uids, group_name)
2 changes: 1 addition & 1 deletion tests/basic_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 78,7 @@ def test_insert_dict_as_sparse_embedding(self):
node1.uid: [0, 10, 20],
node2.uid: [30, 0, 50],
}
self.store.add_nodes([node1, node2])
self.store.update_nodes([node1, node2])

results = self.store._peek_all_documents('group1')
nodes = self.store._build_nodes_from_chroma(results)
Expand Down