使用 LlamaIndex + DuckDB 的 SQL 查询引擎¶
本指南展示了 LlamaIndex 与 DuckDB 集成的核心 SQL 功能。
我们将介绍一些 LlamaIndex 的核心数据结构,包括 `NLSQLTableQueryEngine` 和 `SQLTableRetrieverQueryEngine`。
注意: 任何文本到 SQL 应用都应注意,执行任意 SQL 查询可能存在安全风险。建议根据需要采取预防措施,例如使用受限角色、只读数据库、沙箱等。
如果您在 colab 中打开此 Notebook,您可能需要安装 LlamaIndex 🦙。
In [ ]
已复制!
%pip install llama-index-readers-wikipedia
%pip install llama-index-readers-wikipedia
In [ ]
已复制!
!pip install llama-index
!pip install llama-index
In [ ]
已复制!
!pip install duckdb duckdb-engine
!pip install duckdb duckdb-engine
In [ ]
已复制!
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
import logging import sys logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
In [ ]
已复制!
from llama_index.core import SQLDatabase, SimpleDirectoryReader, Document
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core import SQLDatabase, SimpleDirectoryReader, Document from llama_index.readers.wikipedia import WikipediaReader from llama_index.core.query_engine import NLSQLTableQueryEngine from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
In [ ]
已复制!
from IPython.display import Markdown, display
from IPython.display import Markdown, display
使用我们的 `NLSQLTableQueryEngine` 进行基本文本到 SQL¶
在此初始示例中,我们将逐步介绍如何使用一些测试数据点填充 SQL 数据库,并使用我们的文本到 SQL 功能进行查询。
创建数据库模式 + 测试数据¶
我们使用 sqlalchemy(一种流行的 SQL 数据库工具包)连接到 DuckDB 并创建一个空的 `city_stats` 表。然后我们用一些测试数据填充它。
In [ ]
已复制!
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
column,
)
from sqlalchemy import ( create_engine, MetaData, Table, Column, String, Integer, select, column, )
In [ ]
已复制!
engine = create_engine("duckdb:///:memory:")
# uncomment to make this work with MotherDuck
# engine = create_engine("duckdb:///md:llama-index")
metadata_obj = MetaData()
engine = create_engine("duckdb:///:memory:") # 取消注释以使其与 MotherDuck 一起工作 # engine = create_engine("duckdb:///md:llama-index") metadata_obj = MetaData()
In [ ]
已复制!
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
# 创建 city SQL 表 table_name = "city_stats" city_stats_table = Table( table_name, metadata_obj, Column("city_name", String(16), primary_key=True), Column("population", Integer), Column("country", String(16), nullable=False), ) metadata_obj.create_all(engine)
In [ ]
已复制!
# print tables
metadata_obj.tables.keys()
# 打印表 metadata_obj.tables.keys()
Out [ ]
dict_keys(['city_stats'])
我们将一些测试数据导入到 `city_stats` 表中
In [ ]
已复制!
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
from sqlalchemy import insert rows = [ {"city_name": "Toronto", "population": 2930000, "country": "Canada"}, {"city_name": "Tokyo", "population": 13960000, "country": "Japan"}, { "city_name": "Chicago", "population": 2679000, "country": "United States", }, {"city_name": "Seoul", "population": 9776000, "country": "South Korea"}, ] for row in rows: stmt = insert(city_stats_table).values(**row) with engine.begin() as connection: cursor = connection.execute(stmt)
In [ ]
已复制!
with engine.connect() as connection:
cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
print(cursor.fetchall())
with engine.connect() as connection: cursor = connection.exec_driver_sql("SELECT * FROM city_stats") print(cursor.fetchall())
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]
创建 SQLDatabase 对象¶
我们首先定义 SQLDatabase 抽象(SQLAlchemy 的一个轻量级包装器)。
In [ ]
已复制!
from llama_index.core import SQLDatabase
from llama_index.core import SQLDatabase
In [ ]
已复制!
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/duckdb_engine/__init__.py:162: DuckDBEngineWarning: duckdb-engine doesn't yet support reflection on indices warnings.warn(
查询索引¶
这里我们演示 `NLSQLTableQueryEngine` 的功能,它执行文本到 SQL 的转换。
- 我们构建一个 `NLSQLTableQueryEngine` 并传入我们的 SQL 数据库对象。
- 我们对查询引擎执行查询。
In [ ]
已复制!
query_engine = NLSQLTableQueryEngine(sql_database)
query_engine = NLSQLTableQueryEngine(sql_database)
In [ ]
已复制!
response = query_engine.query("Which city has the highest population?")
response = query_engine.query("Which city has the highest population?")
INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: . > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: .
/Users/jerryliu/Programming/gpt_index/.venv/lib/python3.10/site-packages/langchain/sql_database.py:238: UserWarning: This method is deprecated - please use `get_usable_table_names`. warnings.warn(
INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 332 tokens > [query] Total LLM token usage: 332 tokens INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens > [query] Total embedding token usage: 0 tokens
In [ ]
已复制!
str(response)
str(response)
Out [ ]
' Tokyo has the highest population, with 13,960,000 people.'
In [ ]
已复制!
response.metadata
response.metadata
Out [ ]
{'result': [('Tokyo', 13960000)], 'sql_query': 'SELECT city_name, population \nFROM city_stats \nORDER BY population DESC \nLIMIT 1;'}
使用我们的 `SQLTableRetrieverQueryEngine` 进行高级文本到 SQL¶
在本指南中,我们解决数据库表数量较多的情况,此时将所有表模式放入提示中可能会导致文本到 SQL 提示溢出。
我们首先使用 `ObjectIndex` 对模式进行索引,然后在其之上使用我们的 `SQLTableRetrieverQueryEngine` 抽象。
In [ ]
已复制!
engine = create_engine("duckdb:///:memory:")
# uncomment to make this work with MotherDuck
# engine = create_engine("duckdb:///md:llama-index")
metadata_obj = MetaData()
engine = create_engine("duckdb:///:memory:") # 取消注释以使其与 MotherDuck 一起工作 # engine = create_engine("duckdb:///md:llama-index") metadata_obj = MetaData()
In [ ]
已复制!
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
all_table_names = ["city_stats"]
# create a ton of dummy tables
n = 100
for i in range(n):
tmp_table_name = f"tmp_table_{i}"
tmp_table = Table(
tmp_table_name,
metadata_obj,
Column(f"tmp_field_{i}_1", String(16), primary_key=True),
Column(f"tmp_field_{i}_2", Integer),
Column(f"tmp_field_{i}_3", String(16), nullable=False),
)
all_table_names.append(f"tmp_table_{i}")
metadata_obj.create_all(engine)
# 创建 city SQL 表 table_name = "city_stats" city_stats_table = Table( table_name, metadata_obj, Column("city_name", String(16), primary_key=True), Column("population", Integer), Column("country", String(16), nullable=False), ) all_table_names = ["city_stats"] # 创建大量虚拟表 n = 100 for i in range(n): tmp_table_name = f"tmp_table_{i}" tmp_table = Table( tmp_table_name, metadata_obj, Column(f"tmp_field_{i}_1", String(16), primary_key=True), Column(f"tmp_field_{i}_2", Integer), Column(f"tmp_field_{i}_3", String(16), nullable=False), ) all_table_names.append(f"tmp_table_{i}") metadata_obj.create_all(engine)
In [ ]
已复制!
# insert dummy data
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
# 插入虚拟数据 from sqlalchemy import insert rows = [ {"city_name": "Toronto", "population": 2930000, "country": "Canada"}, {"city_name": "Tokyo", "population": 13960000, "country": "Japan"}, { "city_name": "Chicago", "population": 2679000, "country": "United States", }, {"city_name": "Seoul", "population": 9776000, "country": "South Korea"}, ] for row in rows: stmt = insert(city_stats_table).values(**row) with engine.begin() as connection: cursor = connection.execute(stmt)
In [ ]
已复制!
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
构建对象索引¶
In [ ]
已复制!
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core import VectorStoreIndex
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine from llama_index.core.objects import ( SQLTableNodeMapping, ObjectIndex, SQLTableSchema, ) from llama_index.core import VectorStoreIndex
In [ ]
已复制!
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = []
for table_name in all_table_names:
table_schema_objs.append(SQLTableSchema(table_name=table_name))
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
table_node_mapping = SQLTableNodeMapping(sql_database) table_schema_objs = [] for table_name in all_table_names: table_schema_objs.append(SQLTableSchema(table_name=table_name)) obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, VectorStoreIndex, )
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens > [build_index_from_nodes] Total LLM token usage: 0 tokens INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 6343 tokens > [build_index_from_nodes] Total embedding token usage: 6343 tokens
使用 `SQLTableRetrieverQueryEngine` 查询索引¶
In [ ]
已复制!
query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=1),
)
query_engine = SQLTableRetrieverQueryEngine( sql_database, obj_index.as_retriever(similarity_top_k=1), )
In [ ]
已复制!
response = query_engine.query("Which city has the highest population?")
response = query_engine.query("Which city has the highest population?")
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens > [retrieve] Total LLM token usage: 0 tokens INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 7 tokens > [retrieve] Total embedding token usage: 7 tokens INFO:llama_index.indices.struct_store.sql_query:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: . > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR), population (INTEGER), country (VARCHAR) and foreign keys: . INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 337 tokens > [query] Total LLM token usage: 337 tokens INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 0 tokens > [query] Total embedding token usage: 0 tokens
In [ ]
已复制!
response
response
Out [ ]
Response(response=' The city with the highest population is Tokyo, with a population of 13,960,000.', source_nodes=[], metadata={'result': [('Tokyo', 13960000)], 'sql_query': 'SELECT city_name, population \nFROM city_stats \nORDER BY population DESC \nLIMIT 1;'})