半结构化图像检索¶
在本 Notebook 中,我们将展示如何对图像执行半结构化检索。
给定一组图像,我们可以使用 Gemini Pro Vision 从中推断出结构化输出。
然后我们可以将这些结构化输出索引到向量数据库中。接着,我们利用 语义搜索 + 元数据过滤功能以及 自动检索 的优势:这使我们能够对这些数据同时进行结构化和语义查询!
(另一种方法是将这些数据放入 SQL 数据库,允许进行文本转 SQL。这些技术非常相关)。
In []
已复制!
%pip install llama-index-multi-modal-llms-gemini
%pip install llama-index-vector-stores-qdrant
%pip install llama-index-embeddings-gemini
%pip install llama-index-llms-gemini
%pip install llama-index-multi-modal-llms-gemini %pip install llama-index-vector-stores-qdrant %pip install llama-index-embeddings-gemini %pip install llama-index-llms-gemini
In []
已复制!
!pip install llama-index 'google-generativeai>=0.3.0' matplotlib qdrant_client
!pip install llama-index 'google-generativeai>=0.3.0' matplotlib qdrant_client
设置¶
获取 Google API Key¶
In []
已复制!
import os
GOOGLE_API_KEY = "" # add your GOOGLE API key here
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
import os GOOGLE_API_KEY = "" # 在此处添加您的 GOOGLE API 密钥 os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
获取图像文件¶
图像下载完成后,我们可以获取文件名列表。
In []
已复制!
from pathlib import Path
import random
from typing import Optional
from pathlib import Path import random from typing import Optional
In []
已复制!
def get_image_files(
dir_path, sample: Optional[int] = 10, shuffle: bool = False
):
dir_path = Path(dir_path)
image_paths = []
for image_path in dir_path.glob("*.jpg"):
image_paths.append(image_path)
random.shuffle(image_paths)
if sample:
return image_paths[:sample]
else:
return image_paths
def get_image_files( dir_path, sample: Optional[int] = 10, shuffle: bool = False ): dir_path = Path(dir_path) image_paths = [] for image_path in dir_path.glob("*.jpg"): image_paths.append(image_path) random.shuffle(image_paths) if sample: return image_paths[:sample] else: return image_paths
In []
已复制!
image_files = get_image_files("SROIE2019/test/img", sample=100)
image_files = get_image_files("SROIE2019/test/img", sample=100)
使用 Gemini 提取结构化输出¶
在这里,我们使用 Gemini 提取结构化输出。
- 定义一个 ReceiptInfo pydantic 类,用于捕获我们想要提取的结构化输出。我们提取的字段包括
company
、date
、total
,以及summary
。 - 定义一个
pydantic_gemini
函数,该函数将输入文档转换为响应。
定义一个 ReceiptInfo pydantic 类¶
In []
已复制!
from pydantic import BaseModel, Field
class ReceiptInfo(BaseModel):
company: str = Field(..., description="Company name")
date: str = Field(..., description="Date field in DD/MM/YYYY format")
address: str = Field(..., description="Address")
total: float = Field(..., description="total amount")
currency: str = Field(
..., description="Currency of the country (in abbreviations)"
)
summary: str = Field(
...,
description="Extracted text summary of the receipt, including items purchased, the type of store, the location, and any other notable salient features (what does the purchase seem to be for?).",
)
from pydantic import BaseModel, Field class ReceiptInfo(BaseModel): company: str = Field(..., description="公司名称") date: str = Field(..., description="日期字段,格式为 DD/MM/YYYY") address: str = Field(..., description="地址") total: float = Field(..., description="总金额") currency: str = Field( ..., description="国家货币(缩写)" ) summary: str = Field( ..., description="提取的收据文本摘要,包括购买的物品、商店类型、位置以及其他任何显著特点(购买似乎是用于什么?)。", )
定义一个 pydantic_gemini
函数¶
In []
已复制!
from llama_index.multi_modal_llms.gemini import GeminiMultiModal
from llama_index.core.program import MultiModalLLMCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser
prompt_template_str = """\
Can you summarize the image and return a response \
with the following JSON format: \
"""
async def pydantic_gemini(output_class, image_documents, prompt_template_str):
gemini_llm = GeminiMultiModal(
api_key=GOOGLE_API_KEY, model_name="models/gemini-pro-vision"
)
llm_program = MultiModalLLMCompletionProgram.from_defaults(
output_parser=PydanticOutputParser(output_class),
image_documents=image_documents,
prompt_template_str=prompt_template_str,
multi_modal_llm=gemini_llm,
verbose=True,
)
response = await llm_program.acall()
return response
from llama_index.multi_modal_llms.gemini import GeminiMultiModal from llama_index.core.program import MultiModalLLMCompletionProgram from llama_index.core.output_parsers import PydanticOutputParser prompt_template_str = """\ 能否总结图像并以以下 JSON 格式返回响应:\ \ """ async def pydantic_gemini(output_class, image_documents, prompt_template_str): gemini_llm = GeminiMultiModal( api_key=GOOGLE_API_KEY, model_name="models/gemini-pro-vision" ) llm_program = MultiModalLLMCompletionProgram.from_defaults( output_parser=PydanticOutputParser(output_class), image_documents=image_documents, prompt_template_str=prompt_template_str, multi_modal_llm=gemini_llm, verbose=True, ) response = await llm_program.acall() return response
处理图像¶
In []
已复制!
from llama_index.core import SimpleDirectoryReader
from llama_index.core.async_utils import run_jobs
async def aprocess_image_file(image_file):
# should load one file
print(f"Image file: {image_file}")
img_docs = SimpleDirectoryReader(input_files=[image_file]).load_data()
output = await pydantic_gemini(ReceiptInfo, img_docs, prompt_template_str)
return output
async def aprocess_image_files(image_files):
"""Process metadata on image files."""
new_docs = []
tasks = []
for image_file in image_files:
task = aprocess_image_file(image_file)
tasks.append(task)
outputs = await run_jobs(tasks, show_progress=True, workers=5)
return outputs
from llama_index.core import SimpleDirectoryReader from llama_index.core.async_utils import run_jobs async def aprocess_image_file(image_file): # 应该加载一个文件 print(f"图像文件: {image_file}") img_docs = SimpleDirectoryReader(input_files=[image_file]).load_data() output = await pydantic_gemini(ReceiptInfo, img_docs, prompt_template_str) return output async def aprocess_image_files(image_files): """处理图像文件的元数据。""" new_docs = [] tasks = [] for image_file in image_files: task = aprocess_image_file(image_file) tasks.append(task) outputs = await run_jobs(tasks, show_progress=True, workers=5) return outputs
In []
已复制!
outputs = await aprocess_image_files(image_files)
outputs = await aprocess_image_files(image_files)
In []
已复制!
outputs[4]
outputs[4]
Out[ ]
ReceiptInfo(company='KEDAI BUKU NEW ACHIEVERS', date='15/09/2017', address='NO. 12 & 14, JALAN HIJAUAN JINANG 27/54 TAMAN ALAM MEGAH, SEKSYEN 27 40400 SHAH ALAM, SELANGOR D. E.', total=48.0, currency='MYR', summary='Purchase of books and school supplies at a bookstore.')
将结构化表示转换为 TextNode
对象¶
Node 对象是 LlamaIndex 中向量存储中索引的核心单元。我们定义了一个简单的转换函数,将 ReceiptInfo
对象映射到 TextNode
对象。
In []
已复制!
from llama_index.core.schema import TextNode
from typing import List
def get_nodes_from_objs(
objs: List[ReceiptInfo], image_files: List[str]
) -> TextNode:
"""Get nodes from objects."""
nodes = []
for image_file, obj in zip(image_files, objs):
node = TextNode(
text=obj.summary,
metadata={
"company": obj.company,
"date": obj.date,
"address": obj.address,
"total": obj.total,
"currency": obj.currency,
"image_file": str(image_file),
},
excluded_embed_metadata_keys=["image_file"],
excluded_llm_metadata_keys=["image_file"],
)
nodes.append(node)
return nodes
from llama_index.core.schema import TextNode from typing import List def get_nodes_from_objs( objs: List[ReceiptInfo], image_files: List[str] ) -> TextNode: """从对象获取节点。""" nodes = [] for image_file, obj in zip(image_files, objs): node = TextNode( text=obj.summary, metadata={ "company": obj.company, "date": obj.date, "address": obj.address, "total": obj.total, "currency": obj.currency, "image_file": str(image_file), }, excluded_embed_metadata_keys=["image_file"], excluded_llm_metadata_keys=["image_file"], ) nodes.append(node) return nodes
In []
已复制!
nodes = get_nodes_from_objs(outputs, image_files)
nodes = get_nodes_from_objs(outputs, image_files)
In []
已复制!
print(nodes[0].get_content(metadata_mode="all"))
print(nodes[0].get_content(metadata_mode="all"))
company: UNIHAIKKA INTERNATIONAL SDN BHD date: 13/09/2018 address: 12, Jalan Tampoi 7/4, Kawasan Perindustrian Tampoi, 81200 Johor Bahru, Johor total: 8.85 currency: MYR image_file: SROIE2019/test/img/X51007846371.jpg The receipt is from a restaurant called Bar Wang Rice. The total amount is 8.85 MYR. The items purchased include chicken, vegetables, and a drink.
在向量存储中索引这些节点¶
In []
已复制!
import qdrant_client
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.llms.gemini import Gemini
from llama_index.core import Settings
# Create a local Qdrant vector store
client = qdrant_client.QdrantClient(path="qdrant_gemini")
vector_store = QdrantVectorStore(client=client, collection_name="collection")
# global settings
Settings.embed_model = GeminiEmbedding(
model_name="models/embedding-001", api_key=GOOGLE_API_KEY
)
Settings.llm = (Gemini(api_key=GOOGLE_API_KEY),)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
)
import qdrant_client from llama_index.vector_stores.qdrant import QdrantVectorStore from llama_index.core import StorageContext from llama_index.core import VectorStoreIndex from llama_index.embeddings.gemini import GeminiEmbedding from llama_index.llms.gemini import Gemini from llama_index.core import Settings # 创建一个本地 Qdrant 向量存储 client = qdrant_client.QdrantClient(path="qdrant_gemini") vector_store = QdrantVectorStore(client=client, collection_name="collection") # 全局设置 Settings.embed_model = GeminiEmbedding( model_name="models/embedding-001", api_key=GOOGLE_API_KEY ) Settings.llm = (Gemini(api_key=GOOGLE_API_KEY),) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = VectorStoreIndex( nodes=nodes, storage_context=storage_context, )
In []
已复制!
from llama_index.core.vector_stores import MetadataInfo, VectorStoreInfo
vector_store_info = VectorStoreInfo(
content_info="Receipts",
metadata_info=[
MetadataInfo(
name="company",
description="The name of the store",
type="string",
),
MetadataInfo(
name="address",
description="The address of the store",
type="string",
),
MetadataInfo(
name="date",
description="The date of the purchase (in DD/MM/YYYY format)",
type="string",
),
MetadataInfo(
name="total",
description="The final amount",
type="float",
),
MetadataInfo(
name="currency",
description="The currency of the country the purchase was made (abbreviation)",
type="string",
),
],
)
from llama_index.core.vector_stores import MetadataInfo, VectorStoreInfo vector_store_info = VectorStoreInfo( content_info="收据", metadata_info=[ MetadataInfo( name="company", description="商店名称", type="string", ), MetadataInfo( name="address", description="商店地址", type="string", ), MetadataInfo( name="date", description="购买日期(格式为 DD/MM/YYYY)", type="string", ), MetadataInfo( name="total", description="最终金额", type="float", ), MetadataInfo( name="currency", description="购买发生国家的货币(缩写)", type="string", ), ], )
In []
已复制!
from llama_index.core.retrievers import VectorIndexAutoRetriever
retriever = VectorIndexAutoRetriever(
index,
vector_store_info=vector_store_info,
similarity_top_k=2,
empty_query_top_k=10, # if only metadata filters are specified, this is the limit
verbose=True,
)
from llama_index.core.retrievers import VectorIndexAutoRetriever retriever = VectorIndexAutoRetriever( index, vector_store_info=vector_store_info, similarity_top_k=2, empty_query_top_k=10, # 如果仅指定了元数据过滤器,这是限制 verbose=True, )
In []
已复制!
# from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from IPython.display import Image
def display_response(nodes: List[TextNode]):
"""Display response."""
for node in nodes:
print(node.get_content(metadata_mode="all"))
# img = Image.open(open(node.metadata["image_file"], 'rb'))
display(Image(filename=node.metadata["image_file"], width=200))
# from PIL import Image import requests from io import BytesIO import matplotlib.pyplot as plt from IPython.display import Image def display_response(nodes: List[TextNode]): """显示响应。""" for node in nodes: print(node.get_content(metadata_mode="all")) # img = Image.open(open(node.metadata["image_file"], 'rb')) display(Image(filename=node.metadata["image_file"], width=200))
运行一些查询¶
让我们尝试不同类型的查询!
In []
已复制!
nodes = retriever.retrieve(
"Tell me about some restaurant orders of noodles with total < 25"
)
display_response(nodes)
nodes = retriever.retrieve( "告诉我一些总金额小于 25 的面条餐厅订单" ) display_response(nodes)
Using query str: restaurant orders of noodles Using filters: [('total', '<', 25)] company: Restoran Wan Sheng date: 23-03-2018 address: No. 2, Jalan Temenggung 19/9, Seksyen 9, Bandar Mahkota Cheras, 43200 Cheras, Selangor total: 6.7 currency: MYR image_file: SROIE2019/test/img/X51005711443.jpg Teh (B), Cham (B), Bunga Kekwa, Take Away
company: UNIHAIKKA INTERNATIONAL SDN BHD date: 19/06/2018 address: 12, Jalan Tampoi 7/4, Kawasan Perindustrian Tampoi 81200 Johor Bahru, Johor total: 8.45 currency: MYR image_file: SROIE2019/test/img/X51007846392.jpg The receipt is from a restaurant called Bar Wang Rice. The total amount is 8.45 MYR. The items purchased include 1 plate of fried noodles, 1 plate of chicken, and 1 plate of vegetables.
In []
已复制!
nodes = retriever.retrieve("Tell me about some grocery purchases")
display_response(nodes)
nodes = retriever.retrieve("告诉我一些杂货购买信息") display_response(nodes)
Using query str: grocery purchases Using filters: [] company: GARDENIA BAKERIES (KL) SDN BHD date: 24/09/2017 address: LOT 3, JALAN PELABUR 23/1, 40300 SHAH ALAM, SELANGOR total: 38.55 currency: RM image_file: SROIE2019/test/img/X51006556829.jpg Purchase of groceries from a supermarket.
company: Segi Cash & Carry Sdn. Bhd date: 02/02/2017 address: PT17920, SEKSYEN U9, 40150 SHAH ALAM, SELANGOR DARUL EHSAN total: 27.0 currency: RM image_file: SROIE2019/test/img/X51006335818.jpg Purchase of groceries at Segi Cash & Carry Sdn. Bhd. on 02/02/2017. The total amount of the purchase is RM27.