跳到正文

Auto prev next

节点后处理器模块。

AutoPrevNextNodePostprocessor #

基础: BaseNodePostprocessor

上一个/下一个节点后处理器。

允许用户根据节点的上一个/下一个关系从文档存储中获取附加节点。

注意:与 PrevNextPostprocessor 的区别在于此功能会推断向前/向后方向。

注意:这是 Beta 功能。

参数

名称 类型 描述 默认值
docstore BaseDocumentStore

文档存储。

必需
num_nodes int

要返回的节点数量(默认值:1)

1
infer_prev_next_tmpl str

用于推理的模板。必需字段为 {context_str} 和 {query_str}。

"提供了当前上下文信息。\n还提供了一个问题。\n您是一个检索代理,决定是否在文档存储中搜索额外的先前上下文或未来上下文。\n给定上下文和问题,返回 PREVIOUS、NEXT 或 NONE。\n示例:\n\n上下文:描述了作者在 Y Combinator 的经历。问题:作者在 Y Combinator 之后做了什么?\n答案:NEXT\n\n上下文:描述了作者在 Y Combinator 的经历。问题:作者在 Y Combinator 之前做了什么?\n答案:PREVIOUS\n\n上下文:描述作者在 Y Combinator 的经历。问题:作者在 Y Combinator 做了什么?\n答案:NONE\n\n上下文:{context_str}\n问题:{query_str}\n答案:"
llm Annotated[LLM, SerializeAsAny] | None
refine_prev_next_tmpl str
'提供了当前上下文信息。\n还提供了一个问题。\n还提供了一个现有答案。\n您是一个检索代理,决定是否在文档存储中搜索额外的先前上下文或未来上下文。\n给定上下文、问题和先前答案,返回 PREVIOUS、NEXT 或 NONE。\n示例:\n\n上下文:{context_msg}\n问题:{query_str}\n现有答案:{existing_answer}\n答案:'
verbose bool
False
response_mode ResponseMode
<ResponseMode.COMPACT: 'compact'>
源代码位于 llama-index-core/llama_index/core/postprocessor/node.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class AutoPrevNextNodePostprocessor(BaseNodePostprocessor):
    """
    Previous/Next Node post-processor.

    Allows users to fetch additional nodes from the document store,
    based on the prev/next relationships of the nodes.

    NOTE: difference with PrevNextPostprocessor is that
    this infers forward/backwards direction.

    NOTE: this is a beta feature.

    Args:
        docstore (BaseDocumentStore): The document store.
        num_nodes (int): The number of nodes to return (default: 1)
        infer_prev_next_tmpl (str): The template to use for inference.
            Required fields are {context_str} and {query_str}.

    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    docstore: BaseDocumentStore
    llm: Optional[SerializeAsAny[LLM]] = None
    num_nodes: int = Field(default=1)
    infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL)
    refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL)
    verbose: bool = Field(default=False)
    response_mode: ResponseMode = Field(default=ResponseMode.COMPACT)

    @classmethod
    def class_name(cls) -> str:
        return "AutoPrevNextNodePostprocessor"

    def _parse_prediction(self, raw_pred: str) -> str:
        """Parse prediction."""
        pred = raw_pred.strip().lower()
        if "previous" in pred:
            return "previous"
        elif "next" in pred:
            return "next"
        elif "none" in pred:
            return "none"
        raise ValueError(f"Invalid prediction: {raw_pred}")

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        """Postprocess nodes."""
        llm = self.llm or Settings.llm

        if query_bundle is None:
            raise ValueError("Missing query bundle.")

        infer_prev_next_prompt = PromptTemplate(
            self.infer_prev_next_tmpl,
        )
        refine_infer_prev_next_prompt = PromptTemplate(self.refine_prev_next_tmpl)

        all_nodes: Dict[str, NodeWithScore] = {}
        for node in nodes:
            all_nodes[node.node.node_id] = node
            # use response builder instead of llm directly
            # to be more robust to handling long context
            response_builder = get_response_synthesizer(
                llm=llm,
                text_qa_template=infer_prev_next_prompt,
                refine_template=refine_infer_prev_next_prompt,
                response_mode=self.response_mode,
            )
            raw_pred = response_builder.get_response(
                text_chunks=[node.node.get_content()],
                query_str=query_bundle.query_str,
            )
            raw_pred = cast(str, raw_pred)
            mode = self._parse_prediction(raw_pred)

            logger.debug(f"> Postprocessor Predicted mode: {mode}")
            if self.verbose:
                print(f"> Postprocessor Predicted mode: {mode}")

            if mode == "next":
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
            elif mode == "previous":
                all_nodes.update(
                    get_backward_nodes(node, self.num_nodes, self.docstore)
                )
            elif mode == "none":
                pass
            else:
                raise ValueError(f"Invalid mode: {mode}")

        sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.node.node_id)
        return list(sorted_nodes)