In [ ]
已复制!
import nest_asyncio
nest_asyncio.apply()
import nest_asyncio nest_asyncio.apply()
In [ ]
已复制!
%pip install -U llama-index llama-index-tools-tavily-research
%pip install -U llama-index llama-index-tools-tavily-research
In [ ]
已复制!
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
tavily_ai_api_key = "<Your Tavily AI API Key>"
import os os.environ["OPENAI_API_KEY"] = "sk-proj-..." tavily_ai_api_key = ""
In [ ]
已复制!
!mkdir -p 'data/'
!wget 'https://arxiv.org/pdf/2307.09288.pdf' -O 'data/llama2.pdf'
!mkdir -p 'data/' !wget 'https://arxiv.org/pdf/2307.09288.pdf' -O 'data/llama2.pdf'
由于工作流是异步优先的,所有这些在 notebook 中都能正常运行。如果你在自己的代码中运行,如果没有异步事件循环正在运行,则需要使用 asyncio.run()
来启动一个。
async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())
设计工作流¶
纠正性 RAG 包含以下步骤
- 数据摄取 — 将数据加载到索引中并设置 Tavily AI。摄取步骤将自行运行,接收启动事件并返回停止事件。
- 检索 - 根据查询检索最相关的节点。
- 相关性评估 - 使用 LLM 根据节点内容确定检索到的节点是否与查询相关。
- 相关性提取 - 提取 LLM 确定为相关的节点。
- 查询转换和 Tavily 搜索 - 如果节点不相关,则使用 LLM 转换查询以适应网络搜索。使用 Tavily 在网络上搜索基于查询的相关答案。
- 响应生成 - 根据相关节点和 Tavily 搜索中的文本构建摘要索引,并使用此索引获取原始查询的结果。
需要以下事件
PrepEvent
- 表示索引和其他对象已准备好的事件。RetrieveEvent
- 包含有关检索到的节点信息的事件。RelevanceEvalEvent
- 包含相关性评估结果列表的事件。TextExtractEvent
- 包含从相关节点提取的相关文本连接字符串的事件。QueryEvent
- 包含相关文本和搜索文本的事件。
In [ ]
已复制!
from llama_index.core.workflow import Event
from llama_index.core.schema import NodeWithScore
class PrepEvent(Event):
"""Prep event (prepares for retrieval)."""
pass
class RetrieveEvent(Event):
"""Retrieve event (gets retrieved nodes)."""
retrieved_nodes: list[NodeWithScore]
class RelevanceEvalEvent(Event):
"""Relevance evaluation event (gets results of relevance evaluation)."""
relevant_results: list[str]
class TextExtractEvent(Event):
"""Text extract event. Extracts relevant text and concatenates."""
relevant_text: str
class QueryEvent(Event):
"""Query event. Queries given relevant text and search text."""
relevant_text: str
search_text: str
from llama_index.core.workflow import ( Workflow, step, Context, StartEvent, StopEvent, ) from llama_index.core import ( VectorStoreIndex, Document, PromptTemplate, SummaryIndex, ) from llama_index.core.query_pipeline import QueryPipeline from llama_index.llms.openai import OpenAI from llama_index.tools.tavily_research.base import TavilyToolSpec from llama_index.core.base.base_retriever import BaseRetriever DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate( template="""作为评分员,你的任务是评估检索到的文档与用户问题的相关性。检索到的文档: ------------------- {context_str} 用户问题: -------------- {query_str} 评估标准: - 考虑文档是否包含与用户问题相关的关键词或主题。 - 评估不应过于严格;主要目标是识别并过滤掉明显不相关的检索结果。决策: - 为文档的相关性分配一个二元分数。 - 如果文档与问题相关,则使用“是”,如果不相关,则使用“否”。请在下方提供你的二元分数(“是”或“否”)来指示文档与用户问题的相关性。""" ) DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate( template="""你的任务是改进查询,以确保其对于检索相关搜索结果非常有效。 \n 分析给定的输入以掌握核心语义意图或含义。 \n 原始查询: \n ------- \n {query_str} \n ------- \n 你的目标是重新措辞或增强此查询,以提高其搜索性能。确保修订后的查询简洁并与预期的搜索目标直接一致。 \n 仅返回优化后的查询:""" ) class CorrectiveRAGWorkflow(Workflow): @step async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None: """摄取步骤(用于摄取文档和初始化索引)。""" documents: list[Document] | None = ev.get("documents") if documents is None: return None index = VectorStoreIndex.from_documents(documents) return StopEvent(result=index) @step async def prepare_for_retrieval( self, ctx: Context, ev: StartEvent ) -> PrepEvent | None: """准备进行检索。""" query_str: str | None = ev.get("query_str") retriever_kwargs: dict | None = ev.get("retriever_kwargs", {}) if query_str is None: return None tavily_ai_apikey: str | None = ev.get("tavily_ai_apikey") index = ev.get("index") llm = OpenAI(model="gpt-4") await ctx.set( "relevancy_pipeline", QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]), ) await ctx.set( "transform_query_pipeline", QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]), ) await ctx.set("llm", llm) await ctx.set("index", index) await ctx.set("tavily_tool", TavilyToolSpec(api_key=tavily_ai_apikey)) await ctx.set("query_str", query_str) await ctx.set("retriever_kwargs", retriever_kwargs) return PrepEvent() @step async def retrieve( self, ctx: Context, ev: PrepEvent ) -> RetrieveEvent | None: """检索查询的相关节点。""" query_str = await ctx.get("query_str") retriever_kwargs = await ctx.get("retriever_kwargs") if query_str is None: return None index = await ctx.get("index", default=None) tavily_tool = await ctx.get("tavily_tool", default=None) if not (index or tavily_tool): raise ValueError( "Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first." ) retriever: BaseRetriever = index.as_retriever(**retriever_kwargs) result = retriever.retrieve(query_str) await ctx.set("retrieved_nodes", result) await ctx.set("query_str", query_str) return RetrieveEvent(retrieved_nodes=result) @step async def eval_relevance( self, ctx: Context, ev: RetrieveEvent ) -> RelevanceEvalEvent: """评估检索到的文档与查询的相关性。""" retrieved_nodes = ev.retrieved_nodes query_str = await ctx.get("query_str") relevancy_results = [] for node in retrieved_nodes: relevancy_pipeline = await ctx.get("relevancy_pipeline") relevancy = relevancy_pipeline.run( context_str=node.text, query_str=query_str ) relevancy_results.append(relevancy.message.content.lower().strip()) await ctx.set("relevancy_results", relevancy_results) return RelevanceEvalEvent(relevant_results=relevancy_results) @step async def extract_relevant_texts( self, ctx: Context, ev: RelevanceEvalEvent ) -> TextExtractEvent: """从检索到的文档中提取相关文本。""" retrieved_nodes = await ctx.get("retrieved_nodes") relevancy_results = ev.relevant_results relevant_texts = [ retrieved_nodes[i].text for i, result in enumerate(relevancy_results) if result == "yes" ] result = "\n".join(relevant_texts) return TextExtractEvent(relevant_text=result) @step async def transform_query_pipeline( self, ctx: Context, ev: TextExtractEvent ) -> QueryEvent: """使用 Tavily API 搜索转换后的查询。""" relevant_text = ev.relevant_text relevancy_results = await ctx.get("relevancy_results") query_str = await ctx.get("query_str") # 如果发现任何文档不相关,则转换查询字符串以获得更好的搜索结果。 if "no" in relevancy_results: qp = await ctx.get("transform_query_pipeline") transformed_query_str = qp.run(query_str=query_str).message.content # 使用转换后的查询字符串进行搜索并收集结果。 tavily_tool = await ctx.get("tavily_tool") search_results = tavily_tool.search( transformed_query_str, max_results=5 ) search_text = "\n".join([result.text for result in search_results]) else: search_text = "" return QueryEvent(relevant_text=relevant_text, search_text=search_text) @step async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent: """获取带有相关文本的结果。""" relevant_text = ev.relevant_text search_text = ev.search_text query_str = await ctx.get("query_str") documents = [Document(text=relevant_text + "\n" + search_text)] index = SummaryIndex.from_documents(documents) query_engine = index.as_query_engine() result = query_engine.query(query_str) return StopEvent(result=result)
以下是纠正性 RAG 工作流的代码
In [ ]
已复制!
from llama_index.core.workflow import (
Workflow,
step,
Context,
StartEvent,
StopEvent,
)
from llama_index.core import (
VectorStoreIndex,
Document,
PromptTemplate,
SummaryIndex,
)
from llama_index.core.query_pipeline import QueryPipeline
from llama_index.llms.openai import OpenAI
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.base.base_retriever import BaseRetriever
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.
Retrieved Document:
-------------------
{context_str}
User Question:
--------------
{query_str}
Evaluation Criteria:
- Consider whether the document contains keywords or topics related to the user's question.
- The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.
Decision:
- Assign a binary score to indicate the document's relevance.
- Use 'yes' if the document is relevant to the question, or 'no' if it is not.
Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)
DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
Analyze the given input to grasp the core semantic intent or meaning. \n
Original Query:
\n ------- \n
{query_str}
\n ------- \n
Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
Respond with the optimized query only:"""
)
class CorrectiveRAGWorkflow(Workflow):
@step
async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
"""Ingest step (for ingesting docs and initializing index)."""
documents: list[Document] | None = ev.get("documents")
if documents is None:
return None
index = VectorStoreIndex.from_documents(documents)
return StopEvent(result=index)
@step
async def prepare_for_retrieval(
self, ctx: Context, ev: StartEvent
) -> PrepEvent | None:
"""Prepare for retrieval."""
query_str: str | None = ev.get("query_str")
retriever_kwargs: dict | None = ev.get("retriever_kwargs", {})
if query_str is None:
return None
tavily_ai_apikey: str | None = ev.get("tavily_ai_apikey")
index = ev.get("index")
llm = OpenAI(model="gpt-4")
await ctx.set(
"relevancy_pipeline",
QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]),
)
await ctx.set(
"transform_query_pipeline",
QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]),
)
await ctx.set("llm", llm)
await ctx.set("index", index)
await ctx.set("tavily_tool", TavilyToolSpec(api_key=tavily_ai_apikey))
await ctx.set("query_str", query_str)
await ctx.set("retriever_kwargs", retriever_kwargs)
return PrepEvent()
@step
async def retrieve(
self, ctx: Context, ev: PrepEvent
) -> RetrieveEvent | None:
"""Retrieve the relevant nodes for the query."""
query_str = await ctx.get("query_str")
retriever_kwargs = await ctx.get("retriever_kwargs")
if query_str is None:
return None
index = await ctx.get("index", default=None)
tavily_tool = await ctx.get("tavily_tool", default=None)
if not (index or tavily_tool):
raise ValueError(
"Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first."
)
retriever: BaseRetriever = index.as_retriever(**retriever_kwargs)
result = retriever.retrieve(query_str)
await ctx.set("retrieved_nodes", result)
await ctx.set("query_str", query_str)
return RetrieveEvent(retrieved_nodes=result)
@step
async def eval_relevance(
self, ctx: Context, ev: RetrieveEvent
) -> RelevanceEvalEvent:
"""Evaluate relevancy of retrieved documents with the query."""
retrieved_nodes = ev.retrieved_nodes
query_str = await ctx.get("query_str")
relevancy_results = []
for node in retrieved_nodes:
relevancy_pipeline = await ctx.get("relevancy_pipeline")
relevancy = relevancy_pipeline.run(
context_str=node.text, query_str=query_str
)
relevancy_results.append(relevancy.message.content.lower().strip())
await ctx.set("relevancy_results", relevancy_results)
return RelevanceEvalEvent(relevant_results=relevancy_results)
@step
async def extract_relevant_texts(
self, ctx: Context, ev: RelevanceEvalEvent
) -> TextExtractEvent:
"""Extract relevant texts from retrieved documents."""
retrieved_nodes = await ctx.get("retrieved_nodes")
relevancy_results = ev.relevant_results
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
result = "\n".join(relevant_texts)
return TextExtractEvent(relevant_text=result)
@step
async def transform_query_pipeline(
self, ctx: Context, ev: TextExtractEvent
) -> QueryEvent:
"""Search the transformed query with Tavily API."""
relevant_text = ev.relevant_text
relevancy_results = await ctx.get("relevancy_results")
query_str = await ctx.get("query_str")
# If any document is found irrelevant, transform the query string for better search results.
if "no" in relevancy_results:
qp = await ctx.get("transform_query_pipeline")
transformed_query_str = qp.run(query_str=query_str).message.content
# Conduct a search with the transformed query string and collect the results.
tavily_tool = await ctx.get("tavily_tool")
search_results = tavily_tool.search(
transformed_query_str, max_results=5
)
search_text = "\n".join([result.text for result in search_results])
else:
search_text = ""
return QueryEvent(relevant_text=relevant_text, search_text=search_text)
@step
async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
"""Get result with relevant text."""
relevant_text = ev.relevant_text
search_text = ev.search_text
query_str = await ctx.get("query_str")
documents = [Document(text=relevant_text + "\n" + search_text)]
index = SummaryIndex.from_documents(documents)
query_engine = index.as_query_engine()
result = query_engine.query(query_str)
return StopEvent(result=result)
from llama_index.core.workflow import ( Workflow, step, Context, StartEvent, StopEvent, ) from llama_index.core import ( VectorStoreIndex, Document, PromptTemplate, SummaryIndex, ) from llama_index.core.query_pipeline import QueryPipeline from llama_index.llms.openai import OpenAI from llama_index.tools.tavily_research.base import TavilyToolSpec from llama_index.core.base.base_retriever import BaseRetriever DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate( template="""作为评分员,你的任务是评估针对用户问题检索到的文档的相关性。检索到的文档:------------------- {context_str}用户问题:-------------- {query_str}评估标准:- 考虑文档是否包含与用户问题相关的关键词或主题。- 评估不应过于严格;主要目标是识别并过滤掉明显不相关的检索结果。决定:- 给文档的相关性分配一个二元评分。- 如果文档与问题相关,使用“yes”,如果不相关,使用“no”。请在下方提供你的二元评分(“yes”或“no”)来表示文档与用户问题的相关性。""" ) DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate( template="""你的任务是优化查询,以确保它对于检索相关的搜索结果非常有效。\n 分析给定的输入,以理解其核心语义意图或含义。\n 原始查询:\n -------\n {query_str}\n -------\n 你的目标是重新措辞或增强此查询,以提高其搜索性能。确保修改后的查询简洁,并与预期的搜索目标直接对齐。\n 仅回复优化后的查询:""" ) class CorrectiveRAGWorkflow(Workflow): @step async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None: """摄取步骤(用于摄取文档和初始化索引)。""" documents: list[Document] | None = ev.get("documents") if documents is None: return None index = VectorStoreIndex.from_documents(documents) return StopEvent(result=index) @step async def prepare_for_retrieval( self, ctx: Context, ev: StartEvent ) -> PrepEvent | None: """准备检索。""" query_str: str | None = ev.get("query_str") retriever_kwargs: dict | None = ev.get("retriever_kwargs", {}) if query_str is None: return None tavily_ai_apikey: str | None = ev.get("tavily_ai_apikey") index = ev.get("index") llm = OpenAI(model="gpt-4") await ctx.set( "relevancy_pipeline", QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]), ) await ctx.set( "transform_query_pipeline", QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]), ) await ctx.set("llm", llm) await ctx.set("index", index) await ctx.set("tavily_tool", TavilyToolSpec(api_key=tavily_ai_apikey)) await ctx.set("query_str", query_str) await ctx.set("retriever_kwargs", retriever_kwargs) return PrepEvent() @step async def retrieve( self, ctx: Context, ev: PrepEvent ) -> RetrieveEvent | None: """检索查询相关的节点。""" query_str = await ctx.get("query_str") retriever_kwargs = await ctx.get("retriever_kwargs") if query_str is None: return None index = await ctx.get("index", default=None) tavily_tool = await ctx.get("tavily_tool", default=None) if not (index or tavily_tool): raise ValueError( "Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first." ) retriever: BaseRetriever = index.as_retriever(**retriever_kwargs) result = retriever.retrieve(query_str) await ctx.set("retrieved_nodes", result) await ctx.set("query_str", query_str) return RetrieveEvent(retrieved_nodes=result) @step async def eval_relevance( self, ctx: Context, ev: RetrieveEvent ) -> RelevanceEvalEvent: """评估检索到的文档与查询的相关性。""" retrieved_nodes = ev.retrieved_nodes query_str = await ctx.get("query_str") relevancy_results = [] for node in retrieved_nodes: relevancy_pipeline = await ctx.get("relevancy_pipeline") relevancy = relevancy_pipeline.run( context_str=node.text, query_str=query_str ) relevancy_results.append(relevancy.message.content.lower().strip()) await ctx.set("relevancy_results", relevancy_results) return RelevanceEvalEvent(relevant_results=relevancy_results) @step async def extract_relevant_texts( self, ctx: Context, ev: RelevanceEvalEvent ) -> TextExtractEvent: """从检索到的文档中提取相关的文本。""" retrieved_nodes = await ctx.get("retrieved_nodes") relevancy_results = ev.relevant_results relevant_texts = [ retrieved_nodes[i].text for i, result in enumerate(relevancy_results) if result == "yes" ] result = "\n".join(relevant_texts) return TextExtractEvent(relevant_text=result) @step async def transform_query_pipeline( self, ctx: Context, ev: TextExtractEvent ) -> QueryEvent: """使用 Tavily API 搜索转换后的查询。""" relevant_text = ev.relevant_text relevancy_results = await ctx.get("relevancy_results") query_str = await ctx.get("query_str") # 如果发现任何文档不相关,则转换查询字符串以获得更好的搜索结果。 if "no" in relevancy_results: qp = await ctx.get("transform_query_pipeline") transformed_query_str = qp.run(query_str=query_str).message.content # 使用转换后的查询字符串进行搜索并收集结果。 tavily_tool = await ctx.get("tavily_tool") search_results = tavily_tool.search( transformed_query_str, max_results=5 ) search_text = "\n".join([result.text for result in search_results]) else: search_text = "" return QueryEvent(relevant_text=relevant_text, search_text=search_text) @step async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent: """获取带有相关文本的结果。""" relevant_text = ev.relevant_text search_text = ev.search_text query_str = await ctx.get("query_str") documents = [Document(text=relevant_text + "\n" + search_text)] index = SummaryIndex.from_documents(documents) query_engine = index.as_query_engine() result = query_engine.query(query_str) return StopEvent(result=result)
运行工作流¶
In [ ]
已复制!
from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("./data").load_data()
workflow = CorrectiveRAGWorkflow()
index = await workflow.run(documents=documents)
from llama_index.core import SimpleDirectoryReader documents = SimpleDirectoryReader("./data").load_data() workflow = CorrectiveRAGWorkflow() index = await workflow.run(documents=documents)
In [ ]
已复制!
from IPython.display import Markdown, display
response = await workflow.run(
query_str="How was Llama2 pretrained?",
index=index,
tavily_ai_apikey=tavily_ai_api_key,
)
display(Markdown(str(response)))
from IPython.display import Markdown, display response = await workflow.run( query_str="How was Llama2 pretrained?", index=index, tavily_ai_apikey=tavily_ai_api_key, ) display(Markdown(str(response)))
Llama 2 使用优化的自回归 Transformer 进行预训练,并进行了一些修改以增强性能。这些修改包括更强大的数据清洗、更新的数据混合、在总 tokens 数量增加 40% 的数据上进行训练、上下文长度加倍,以及使用分组查询注意力 (GQA) 来提高大型模型的推理可扩展性。
In [ ]
已复制!
response = await workflow.run(
query_str="What is the functionality of latest ChatGPT memory."
)
display(Markdown(str(response)))
response = await workflow.run( query_str="What is the functionality of latest ChatGPT memory." ) display(Markdown(str(response)))
最新 ChatGPT 记忆的功能是自主记住它认为与对话相关的信息。此功能旨在避免用户重复信息,并使未来的对话更有帮助。用户可以控制聊天机器人的记忆,并根据需要访问和管理这些记忆。