This repository has been archived by the owner on Nov 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
81 lines (59 loc) · 2.35 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
from pathlib import Path
import click
from flair.data import Sentence
from flair.datasets import ColumnCorpus
from flair.models import SequenceTagger
# Convert IOBES to IOB for CoNLL evaluations script
def iobes_to_iob(tag):
iob_tag = tag
if tag.startswith("S-"):
iob_tag = tag.replace("S-", "B-")
if tag.startswith("E-"):
iob_tag = tag.replace("E-", "I-")
return iob_tag
@click.command()
@click.option("--dataset", type=str, help="Define dataset")
@click.option("--split", default="test", type=str, help="Defines dataset split (dev or test)")
@click.option("--model", type=click.Path(exists=True))
def parse_arguments(dataset, split, model):
# Adjust logging level
logging.getLogger("flair").setLevel(level="ERROR")
columns = {0: "text", 1: "ner"}
if dataset == "lft":
corpus: ColumnCorpus = ColumnCorpus(
Path("./data"),
columns,
train_file="./enp_DE.lft.mr.tok.train.bio",
dev_file="./enp_DE.lft.mr.tok.dev.bio",
test_file="./enp_DE.lft.mr.tok.test.bio",
tag_to_bioes="ner",
)
elif dataset == "onb":
corpus: ColumnCorpus = ColumnCorpus(
Path("./data"),
columns,
train_file="./enp_DE.onb.mr.tok.train.bio",
dev_file="./enp_DE.onb.mr.tok.dev.bio",
test_file="./enp_DE.onb.mr.tok.test.bio",
tag_to_bioes="ner",
)
tagger: SequenceTagger = SequenceTagger.load(model)
dataset_split = corpus.test if split == "test" else corpus.dev
for test_sentence in dataset_split:
tokens = test_sentence.tokens
gold_tags = [token.get_tag("ner").value for token in tokens]
tagged_sentence = Sentence()
tagged_sentence.tokens = tokens
# Tag sentence with model
tagger.predict(tagged_sentence)
predicted_tags = [token.get_tag("ner").value for token in tagged_sentence.tokens]
assert len(tokens) == len(gold_tags)
assert len(gold_tags) == len(predicted_tags)
for token, gold_tag, predicted_tag in zip(tokens, gold_tags, predicted_tags):
gold_tag = iobes_to_iob(gold_tag)
predicted_tag = iobes_to_iob(predicted_tag)
print(f"{token.text} {gold_tag} {predicted_tag}")
print("")
if __name__ == "__main__":
parse_arguments()