-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_constrained_tokenset.py
executable file
·120 lines (102 loc) · 4.09 KB
/
gen_constrained_tokenset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
##
## Copyright (c) 2016, Alliance for Open Media. All rights reserved.
##
## This source code is subject to the terms of the BSD 2 Clause License and
## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
## was not distributed with this source code in the LICENSE file, you can
## obtain it at www.aomedia.org/license/software. If the Alliance for Open
## Media Patent License 1.0 was not distributed with this source code in the
## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
##
"""Generate the probability model for the constrained token set.
Model obtained from a 2-sided zero-centered distribution derived
from a Pareto distribution. The cdf of the distribution is:
cdf(x) = 0.5 0.5 * sgn(x) * [1 - {alpha/(alpha |x|)} ^ beta]
For a given beta and a given probability of the 1-node, the alpha
is first solved, and then the {alpha, beta} pair is used to generate
the probabilities for the rest of the nodes.
"""
import heapq
import sys
import numpy as np
import scipy.optimize
import scipy.stats
def cdf_spareto(x, xm, beta):
p = 1 - (xm / (np.abs(x) xm))**beta
p = 0.5 0.5 * np.sign(x) * p
return p
def get_spareto(p, beta):
cdf = cdf_spareto
def func(x):
return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) /
(1 - cdf(0.5, x, beta)) - p)**2
alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12)
parray = np.zeros(11)
parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5)
parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta)))
parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta)))
parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta)))
parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta)))
parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta)))
parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta)))
parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta)))
parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta)))
parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta)))
parray[10] = 2 * (1. - cdf(66.5, alpha, beta))
return parray
def quantize_probs(p, save_first_bin, bits):
"""Quantize probability precisely.
Quantize probabilities minimizing dH (Kullback-Leibler divergence)
approximated by: sum (p_i-q_i)^2/p_i.
References:
https://en.wikipedia.org/wiki/Kullback–Leibler_divergence
https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit
"""
num_sym = p.size
p = np.clip(p, 1e-16, 1)
L = 2**bits
pL = p * L
ip = 1. / p # inverse probability
q = np.clip(np.round(pL), 1, L 1 - num_sym)
quant_err = (pL - q)**2 * ip
sgn = np.sign(L - q.sum()) # direction of correction
if sgn != 0: # correction is needed
v = [] # heap of adjustment results (adjustment err, index) of each symbol
for i in range(1 if save_first_bin else 0, num_sym):
q_adj = q[i] sgn
if q_adj > 0 and q_adj < L:
adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
heapq.heappush(v, (adj_err, i))
while q.sum() != L:
# apply lowest error adjustment
(adj_err, i) = heapq.heappop(v)
quant_err[i] = adj_err
q[i] = sgn
# calculate the cost of adjusting this symbol again
q_adj = q[i] sgn
if q_adj > 0 and q_adj < L:
adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
heapq.heappush(v, (adj_err, i))
return q
def get_quantized_spareto(p, beta, bits, first_token):
parray = get_spareto(p, beta)
parray = parray[1:] / (1 - parray[0])
# CONFIG_NEW_TOKENSET
if first_token > 1:
parray = parray[1:] / (1 - parray[0])
qarray = quantize_probs(parray, first_token == 1, bits)
return qarray.astype(np.int)
def main(bits=15, first_token=1):
beta = 8
for q in range(1, 256):
parray = get_quantized_spareto(q / 256., beta, bits, first_token)
assert parray.sum() == 2**bits
print('{', ', '.join('%d' % i for i in parray), '},')
if __name__ == '__main__':
if len(sys.argv) > 2:
main(int(sys.argv[1]), int(sys.argv[2]))
elif len(sys.argv) > 1:
main(int(sys.argv[1]))
else:
main()