Skip to content

Commit

Permalink
add --eos-factor for beam search to alleviate the problem of too shor…
Browse files Browse the repository at this point in the history
…t transcripts with LM fusion
  • Loading branch information
freewym committed Oct 10, 2019
1 parent 18f6774 commit be72415
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/asr_librispeech/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 229,7 @@ if [ ${stage} -le 8 ]; then
decode_affix=
if $lm_shallow_fusion; then
path="$path:$lmdir/$lm_checkpoint"
opts="$opts --lm-weight 0.4 --coverage-weight 0.015"
opts="$opts --lm-weight 0.4 --coverage-weight 0.0 --eos-factor 1.5"
decode_affix=shallow_fusion
fi
for dataset in $test_set; do
Expand Down
2 changes: 1 addition & 1 deletion examples/asr_wsj/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 294,7 @@ if [ ${stage} -le 9 ]; then
decode_affix=shallow_fusion
else
path="$path:$wordlmdir/$lm_checkpoint"
opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.01"
opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.005 --eos-factor 1.5"
decode_affix=shallow_fusion_wordlm
fi
fi
Expand Down
10 changes: 9 additions & 1 deletion fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 35,8 @@ def __init__(
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
coverage_weight=0.01,
coverage_weight=0.0,
eos_factor=None,
):
"""Generates translations of a given source sentence.
Expand Down Expand Up @@ -87,10 88,12 @@ def __init__(
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
self.coverage_weight = coverage_weight
self.eos_factor = eos_factor

assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
assert temperature > 0, '--temperature must be greater than 0'
assert eos_factor is None or eos_factor >= 1.0, '--eos-factor must be >= 1.0 if set'

if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp)
Expand Down Expand Up @@ -318,6 321,11 @@ def get_hypo():
lprobs[:, self.eos 1:] = -math.inf
elif step < self.min_len:
lprobs[:, self.eos] = -math.inf
elif self.eos_factor is not None:
# only consider EOS if its score is no less than a specified
# factor of the best candidate score
disallow_eos_mask = lprobs[:, self.eos] < self.eos_factor * lprobs.max(dim=1)[0]
lprobs[disallow_eos_mask, self.eos] = -math.inf

# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1):
Expand Down
1 change: 1 addition & 0 deletions fairseq/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 241,7 @@ def build_generator(self, args):
match_source_len=getattr(args, 'match_source_len', False),
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
coverage_weight=getattr(args, 'coverage_weight', 0.0),
eos_factor=getattr(args, 'eos_factor', None),
)

def build_dataset_for_inference(self, src_tokens, src_lengths):
Expand Down
3 changes: 3 additions & 0 deletions speech_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 223,9 @@ def cli_main():
help='coverage weight in log-prob space, mostly to '
'reduce deletion errors while using the pretrained '
'external LM for decoding')
parser.add_argument('--eos-factor', default=None, type=float, metavar='F',
help='only consider emitting EOS if its score is no less '
'than the specified factor of the best candidate score')
parser.add_argument('--lm-weight', default=0.0, type=float, metavar='W',
help='LM weight in log-prob space, assuming the pretrained '
'external LM is specified as the second one in --path')
Expand Down

0 comments on commit be72415

Please sign in to comment.