用于结构化输出的反射工作流程¶
本笔记将逐步介绍如何设置一个 Workflow
,以通过重试和错误反思来提供可靠的结构化输出。
本笔记最适合与开源 LLM 一起使用,因此我们将使用 Ollama
。如果您尚未运行 Ollama,请访问 https://ollama.ac.cn 开始使用并下载您想要使用的模型。(在本例中,我们在运行本笔记之前执行了 ollama pull llama3.1
)。
In [ ]
已复制!
!pip install -U llama-index llama-index-llms-ollama
!pip install -U llama-index llama-index-llms-ollama
由于工作流程是异步优先的,因此这一切在笔记中运行良好。如果您在自己的代码中运行,如果异步事件循环尚未运行,您需要使用 asyncio.run()
来启动它。
async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())
In [ ]
已复制!
from llama_index.core.workflow import Event
class ExtractionDone(Event):
output: str
passage: str
class ValidationErrorEvent(Event):
error: str
wrong_output: str
passage: str
from llama_index.core.workflow import Event class ExtractionDone(Event): output: str passage: str class ValidationErrorEvent(Event): error: str wrong_output: str passage: str
要提取的项目¶
为了提示我们的模型,让我们定义一个我们要提取的 pydantic 模型。
In [ ]
已复制!
from pydantic import BaseModel
class Car(BaseModel):
brand: str
model: str
power: int
class CarCollection(BaseModel):
cars: list[Car]
from pydantic import BaseModel class Car(BaseModel): brand: str model: str power: int class CarCollection(BaseModel): cars: list[Car]
In [ ]
已复制!
import json
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
Context,
step,
)
from llama_index.llms.ollama import Ollama
EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------
Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}
"""
REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------
This caused the JSON decode error: {error}
Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""
class ReflectionWorkflow(Workflow):
max_retries: int = 3
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
current_retries = await ctx.get("retries", default=0)
if current_retries >= self.max_retries:
return StopEvent(result="Max retries reached")
else:
await ctx.set("retries", current_retries + 1)
if isinstance(ev, StartEvent):
passage = ev.get("passage")
if not passage:
return StopEvent(result="Please provide some text in input")
reflection_prompt = ""
elif isinstance(ev, ValidationErrorEvent):
passage = ev.passage
reflection_prompt = REFLECTION_PROMPT.format(
wrong_answer=ev.wrong_output, error=ev.error
)
llm = Ollama(model="llama3", request_timeout=30)
prompt = EXTRACTION_PROMPT.format(
passage=passage, schema=CarCollection.schema_json()
)
if reflection_prompt:
prompt += reflection_prompt
output = await llm.acomplete(prompt)
return ExtractionDone(output=str(output), passage=passage)
@step
async def validate(
self, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
try:
CarCollection.model_validate_json(ev.output)
except Exception as e:
print("Validation failed, retrying...")
return ValidationErrorEvent(
error=str(e), wrong_output=ev.output, passage=ev.passage
)
return StopEvent(result=ev.output)
import json from llama_index.core.workflow import ( Workflow, StartEvent, StopEvent, Context, step, ) from llama_index.llms.ollama import Ollama EXTRACTION_PROMPT = """ Context information is below: --------------------- {passage} --------------------- Given the context information and not prior knowledge, create a JSON object from the information in the context. The JSON object must follow the JSON schema: {schema} """ REFLECTION_PROMPT = """ You already created this output previously: --------------------- {wrong_answer} --------------------- This caused the JSON decode error: {error} Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object. Do not repeat the schema. """ class ReflectionWorkflow(Workflow): max_retries: int = 3 @step async def extract( self, ctx: Context, ev: StartEvent | ValidationErrorEvent ) -> StopEvent | ExtractionDone: current_retries = await ctx.get("retries", default=0) if current_retries >= self.max_retries: return StopEvent(result="Max retries reached") else: await ctx.set("retries", current_retries + 1) if isinstance(ev, StartEvent): passage = ev.get("passage") if not passage: return StopEvent(result="Please provide some text in input") reflection_prompt = "" elif isinstance(ev, ValidationErrorEvent): passage = ev.passage reflection_prompt = REFLECTION_PROMPT.format( wrong_answer=ev.wrong_output, error=ev.error ) llm = Ollama(model="llama3", request_timeout=30) prompt = EXTRACTION_PROMPT.format( passage=passage, schema=CarCollection.schema_json() ) if reflection_prompt: prompt += reflection_prompt output = await llm.acomplete(prompt) return ExtractionDone(output=str(output), passage=passage) @step async def validate( self, ev: ExtractionDone ) -> StopEvent | ValidationErrorEvent: try: CarCollection.model_validate_json(ev.output) except Exception as e: print("Validation failed, retrying...") return ValidationErrorEvent( error=str(e), wrong_output=ev.output, passage=ev.passage ) return StopEvent(result=ev.output)
就是这样!让我们稍微探索一下我们编写的工作流程。
- 我们有一个入口点,
extract
(接受StartEvent
的步骤) - 当
extract
完成时,它会发出一个ExtractionDone
事件 validate
运行并确认提取结果- 如果没问题,它会发出
StopEvent
并停止工作流程 - 如果不是,它会返回一个包含错误信息的
ValidationErrorEvent
- 如果没问题,它会发出
- 任何发出的
ValidationErrorEvent
都会触发循环,并且extract
将再次运行! - 这会一直持续到结构化输出被验证为止
运行工作流程!¶
注意:使用循环时,我们需要注意运行时长。这里,我们将超时设置为 120 秒。
In [ ]
已复制!
w = ReflectionWorkflow(timeout=120, verbose=True)
# Run the workflow
ret = await w.run(
passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)
w = ReflectionWorkflow(timeout=120, verbose=True) # Run the workflow ret = await w.run( passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp." )
Running step extract Step extract produced event ExtractionDone Running step validate Validation failed, retrying... Step validate produced event ValidationErrorEvent Running step extract Step extract produced event ExtractionDone Running step validate Step validate produced event StopEvent
In [ ]
已复制!
print(ret)
print(ret)
{ "cars": [ { "brand": "Fiat", "model": "Panda", "power": 45 }, { "brand": "Honda", "model": "Civic", "power": 330 } ] }