Skip to content

Commit

Permalink
Merge pull request #14 from cadurosar/main
Browse files Browse the repository at this point in the history
Fix small bugs on evaluation
  • Loading branch information
thibault-formal authored Jun 14, 2022
2 parents db3cb52 a43abb0 commit b346a27
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/tasks/transformer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 28,7 @@ def __init__(self, model, config, compute_stats=False, dim_voc=None, is_query=Fa
if self.compute_stats:
self.l0 = L0()

def index(self, collection_loader):
def index(self, collection_loader, id_dict=None):
doc_ids = []
if self.compute_stats:
stats = defaultdict(float)
Expand All @@ -46,6 46,8 @@ def index(self, collection_loader):
data = batch_documents[row, col]
row = row count
batch_ids = to_list(batch["id"])
if id_dict:
batch_ids = [id_dict[x] for x in batch_ids]
count = len(batch_ids)
doc_ids.extend(batch_ids)
self.sparse_index.add_batch_document(row.cpu().numpy(), col.cpu().numpy(), data.cpu().numpy(),
Expand Down
2 changes: 1 addition & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 90,7 @@ def get_dataset_name(path):
return "TREC_DL_2019"
elif "trec2020" in path or "TREC_DL_2020" in path:
return "TREC_DL_2020"
elif "MSMARCO" in path:
elif "msmarco" in path:
if "train_queries" in path:
return "MSMARCO_TRAIN"
else:
Expand Down

0 comments on commit b346a27

Please sign in to comment.