Skip to content

Commit

Permalink
Refactor Dataset API (infiniflow#2783)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Refactor Dataset API

### Type of change

- [x] Refactoring

---------

Co-authored-by: liuhua <[email protected]>
  • Loading branch information
Feiue and liuhua authored Oct 11, 2024
1 parent a2f9c03 commit cbd7cd7
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 393 deletions.
2 changes: 1 addition & 1 deletion api/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 83,7 @@ def register_page(page_path):
sys.modules[module_name] = page
spec.loader.exec_module(page)
page_name = getattr(page, 'page_name', page_name)
url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'
url_prefix = f'/api/{API_VERSION}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}'

app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix
Expand Down
299 changes: 128 additions & 171 deletions api/apps/sdk/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,143 25,146 @@
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, token_required, get_data_error_result
from api.utils.api_utils import get_result, token_required,get_error_data_result


@manager.route('/save', methods=['POST'])
@manager.route('/dataset', methods=['POST'])
@token_required
def save(tenant_id):
def create(tenant_id):
req = request.json
e, t = TenantService.get_by_id(tenant_id)
if "id" not in req:
if "tenant_id" in req or "embedding_model" in req:
return get_data_error_result(
retmsg="Tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_data_error_result(
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = req['created_by'] = tenant_id
req['embedding_model'] = t.embd_id
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "parse_method",
"embd_id": "embedding_model"
}
mapped_keys = {new_key: req[old_key] for new_key, old_key in key_mapping.items() if old_key in req}
req.update(mapped_keys)
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Create dataset error.(Database error)")
renamed_data = {}
e, k = KnowledgebaseService.get_by_id(req["id"])
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
return get_json_result(data=renamed_data)
else:
invalid_keys = {"embd_id", "chunk_num", "doc_num", "parser_id"}
if any(key in req for key in invalid_keys):
return get_data_error_result(retmsg="The input parameters are invalid.")

if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_data_error_result(
retmsg="Can't change tenant_id.")

if "embedding_model" in req:
if req["embedding_model"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change embedding_model.")
req.pop("embedding_model")

if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)

if not req["id"]:
return get_data_error_result(
retmsg="id can not be empty.")
e, kb = KnowledgebaseService.get_by_id(req["id"])

if "chunk_count" in req:
if req["chunk_count"] != kb.chunk_num:
return get_data_error_result(
retmsg="Can't change chunk_count.")
req.pop("chunk_count")

if "document_count" in req:
if req['document_count'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")
req.pop("document_count")

if "parse_method" in req:
if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
return get_data_error_result(
retmsg="If chunk count is not 0, parse method is not changable.")
req['parser_id'] = req.pop('parse_method')
if "name" in req:
req["name"] = req["name"].strip()
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")

del req["id"]
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True)

if "tenant_id" in req or "embedding_model" in req:
return get_error_data_result(
retmsg="Tenant_id or embedding_model must not be provided")
chunk_count=req.get("chunk_count")
document_count=req.get("document_count")
if chunk_count or document_count:
return get_error_data_result(retmsg="chunk_count or document_count must be 0 or not be provided")
if "name" not in req:
return get_error_data_result(
retmsg="Name is not empty!")
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_error_data_result(
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = req['created_by'] = tenant_id
req['embedding_model'] = t.embd_id
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "parse_method",
"embd_id": "embedding_model"
}
mapped_keys = {new_key: req[old_key] for new_key, old_key in key_mapping.items() if old_key in req}
req.update(mapped_keys)
if not KnowledgebaseService.save(**req):
return get_error_data_result(retmsg="Create dataset error.(Database error)")
renamed_data = {}
e, k = KnowledgebaseService.get_by_id(req["id"])
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
return get_result(data=renamed_data)

@manager.route('/delete', methods=['DELETE'])
@manager.route('/dataset', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
if "id" not in req:
return get_data_error_result(
retmsg="id is required")
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)

for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)

if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database serror)")
return get_json_result(data=True)


@manager.route('/list', methods=['GET'])
req = request.json
names=req.get("names")
ids = req.get("ids")
if not ids and not names:
return get_error_data_result(
retmsg="ids or names is required")
id_list=[]
if names:
for name in names:
kbs=KnowledgebaseService.query(name=name,tenant_id=tenant_id)
if not kbs:
return get_error_data_result(retmsg=f"You don't own the dataset {name}")
id_list.append(kbs[0].id)
if ids:
for id in ids:
kbs=KnowledgebaseService.query(id=id,tenant_id=tenant_id)
if not kbs:
return get_error_data_result(retmsg=f"You don't own the dataset {id}")
id_list.extend(ids)
for id in id_list:
for doc in DocumentService.query(kb_id=id):
if not DocumentService.remove_document(doc, tenant_id):
return get_error_data_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(
retmsg="Delete dataset error.(Database serror)")
return get_result(retcode=RetCode.SUCCESS)

@manager.route('/dataset/<dataset_id>', methods=['PUT'])
@token_required
def update(tenant_id,dataset_id):
if not KnowledgebaseService.query(id=dataset_id,tenant_id=tenant_id):
return get_error_data_result(retmsg="You don't own the dataset")
req = request.json
e, t = TenantService.get_by_id(tenant_id)
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id"}
if any(key in req for key in invalid_keys):
return get_error_data_result(retmsg="The input parameters are invalid.")
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_error_data_result(
retmsg="Can't change tenant_id.")
if "embedding_model" in req:
if req["embedding_model"] != t.embd_id:
return get_error_data_result(
retmsg="Can't change embedding_model.")
req.pop("embedding_model")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if "chunk_count" in req:
if req["chunk_count"] != kb.chunk_num:
return get_error_data_result(
retmsg="Can't change chunk_count.")
req.pop("chunk_count")
if "document_count" in req:
if req['document_count'] != kb.doc_num:
return get_error_data_result(
retmsg="Can't change document_count.")
req.pop("document_count")
if "parse_method" in req:
if kb.chunk_num != 0 and req['parse_method'] != kb.parser_id:
return get_error_data_result(
retmsg="If chunk count is not 0, parse method is not changable.")
req['parser_id'] = req.pop('parse_method')
if "name" in req:
req["name"] = req["name"].strip()
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_error_data_result(
retmsg="Duplicated knowledgebase name in updating dataset.")
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(retmsg="Update dataset error.(Database error)")
return get_result(retcode=RetCode.SUCCESS)

@manager.route('/dataset', methods=['GET'])
@token_required
def list_datasets(tenant_id):
def list(tenant_id):
id = request.args.get("id")
name = request.args.get("name")
kbs = KnowledgebaseService.query(id=id,name=name,status=1)
if not kbs:
return get_error_data_result(retmsg="The dataset doesn't exist")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
kbs = KnowledgebaseService.get_list(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc, id, name)
renamed_list = []
for kb in kbs:
key_mapping = {
Expand All @@ -175,50 178,4 @@ def list_datasets(tenant_id):
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
renamed_list.append(renamed_data)
return get_json_result(data=renamed_list)


@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "parse_method",
"embd_id": "embedding_model"
}
renamed_data = {}
if "id" in req:
id = req["id"]
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
if not kb:
return get_json_result(
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if kb[0].name != name:
return get_json_result(
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
return get_json_result(data=renamed_data)
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
return get_json_result(data=renamed_data)
else:
return get_data_error_result(
retmsg="At least one of `id` or `name` must be provided.")
return get_result(data=renamed_list)
24 changes: 24 additions & 0 deletions api/db/services/knowledgebase_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 142,27 @@ def get_by_name(cls, kb_name, tenant_id):
@DB.connection_context()
def get_all_ids(cls):
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]

@classmethod
@DB.connection_context()
def get_list(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, id , name):
kbs = cls.model.select()
if id:
kbs = kbs.where(cls.model.id == id)
if name:
kbs = kbs.where(cls.model.name == name)
kbs = kbs.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())

kbs = kbs.paginate(page_number, items_per_page)

return list(kbs.dicts())
Loading

0 comments on commit cbd7cd7

Please sign in to comment.