Skip to content

Commit

Permalink
PQ4 fast scan benchmarks (facebookresearch#1555)
Browse files Browse the repository at this point in the history
Summary:
Code   scripts for Faiss benchmarks around the  Fast scan codes.

Pull Request resolved: facebookresearch#1555

Test Plan: buck test //faiss/tests/:test_refine

Reviewed By: wickedfoo

Differential Revision: D25546505

Pulled By: mdouze

fbshipit-source-id: 902486b7f47e36221a2671d124df8c114f25db58
  • Loading branch information
mdouze authored and facebook-github-bot committed Dec 16, 2020
1 parent 90c891b commit c5975cd
Show file tree
Hide file tree
Showing 47 changed files with 2,463 additions and 853 deletions.
221 changes: 163 additions & 58 deletions benchs/bench_all_ivf/bench_all_ivf.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 1,3 @@
#!/usr/bin/env python2

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
Expand All @@ -8,12 6,15 @@
import os
import sys
import time
import pdb
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize



######################################################
# Command-line parsing
######################################################
Expand All @@ -34,8 35,8 @@ def aa(*args, **kwargs):
group = parser.add_argument_group('index consturction')

aa('--indexkey', default='HNSW32', help='index_factory type')
aa('--efConstruction', default=200, type=int,
help='HNSW construction factor')
aa('--by_residual', default=-1, type=int,
help="set if index should use residuals (default=unchanged)")
aa('--M0', default=-1, type=int, help='size of base level')
aa('--maxtrain', default=256 * 256, type=int,
help='maximum number of training points (0 to set automatically)')
Expand All @@ -54,6 55,8 @@ def aa(*args, **kwargs):
group = parser.add_argument_group('searching')

aa('--k', default=100, type=int, help='nb of nearest neighbors')
aa('--inter', default=False, action='store_true',
help='use intersection measure instead of 1-recall as metric')
aa('--searchthreads', default=-1, type=int,
help='nb of threads to use at search time')
aa('--searchparams', nargs=' ', default=['autotune'],
Expand All @@ -64,7 67,7 @@ def aa(*args, **kwargs):
help='set max value for autotune variables format "var:val" (exclusive)')
aa('--autotune_range', default=[], nargs='*',
help='set complete autotune range, format "var:val1,val2,..."')
aa('--min_test_duration', default=0, type=float,
aa('--min_test_duration', default=3.0, type=float,
help='run test at least for so long to avoid jitter')

args = parser.parse_args()
Expand All @@ -79,64 82,126 @@ def aa(*args, **kwargs):
# Load dataset
######################################################

xt, xb, xq, gt = datasets.load_data(
ds = datasets.load_dataset(
dataset=args.db, compute_gt=args.compute_gt)


print("dataset sizes: train %s base %s query %s GT %s" % (
xt.shape, xb.shape, xq.shape, gt.shape))
print(ds)

nq, d = xq.shape
nb, d = xb.shape
nq, d = ds.nq, ds.d
nb, d = ds.nq, ds.d


######################################################
# Make index
######################################################

def unwind_index_ivf(index):
if isinstance(index, faiss.IndexPreTransform):
assert index.chain.size() == 1
vt = index.chain.at(0)
index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index))
assert vt2 is None
return index_ivf, vt
if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine):
return unwind_index_ivf(faiss.downcast_index(index.base_index))
if isinstance(index, faiss.IndexIVF):
return index, None
else:
return None, None


if args.indexfile and os.path.exists(args.indexfile):

print("reading", args.indexfile)
index = faiss.read_index(args.indexfile)

if isinstance(index, faiss.IndexPreTransform):
index_ivf = faiss.downcast_index(index.index)
else:
index_ivf = index
assert isinstance(index_ivf, faiss.IndexIVF)
index_ivf, vec_transform = unwind_index_ivf(index)
if vec_transform is None:
vec_transform = lambda x: x
assert isinstance(index_ivf, faiss.IndexIVF)

else:

print("build index, key=", args.indexkey)

index = faiss.index_factory(d, args.indexkey)
index = faiss.index_factory(
d, args.indexkey, faiss.METRIC_L2 if ds.metric == "L2" else
faiss.METRIC_INNER_PRODUCT
)

if isinstance(index, faiss.IndexPreTransform):
index_ivf = faiss.downcast_index(index.index)
vec_transform = index.chain.at(0).apply_py
index_ivf, vec_transform = unwind_index_ivf(index)
if vec_transform is None:
vec_transform = lambda x: x
else:
vec_transform = faiss.downcast_VectorTransform(vec_transform)

if args.by_residual != -1:
by_residual = args.by_residual == 1
print("setting by_residual = ", by_residual)
index_ivf.by_residual # check if field exists
index_ivf.by_residual = by_residual


if index_ivf:
print("Update add-time parameters")
# adjust default parameters used at add time for quantizers
# because otherwise the assignment is inaccurate
quantizer = faiss.downcast_index(index_ivf.quantizer)
if isinstance(quantizer, faiss.IndexRefine):
print(" update quantizer k_factor=", quantizer.k_factor, end=" -> ")
quantizer.k_factor = 32 if index_ivf.nlist < 1e6 else 64
print(quantizer.k_factor)
base_index = faiss.downcast_index(quantizer.base_index)
if isinstance(base_index, faiss.IndexIVF):
print(" update quantizer nprobe=", base_index.nprobe, end=" -> ")
base_index.nprobe = (
16 if base_index.nlist < 1e5 else
32 if base_index.nlist < 4e6 else
64)
print(base_index.nprobe)
elif isinstance(quantizer, faiss.IndexHNSW):
print(" update quantizer efSearch=", quantizer.hnsw.efSearch, end=" -> ")
quantizer.hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64
print(quantizer.hnsw.efSearch)

if index_ivf:
index_ivf.verbose = True
index_ivf.quantizer.verbose = True
index_ivf.cp.verbose = True
else:
index_ivf = index
vec_transform = lambda x:x
assert isinstance(index_ivf, faiss.IndexIVF)
index_ivf.verbose = True
index_ivf.quantizer.verbose = True
index_ivf.cp.verbose = True
index.verbose = True

maxtrain = args.maxtrain
if maxtrain == 0:
if 'IMI' in args.indexkey:
maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2))
else:
elif index_ivf:
maxtrain = 50 * index_ivf.nlist
else:
# just guess...
maxtrain = 256 * 100
maxtrain = max(maxtrain, 256 * 100)
print("setting maxtrain to %d" % maxtrain)
args.maxtrain = maxtrain

xt2 = sanitize(xt[:args.maxtrain])
assert np.all(np.isfinite(xt2))
try:
xt2 = ds.get_train(maxtrain=maxtrain)
except NotImplementedError:
print("No training set: training on database")
xt2 = ds.get_database()[:maxtrain]

print("train, size", xt2.shape)
assert np.all(np.isfinite(xt2))

if (isinstance(vec_transform, faiss.OPQMatrix) and
isinstance(index_ivf, faiss.IndexIVFPQFastScan)):
print(" Forcing OPQ training PQ to PQ4")
ref_pq = index_ivf.pq
training_pq = faiss.ProductQuantizer(
ref_pq.d, ref_pq.M, ref_pq.nbits
)
vec_transform.pq
vec_transform.pq = training_pq


if args.get_centroids_from == '':

Expand All @@ -147,7 212,8 @@ def aa(*args, **kwargs):

if args.train_on_gpu:
print("add a training index on GPU")
train_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
train_index = faiss.index_cpu_to_all_gpus(
faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = train_index

else:
Expand All @@ -158,13 224,15 @@ def aa(*args, **kwargs):
centroids = centroids.reshape(-1, d)
print(" centroid table shape", centroids.shape)

if isinstance(index, faiss.IndexPreTransform):
if isinstance(vec_transform, faiss.VectorTransform):
print(" training vector transform")
assert index.chain.size() == 1
vt = index.chain.at(0)
vt.train(xt2)
vec_transform.train(xt2)
print(" transform centroids")
centroids = vt.apply_py(centroids)
centroids = vec_transform.apply_py(centroids)

if not index_ivf.quantizer.is_trained:
print(" training quantizer")
index_ivf.quantizer.train(centroids)

print(" add centroids to quantizer")
index_ivf.quantizer.add(centroids)
Expand All @@ -177,12 245,16 @@ def aa(*args, **kwargs):
print("adding")
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(xb))
index.add(sanitize(ds.get_database()))
else:
for i0 in range(0, nb, args.add_bs):
i1 = min(nb, i0 args.add_bs)
print(" adding %d:%d / %d" % (i0, i1, nb))
index.add(sanitize(xb[i0:i1]))
i0 = 0
for xblock in ds.database_iterator(bs=args.add_bs):
i1 = i0 len(xblock)
print(" adding %d:%d / %d [%.3f s, RSS %d kiB] " % (
i0, i1, ds.nb, time.time() - t0,
faiss.get_mem_usage_kb()))
index.add(xblock)
i0 = i1

print(" add in %.3f s" % (time.time() - t0))
if args.indexfile:
Expand Down Expand Up @@ -211,39 283,65 @@ def aa(*args, **kwargs):
# Index is ready
#############################################################

xq = sanitize(xq)
xq = sanitize(ds.get_queries())
gt = ds.get_groundtruth(k=args.k)
assert gt.shape[1] == args.k, pdb.set_trace()

if args.searchthreads != -1:
print("Setting nb of threads to", args.searchthreads)
faiss.omp_set_num_threads(args.searchthreads)


ps = faiss.ParameterSpace()
ps.initialize(index)


parametersets = args.searchparams

header = '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % "parameters"


def eval_setting(index, xq, gt, min_time):
if args.inter:
header = (
'%-40s inter@= time(ms/q) nb distances #runs' %
("parameters", args.k)
)
else:

header = (
'%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' %
"parameters"
)

def compute_inter(a, b):
nq, rank = a.shape
ninter = sum(
np.intersect1d(a[i, :rank], b[i, :rank]).size
for i in range(nq)
)
return ninter / a.size



def eval_setting(index, xq, gt, k, inter, min_time):
nq = xq.shape[0]
ivf_stats = faiss.cvar.indexIVF_stats
ivf_stats.reset()
nrun = 0
t0 = time.time()
while True:
D, I = index.search(xq, 100)
D, I = index.search(xq, k)
nrun = 1
t1 = time.time()
if t1 - t0 > min_time:
break
ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun)
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print(" %8.3f " % ms_per_query, end=' ')
if inter:
rank = k
inter_measure = compute_inter(gt[:, :rank], I[:, :rank])
print("%.4f" % inter_measure, end=' ')
else:
for rank in 1, 10, 100:
n_ok = (I[:, :rank] == gt[:, :1]).sum()
print("%.4f" % (n_ok / float(nq)), end=' ')
print(" %9.5f " % ms_per_query, end=' ')
print("d " % (ivf_stats.ndis / nrun), end=' ')
print(nrun)

Expand All @@ -269,15 367,20 @@ def eval_setting(index, xq, gt, min_time):
pr = ps.add_range(k)
faiss.copy_array_to_vector(vals, pr.values)

# setup the Criterion object: optimize for 1-R@1
crit = faiss.OneRecallAtRCriterion(nq, 1)
# setup the Criterion object
if args.inter:
print("Optimize for intersection @ ", args.k)
crit = faiss.IntersectionCriterion(nq, args.k)
else:
print("Optimize for 1-recall @ 1")
crit = faiss.OneRecallAtRCriterion(nq, 1)

# by default, the criterion will request only 1 NN
crit.nnn = 100
crit.nnn = args.k
crit.set_groundtruth(None, gt.astype('int64'))

# then we let Faiss find the optimal parameters by itself
print("exploring operating points")
print("exploring operating points, %d threads" % faiss.omp_get_max_threads());
ps.display()

t0 = time.time()
Expand All @@ -286,17 389,19 @@ def eval_setting(index, xq, gt, min_time):

op.display()

print("Re-running evaluation on selected OPs")
print(header)
opv = op.optimal_pts
maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40)
for i in range(opv.size()):
opt = opv.at(i)

ps.set_index_parameters(index, opt.key)

print("%-40s " % opt.key, end=' ')
print(opt.key.ljust(maxw), end=' ')
sys.stdout.flush()

eval_setting(index, xq, gt, args.min_test_duration)
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)

else:
print(header)
Expand All @@ -305,4 410,4 @@ def eval_setting(index, xq, gt, min_time):
sys.stdout.flush()
ps.set_index_parameters(index, param)

eval_setting(index, xq, gt, args.min_test_duration)
eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration)
Loading

0 comments on commit c5975cd

Please sign in to comment.