class KeywordExtractor(BaseExtractor):
"""Keyword extractor. Node-level extractor. Extracts
`excerpt_keywords` metadata field.
Args:
llm (Optional[LLM]): LLM
keywords (int): number of keywords to extract
prompt_template (str): template for keyword extraction
"""
llm: LLMPredictorType = Field(description="The LLM to use for generation.")
keywords: int = Field(
default=5, description="The number of keywords to extract.", gt=0
)
prompt_template: str = Field(
default=DEFAULT_KEYWORD_EXTRACT_TEMPLATE,
description="Prompt template to use when generating keywords.",
)
def __init__(
self,
llm: Optional[LLM] = None,
# TODO: llm_predictor arg is deprecated
llm_predictor: Optional[LLMPredictorType] = None,
keywords: int = 5,
prompt_template: str = DEFAULT_KEYWORD_EXTRACT_TEMPLATE,
num_workers: int = DEFAULT_NUM_WORKERS,
**kwargs: Any,
) -> None:
"""Init params."""
if keywords < 1:
raise ValueError("num_keywords must be >= 1")
super().__init__(
llm=llm or llm_predictor or Settings.llm,
keywords=keywords,
prompt_template=prompt_template,
num_workers=num_workers,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "KeywordExtractor"
async def _aextract_keywords_from_node(self, node: BaseNode) -> Dict[str, str]:
"""Extract keywords from a node and return it's metadata dict."""
if self.is_text_node_only and not isinstance(node, TextNode):
return {}
context_str = node.get_content(metadata_mode=self.metadata_mode)
keywords = await self.llm.apredict(
PromptTemplate(template=self.prompt_template),
keywords=self.keywords,
context_str=context_str,
)
return {"excerpt_keywords": keywords.strip()}
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
keyword_jobs = []
for node in nodes:
keyword_jobs.append(self._aextract_keywords_from_node(node))
metadata_list: List[Dict] = await run_jobs(
keyword_jobs, show_progress=self.show_progress, workers=self.num_workers
)
return metadata_list