🏔️ 使用工作流程进行回溯提示,结合 Argilla 实现 RAG¶
本教程将展示如何使用 LlamaIndex 工作流程结合 Argilla 实现 RAG 的回溯提示。
这种提示方法基于“退后一步:通过大型语言模型中的抽象引发推理”。这篇论文建议通过让模型退后一步,以更抽象的方式推理上下文,来改进响应。通过这种方式,原始查询被抽象化,并用于检索相关信息。然后,该上下文与原始上下文和查询一起用于生成最终响应。
Argilla 是一个面向 AI 工程师和领域专家的协作工具,用于构建高质量的数据集。通过使用 Argilla,您可以分析和提高数据质量,通过将人工反馈纳入循环来改善模型性能。该集成将自动在 Argilla 中记录查询、响应、检索到的上下文及其得分,以及完整的跟踪(包括 spans 和 events)和相关元数据。默认情况下,您将能够对响应进行评分、提供反馈并评估检索到的上下文,从而确保准确性并防止任何差异。
它包括以下步骤:
- 为 LlamaIndex 设置 Argilla 处理程序。
- 设计回溯工作流程。
- 使用 LlamaIndex 运行回溯工作流程并将响应自动记录到 Argilla。
%pip install "argilla-llama-index>=2.1.0"
导入所需的库
from llama_index.core import (
Settings,
SimpleDirectoryReader,
VectorStoreIndex,
)
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.schema import NodeWithScore
from llama_index.core.workflow import (
Context,
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.core import get_response_synthesizer
from llama_index.core.workflow import Event
from llama_index.utils.workflow import draw_all_possible_flows
from llama_index.llms.openai import OpenAI
from argilla_llama_index import ArgillaHandler
我们需要设置 OpenAI API 密钥。运行使用 GPT 模型进行查询时需要 OpenAI API 密钥。
import os
os.environ["OPENAI_API_KEY"] = "sk-..."
设置 Argilla 的 LlamaIndex 处理程序¶
为了在您的 LlamaIndex 工作流程中轻松将数据记录到 Argilla,您只需初始化 Argilla 处理程序并将其附加到 Llama Index dispatcher,以处理 spans 和 events。这确保了使用 LlamaIndex 获得的预测将自动与有用的元数据一起记录到 Argilla 实例中。
- dataset_name:数据集的名称。如果数据集不存在,将使用指定的名称创建。否则,将更新数据集。
- api_url:连接到 Argilla 实例的 URL。
- api_key:用于向 Argilla 实例进行身份验证的 API 密钥。
- number_of_retrievals:要记录的检索到的文档数量。默认为 0。
- workspace_name:记录数据的工作空间的名称。默认为第一个可用的工作空间。
argilla_handler = ArgillaHandler(
dataset_name="workflow_llama_index",
api_url="http://localhost:6900",
api_key="argilla.apikey",
number_of_retrievals=2,
)
root_dispatcher = get_dispatcher()
root_dispatcher.add_span_handler(argilla_handler)
root_dispatcher.add_event_handler(argilla_handler)
定义回溯工作流程¶
首先,我们需要定义回溯工作流程中将使用的两个事件。StepBackEvent 将接收回溯查询,而 RetrieverEvent 将在检索后接收与原始查询和回溯查询相关的节点。
class StepBackEvent(Event):
"""Get the step-back query"""
step_back_query: str
class RetrieverEvent(Event):
"""Result of running the retrievals"""
nodes_original: list[NodeWithScore]
nodes_step_back: list[NodeWithScore]
接下来,我们将根据原始论文定义提示,以获取回溯查询,然后获取最终响应。
STEP_BACK_TEMPLATE = """
You are an expert at world knowledge. Your task is to step back and
paraphrase a question to a more generic step-back question, which is
easier to answer. Here are a few examples:
Original Question: Which position did Knox Cunningham hold from May 1955 to Apr 1956?
Stepback Question: Which positions have Knox Cunningham held in his career?
Original Question: Who was the spouse of Anna Karina from 1968 to 1974?
Stepback Question: Who were the spouses of Anna Karina?
Original Question: what is the biggest hotel in las vegas nv as of November 28, 1993
Stepback Question: what is the size of the hotels in las vegas nv as of November 28, 1993?
Original Question: {original_query}
Stepback Question:
"""
GENERATE_ANSWER_TEMPLATE = """
You are an expert of world knowledge. I am going to ask you a question.
Your response should be comprehensive and not contradicted with the
following context if they are relevant. Otherwise, ignore them if they are
not relevant.
{context_original}
{context_step_back}
Original Question: {query}
Answer:
"""
现在,我们将定义回溯工作流程。在这种情况下,工作流程将是线性的。首先,我们将提示 LLM 对原始查询进行抽象(回溯提示)。然后,我们将检索与原始查询和回溯查询相关的节点。最后,我们将提示 LLM 生成最终响应。
class RAGWorkflow(Workflow):
@step
async def step_back(
self, ctx: Context, ev: StartEvent
) -> StepBackEvent | None:
"""Generate the step-back query."""
query = ev.get("query")
index = ev.get("index")
if not query:
return None
if not index:
return None
llm = Settings.llm
step_back_query = llm.complete(
prompt=STEP_BACK_TEMPLATE.format(original_query=query),
formatted=True,
)
await ctx.set("query", query)
await ctx.set("index", index)
return StepBackEvent(step_back_query=str(step_back_query))
@step
async def retrieve(
self, ctx: Context, ev: StepBackEvent
) -> RetrieverEvent | None:
"Retrieve the relevant nodes for the original and step-back queries."
query = await ctx.get("query", default=None)
index = await ctx.get("index", default=None)
await ctx.set("step_back_query", ev.step_back_query)
retriever = index.as_retriever(similarity_top_k=2)
nodes_step_back = await retriever.aretrieve(ev.step_back_query)
nodes_original = await retriever.aretrieve(query)
return RetrieverEvent(
nodes_original=nodes_original, nodes_step_back=nodes_step_back
)
@step
async def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent:
"""Return a response using the contextualized prompt and retrieved nodes."""
nodes_original = ev.nodes_original
nodes_step_back = ev.nodes_step_back
context_original = max(
nodes_original, key=lambda node: node.get_score()
).get_text()
context_step_back = max(
nodes_step_back, key=lambda node: node.get_score()
).get_text()
query = await ctx.get("query", default=None)
formatted_query = GENERATE_ANSWER_TEMPLATE.format(
context_original=context_original,
context_step_back=context_step_back,
query=query,
)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.COMPACT
)
response = response_synthesizer.synthesize(
formatted_query, nodes=ev.nodes_original
)
return StopEvent(result=response)
draw_all_possible_flows(RAGWorkflow, filename="step_back_workflow.html")
运行回溯工作流程¶
我们将使用从LlamaIndex 文档获取的示例 .txt 文件。
# Retrieve the data if needed
!mkdir -p ../../data
!curl https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt -o ../../data/paul_graham_essay.txt
现在,让我们使用这份文档创建一个 LlamaIndex 索引。由于原始查询和回溯查询的最高评分上下文将包含在最终提示中,我们将减小块大小并使用 SentenceSplitter。
# LLM settings
Settings.llm = OpenAI(model="gpt-3.5-turbo", temperature=0.8)
# Load the data and create the index
transformations = [
SentenceSplitter(chunk_size=256, chunk_overlap=75),
]
documents = SimpleDirectoryReader("../../data").load_data()
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations,
)
现在,让我们运行回溯工作流程并进行查询。
w = RAGWorkflow()
result = await w.run(query="What's Paul's work", index=index)
result
下一步¶
标注数据后,您可以从 Argilla 中检索数据。通过将人工反馈集成到流程中,我们保证了数据质量,使其可用于微调您的模型。此外,为了保持模型性能并防止数据漂移,您可以预留一部分数据进行持续评估。