Skip to content

Commit

Permalink
Add mention graph extraction to ConversationalGraphExtractor
Browse files Browse the repository at this point in the history
  • Loading branch information
Aethor committed Jul 18, 2024
1 parent 7fb306d commit 3cce6fb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 8 deletions.
87 changes: 80 additions & 7 deletions renard/pipeline/graph_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,26 442,49 @@ def optional_needs(self) -> Set[str]:


class ConversationalGraphExtractor(PipelineStep):
"""A graph extractor using conversation between characters
"""A graph extractor using conversation between characters or
mentions.
.. note::
This is an early version, that only supports static graphs
for now.
Does not support dynamic networks yet.
"""

def __init__(
self, conversation_dist: Union[int, Tuple[int, Literal["tokens", "sentences"]]]
self,
graph_type: Literal["conversation", "mention"],
conversation_dist: Optional[
Union[int, Tuple[int, Literal["tokens", "sentences"]]]
] = None,
ignore_self_mention: bool = True,
):
"""
:param graph_type: either 'conversation' or 'mention'.
'conversation' extracts an undirected graph with
interactions being extracted from the conversations
occurring between characters. 'mention' extracts a
directed graph where interactions are character mentions
of one another in quoted speech.
:param conversation_dist: must be supplied if `graph_type` is
'conversation'. The distance between two quotation for
them to be considered as being interacting.
:param ignore_self_mention: if ``True``, self mentions are
ignore for ``graph_type=='mention'``
"""
self.graph_type = graph_type

if isinstance(conversation_dist, int):
conversation_dist = (conversation_dist, "tokens")
self.conversation_dist = conversation_dist

self.ignore_self_mention = ignore_self_mention

super().__init__()

def _quotes_interact(
self, quote_1: Quote, quote_2: Quote, sentences: List[List[str]]
) -> bool:
assert not self.conversation_dist is None
ordered = quote_2.start >= quote_1.end
if self.conversation_dist[1] == "tokens":
return (
Expand All @@ -483,14 506,13 @@ def _quotes_interact(
else:
raise NotImplementedError

def __call__(
def _conversation_extract(
self,
sentences: List[List[str]],
quotes: List[Quote],
speakers: List[Optional[Character]],
characters: Set[Character],
**kwargs,
) -> Dict[str, Any]:
) -> nx.Graph:
G = nx.Graph()
for character in characters:
G.add_node(character)
Expand Down Expand Up @@ -520,6 542,57 @@ def __call__(
G.add_edge(speaker_1, speaker_2, weight=0)
G.edges[speaker_1, speaker_2]["weight"] = 1

return G

def _mention_extract(
self,
quotes: List[Quote],
speakers: List[Optional[Character]],
characters: Set[Character],
) -> nx.Graph:
G = nx.DiGraph()
for character in characters:
G.add_node(character)

for quote, speaker in zip(quotes, speakers):
# no speaker prediction: ignore
if speaker is None:
continue

# TODO: optim
# find characters mentioned in quote and add a directed
# edge speaker => character
for character in characters:
if character == speaker and self.ignore_self_mention:
continue
for mention in character.mentions:
if (
mention.start_idx >= quote.start
and mention.end_idx <= quote.end
):
if not G.has_edge(speaker, character):
G.add_edge(speaker, character, weight=0)
G.edges[speaker, character]["weight"] = 1
break

return G

def __call__(
self,
sentences: List[List[str]],
quotes: List[Quote],
speakers: List[Optional[Character]],
characters: Set[Character],
**kwargs,
) -> Dict[str, Any]:

if self.graph_type == "conversation":
G = self._conversation_extract(sentences, quotes, speakers, characters)
elif self.graph_type == "mention":
G = self._mention_extract(quotes, speakers, characters)
else:
raise ValueError(f"unknown graph_type: {self.graph_type}")

return {"character_network": G}

def needs(self) -> Set[str]:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 84,9 @@ def test_conversational_pipeline_runs():
warn=False,
progress_report=None,
conversational=True,
graph_extractor_kwargs={"conversation_dist": (3, "sentences")},
graph_extractor_kwargs={
"graph_type": "conversation",
"conversation_dist": (3, "sentences"),
},
)
pipeline(text)

0 comments on commit 3cce6fb

Please sign in to comment.