Skip to content

Commit

Permalink
fix: fix the unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Jun 3, 2020
1 parent 73d17ba commit 0c18fdb
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 17,22 @@
from tests import JinaTestCase


def random_docs(num_docs, chunks_per_doc=5, embed_dim=None, field_name=''):
c_id = 0
def random_docs(num_docs, chunks_per_doc=5, embed_dim=None, field_name='', beg_id=0):
c_id = beg_id
for j in range(num_docs):
d = jina_pb2.Document()
d.meta_info = b'hello world'
d.doc_id = beg_id j
for k in range(chunks_per_doc):
c = d.chunks.add()
if isinstance(embed_dim, int):
c.embedding.CopyFrom(array2pb(np.random.random([embed_dim])))
else:
c.text = 'i\'m chunk %d from doc %d with field_name %s' % (c_id, j, c.field_name)
c.chunk_id = c_id
c.doc_id = j
c.doc_id = d.doc_id
c.field_name = field_name
c_id = 1
d.meta_info = b'hello world'
yield d


Expand Down Expand Up @@ -270,9 271,9 @@ def validate(rsp):
self.assertEqual(len(rsp.docs[0].topk_results), num_docs)
f = Flow().add(name='idx', yaml_path=yaml_path, copy_flow=False)
with f:
f.index(input_fn=random_docs(num_docs, num_chunks, embed_dim=10, field_name=filter_by),
f.index(input_fn=random_docs(num_docs, num_chunks, embed_dim=10, field_name=filter_by, beg_id=10),
random_doc_id=False)
f.index(input_fn=random_docs(num_docs, num_chunks, embed_dim=10, field_name='summary'),
f.index(input_fn=random_docs(num_docs, num_chunks, embed_dim=10, field_name='summary', beg_id=20),
random_doc_id=False)
fq = (Flow().add(name='idx', yaml_path=yaml_path, copy_flow=False)
.add(name='ranker', yaml_path='MinRanker', copy_flow=False))
Expand Down Expand Up @@ -301,8 302,8 @@ def validate(rsp):
f = (Flow().add(name='enc', yaml_path=encoder_yml, copy_flow=False)
.add(name='idx', yaml_path=indexer_yml, copy_flow=False))
with f:
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by), random_doc_id=False)
f.index(input_fn=random_docs(num_docs, num_chunks, field_name='summary'), random_doc_id=False)
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by, beg_id=10))
f.index(input_fn=random_docs(num_docs, num_chunks, field_name='summary', beg_id=20))

fq = (Flow().add(name='enc', yaml_path=encoder_yml, copy_flow=False)
.add(name='idx', yaml_path=indexer_yml, copy_flow=False)
Expand Down Expand Up @@ -339,12 340,13 @@ def validate(rsp):
.join(needs=['title_idx', 'summary_idx']))

with f:
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by), random_doc_id=True)
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by_2), random_doc_id=True)
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by, beg_id=10), random_doc_id=True)
f.index(input_fn=random_docs(num_docs, num_chunks, field_name=filter_by_2, beg_id=20), random_doc_id=True)

fq = (Flow().add(name='enc', yaml_path='OneHotTextEncoder', copy_flow=False)
.add(name='title_idx', yaml_path=indexer_yml, copy_flow=False)
.add(name='summary_idx', yaml_path=indexer_yml_2, copy_flow=False)
.add(name='summary_idx', yaml_path=indexer_yml_2, copy_flow=False, needs='enc')
.add(name='join', yaml_path='_merge_topk_chunks', needs=['title_idx', 'summary_idx'])
.add(name='ranker', yaml_path='MinRanker', copy_flow=False))

with fq:
Expand All @@ -353,7 355,7 @@ def validate(rsp):
random_doc_id=False,
output_fn=validate,
callback_on_body=True,
filter_by=filter_by)
filter_by=[filter_by, filter_by_2])
except BadClient as e:
self.assertTrue(False, e)
finally:
Expand Down

0 comments on commit 0c18fdb

Please sign in to comment.