[Beta] 使用 PGVector 的文本转SQL¶
本 notebook 演示了如何使用 pgvector 执行文本转SQL。这使我们能够在 SQL 中同时进行语义搜索和结构化查询,所有操作都在 SQL 中进行!
这理论上可以实现比语义搜索 + 元数据过滤更具表达力的查询。
注意:这是一个测试版功能,接口可能会更改。但希望在此期间您能觉得它有用!
注意:任何文本转SQL 应用都应该意识到执行任意 SQL 查询可能存在安全风险。建议根据需要采取预防措施,例如使用受限角色、只读数据库、沙盒等。
设置数据¶
加载文档¶
加载 Lyft 2021 10k 文档。
输入 [ ]
已复制!
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
%pip install llama-index-embeddings-huggingface %pip install llama-index-readers-file %pip install llama-index-llms-openai
输入 [ ]
已复制!
from llama_index.readers.file import PDFReader
from llama_index.readers.file import PDFReader
输入 [ ]
已复制!
reader = PDFReader()
reader = PDFReader()
下载数据
输入 [ ]
已复制!
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
!mkdir -p 'data/10k/' !wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
输入 [ ]
已复制!
docs = reader.load_data("./data/10k/lyft_2021.pdf")
docs = reader.load_data("./data/10k/lyft_2021.pdf")
输入 [ ]
已复制!
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
from llama_index.core.node_parser import SentenceSplitter node_parser = SentenceSplitter() nodes = node_parser.get_nodes_from_documents(docs)
输入 [ ]
已复制!
print(nodes[8].get_content(metadata_mode="all"))
print(nodes[8].get_content(metadata_mode="all"))
将数据插入 Postgres + PGVector¶
请确保您已安装所有必需的依赖项!
输入 [ ]
已复制!
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
输入 [ ]
已复制!
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
from pgvector.sqlalchemy import Vector from sqlalchemy import insert, create_engine, String, text, Integer from sqlalchemy.orm import declarative_base, mapped_column
建立连接¶
输入 [ ]
已复制!
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
engine = create_engine("postgresql+psycopg2://localhost/postgres") with engine.connect() as conn: conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) conn.commit()
定义表模式¶
定义为 Python 类。注意,我们存储 page_label、embedding 和 text。
输入 [ ]
已复制!
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
Base = declarative_base() class SECTextChunk(Base): __tablename__ = "sec_text_chunk" id = mapped_column(Integer, primary_key=True) page_label = mapped_column(Integer) file_name = mapped_column(String) text = mapped_column(String) embedding = mapped_column(Vector(384))
输入 [ ]
已复制!
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
Base.metadata.drop_all(engine) Base.metadata.create_all(engine)
使用 sentence_transformers 模型为每个节点生成嵌入¶
输入 [ ]
已复制!
# get embeddings for each row
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
# 获取每行的嵌入 from llama_index.embeddings.huggingface import HuggingFaceEmbedding embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en") for node in nodes: text_embedding = embed_model.get_text_embedding(node.get_content()) node.embedding = text_embedding
插入数据库¶
输入 [ ]
已复制!
# insert into database
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
# 插入数据库 for node in nodes: row_dict = { "text": node.get_content(), "embedding": node.embedding, **node.metadata, } stmt = insert(SECTextChunk).values(**row_dict) with engine.connect() as connection: cursor = connection.execute(stmt) connection.commit()
定义 PGVectorSQLQueryEngine¶
现在我们已将数据加载到数据库中,准备好设置我们的查询引擎。
定义提示¶
我们创建了默认文本转SQL 提示的修改版本,以注入对 pgvector 语法的感知。我们还通过一些少样本示例来提示它如何使用该语法 (<-->)。
注意:这默认包含在 PGVectorSQLQueryEngine
中,我们在此处包含它主要是为了清晰起见!
输入 [ ]
已复制!
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
You can order the results by a relevant column to return the most \
interesting examples in the database.
Pay attention to use only the column names that you can see in the schema \
description. Be careful to not query for columns that do not exist. \
Pay attention to which column is in which table. Also, qualify column names \
with the table name when needed.
IMPORTANT NOTE: you can use specialized pgvector syntax (`<->`) to do nearest \
neighbors/semantic search to a given vector from an embeddings column in the table. \
The embeddings value for a given row typically represents the semantic meaning of that row. \
The vector represents an embedding representation \
of the question, given below. Do NOT fill in the vector values directly, but rather specify a \
`[query_vector]` placeholder. For instance, some select statement examples below \
(the name of the embeddings column is `embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;
You are required to use the following format, \
each taking one line:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use tables listed below.
{schema}
Question: {query_str}
SQLQuery: \
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
from llama_index.core import PromptTemplate text_to_sql_tmpl = """\ 给定一个输入问题,首先创建一个语法正确的 {dialect} \ 要运行的查询,然后查看查询结果并返回答案。\ 您可以按相关列排序结果,以返回数据库中最有用的示例。\ 注意只使用您在模式描述中可以看到的列名。\ 注意不要查询不存在的列。\ 注意哪个列在哪个表中。另外,在需要时使用表名限定列名。\ 重要说明:您可以使用专门的 pgvector 语法(`<->`)对表中 embeddings 列中的给定向量进行最近邻搜索/语义搜索。\ 给定行中的嵌入值通常代表该行的语义含义。\ 该向量代表问题的嵌入表示,如下所示。不要直接填写向量值,而是指定一个 `[query_vector]` 占位符。\ 例如,下面是一些 SELECT 语句示例(嵌入列名为 `embedding`):\ SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;\ SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;\ SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;\ 您需要使用以下格式,每行一个:\ Question: Question here\ SQLQuery: 要运行的 SQL 查询\ SQLResult: SQLQuery 的结果\ Answer: 最终答案\ 仅使用下面列出的表。\ {schema}\ Question: {query_str}\ SQLQuery: \ """ text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
设置 LLM、嵌入模型及其他¶
除了 LLM 和嵌入模型,注意我们还在表本身上添加了注释。这能更好地帮助 LLM 理解列模式(例如,告诉它嵌入列代表什么),以便更好地进行表格查询或语义搜索。
输入 [ ]
已复制!
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4")
Settings.embed_model = embed_model
table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:
id: id of row
page_label: page number
file_name: top-level file name
text: all text chunk is here
embedding: the embeddings representing the text chunk
For most queries you should perform semantic search against the `embedding` column values, since \
that encodes the meaning of the text.
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
from llama_index.core import SQLDatabase from llama_index.llms.openai import OpenAI from llama_index.core.query_engine import PGVectorSQLQueryEngine from llama_index.core import Settings sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"]) Settings.llm = OpenAI(model="gpt-4") Settings.embed_model = embed_model table_desc = """\ 此表代表来自 SEC 文件的文本块。每行包含以下列: id: 行 ID page_label: 页码 file_name: 顶级文件名 text: 所有文本块在此处 embedding: 代表文本块的嵌入 对于大多数查询,您应该对 `embedding` 列的值执行语义搜索,因为 \ 它编码了文本的含义。 """ context_query_kwargs = {"sec_text_chunk": table_desc}
定义查询引擎¶
输入 [ ]
已复制!
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
query_engine = PGVectorSQLQueryEngine( sql_database=sql_database, text_to_sql_prompt=text_to_sql_prompt, context_query_kwargs=context_query_kwargs, )
运行一些查询¶
现在我们准备运行一些查询
输入 [ ]
已复制!
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
response = query_engine.query( "Can you tell me about the risk factors described in page 6?", )
输入 [ ]
已复制!
print(str(response))
print(str(response))
Page 6 discusses the impact of the COVID-19 pandemic on the business. It mentions that the pandemic has affected communities in the United States, Canada, and globally. The pandemic has led to a significant decrease in the demand for ridesharing services, which has negatively impacted the company's financial performance. The page also discusses the company's efforts to adapt to the changing environment by focusing on the delivery of essential goods and services. Additionally, it mentions the company's transportation network, which offers riders seamless, personalized, and on-demand access to a variety of mobility options.
输入 [ ]
已复制!
print(response.metadata["sql_query"])
print(response.metadata["sql_query"])
输入 [ ]
已复制!
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
response = query_engine.query( "Tell me more about Lyft's real estate operating leases", )
输入 [ ]
已复制!
print(str(response))
print(str(response))
Lyft's lease arrangements include vehicle rental programs, office space, and data centers. Leases that do not meet any specific criteria are accounted for as operating leases. The lease term begins when Lyft is available to use the underlying asset and ends upon the termination of the lease. The lease term includes any periods covered by an option to extend if Lyft is reasonably certain to exercise that option. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.
输入 [ ]
已复制!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT * FROM sec_text_chunk WHERE text LIKE '%Lyft%' AND text LIKE '%real estate%' AND text LIKE '%operating leases%' ORDER BY embedding <-> '[-0.007079003844410181, -0.04383348673582077, 0.02910166047513485, 0.02049737051129341, 0.009460929781198502, -0.017539210617542267, 0.04225028306245804, 0.0
输入 [ ]
已复制!
# looked at returned result
print(response.metadata["result"])
# 查看返回结果 print(response.metadata["result"])
[(157, 93, 'lyft_2021.pdf', "Leases that do not meet any of the above criteria are accounted for as operating leases.Lessor\nThe\n Company's lease arrangements include vehicle re ... (4356 characters truncated) ... realized. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.", '[0.017818017,-0.024016099,0.0042511695,0.03114478,0.003591422,-0.0097886855,0.02455732,0.013048866,0.018157514,-0.009401044,0.031699456,0.01678178,0. ... (4472 characters truncated) ... 6,0.01127416,0.045080125,-0.017046565,-0.028544193,-0.016320521,0.01062995,-0.021007432,-0.006999497,-0.08426073,-0.014918887,0.059064835,0.03307945]')]
输入 [ ]
已复制!
# structured query
response = query_engine.query(
"Tell me about the max page number in this table",
)
# 结构化查询 response = query_engine.query( "Tell me about the max page number in this table", )
输入 [ ]
已复制!
print(str(response))
print(str(response))
The maximum page number in this table is 238.
输入 [ ]
已复制!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT MAX(page_label) FROM sec_text_chunk;