跳到内容

Sql

NLSQLRetriever #

基类: BaseRetriever, PromptMixin

文本到 SQL 检索器。

通过文本进行检索。

参数

名称 类型 描述 默认值
sql_database SQLDatabase

SQL 数据库。

必需
text_to_sql_prompt BasePromptTemplate

用于文本到 SQL 的提示模板。默认为 DEFAULT_TEXT_TO_SQL_PROMPT。

工作流运行检查点
context_query_kwargs dict

从表名到上下文查询的映射。默认为 None。

工作流运行检查点
tables Union[List[str], List[Table]]

表名列表或 Table 对象列表。

工作流运行检查点
table_retriever ObjectRetriever[SQLTableSchema]

用于 SQLTableSchema 对象的对象检索器。默认为 None。

工作流运行检查点
rows_retriever Dict[str, VectorIndexRetriever]

表名与其行的向量索引检索器之间的映射。默认为 None。

必需
context_str_prefix str

上下文字符串的前缀。默认为 None。

工作流运行检查点
return_raw bool

是否返回 SQL 结果的纯文本转储,或解析为节点。

True
handle_sql_errors bool

是否处理 SQL 错误。默认为 True。

True
sql_only bool)

是否只获取 SQL 语句而不获取 SQL 查询结果。默认为 False。

False
llm Optional[LLM]

要使用的语言模型。

工作流运行检查点
源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class NLSQLRetriever(BaseRetriever, PromptMixin):
    """
    Text-to-SQL Retriever.

    Retrieves via text.

    Args:
        sql_database (SQLDatabase): SQL database.
        text_to_sql_prompt (BasePromptTemplate): Prompt template for text-to-sql.
            Defaults to DEFAULT_TEXT_TO_SQL_PROMPT.
        context_query_kwargs (dict): Mapping from table name to context query.
            Defaults to None.
        tables (Union[List[str], List[Table]]): List of table names or Table objects.
        table_retriever (ObjectRetriever[SQLTableSchema]): Object retriever for
            SQLTableSchema objects. Defaults to None.
        rows_retriever (Dict[str, VectorIndexRetriever]): a mapping between table name and
            a vector index retriever of its rows. Defaults to None.
        context_str_prefix (str): Prefix for context string. Defaults to None.
        return_raw (bool): Whether to return plain-text dump of SQL results, or parsed into Nodes.
        handle_sql_errors (bool): Whether to handle SQL errors. Defaults to True.
        sql_only (bool) : Whether to get only sql and not the sql query result.
            Default to False.
        llm (Optional[LLM]): Language model to use.

    """

    def __init__(
        self,
        sql_database: SQLDatabase,
        text_to_sql_prompt: Optional[BasePromptTemplate] = None,
        context_query_kwargs: Optional[dict] = None,
        tables: Optional[Union[List[str], List[Table]]] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
        rows_retrievers: Optional[dict[str, BaseRetriever]] = None,
        context_str_prefix: Optional[str] = None,
        sql_parser_mode: SQLParserMode = SQLParserMode.DEFAULT,
        llm: Optional[LLM] = None,
        embed_model: Optional[BaseEmbedding] = None,
        return_raw: bool = True,
        handle_sql_errors: bool = True,
        sql_only: bool = False,
        callback_manager: Optional[CallbackManager] = None,
        verbose: bool = False,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        self._sql_retriever = SQLRetriever(sql_database, return_raw=return_raw)
        self._sql_database = sql_database
        self._get_tables = self._load_get_tables_fn(
            sql_database, tables, context_query_kwargs, table_retriever
        )
        self._context_str_prefix = context_str_prefix
        self._llm = llm or Settings.llm
        self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
        self._sql_parser_mode = sql_parser_mode

        embed_model = embed_model or Settings.embed_model
        self._sql_parser = self._load_sql_parser(sql_parser_mode, embed_model)
        self._handle_sql_errors = handle_sql_errors
        self._sql_only = sql_only
        self._verbose = verbose

        # To retrieve relevant rows from each retrieved table
        self._rows_retrievers = rows_retrievers
        super().__init__(
            callback_manager=callback_manager or Settings.callback_manager,
            verbose=verbose,
        )

    def _get_prompts(self) -> Dict[str, Any]:
        """Get prompts."""
        return {
            "text_to_sql_prompt": self._text_to_sql_prompt,
        }

    def _update_prompts(self, prompts: PromptDictType) -> None:
        """Update prompts."""
        if "text_to_sql_prompt" in prompts:
            self._text_to_sql_prompt = prompts["text_to_sql_prompt"]

    def _get_prompt_modules(self) -> PromptMixinType:
        """Get prompt modules."""
        return {}

    def _load_sql_parser(
        self, sql_parser_mode: SQLParserMode, embed_model: BaseEmbedding
    ) -> BaseSQLParser:
        """Load SQL parser."""
        if sql_parser_mode == SQLParserMode.DEFAULT:
            return DefaultSQLParser()
        elif sql_parser_mode == SQLParserMode.PGVECTOR:
            return PGVectorSQLParser(embed_model=embed_model)
        else:
            raise ValueError(f"Unknown SQL parser mode: {sql_parser_mode}")

    def _load_get_tables_fn(
        self,
        sql_database: SQLDatabase,
        tables: Optional[Union[List[str], List[Table]]] = None,
        context_query_kwargs: Optional[dict] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
    ) -> Callable[[str], List[SQLTableSchema]]:
        """Load get_tables function."""
        context_query_kwargs = context_query_kwargs or {}
        if table_retriever is not None:
            return lambda query_str: cast(Any, table_retriever).retrieve(query_str)
        else:
            if tables is not None:
                table_names: List[str] = [
                    t.name if isinstance(t, Table) else t for t in tables
                ]
            else:
                table_names = list(sql_database.get_usable_table_names())
            context_strs = [context_query_kwargs.get(t, None) for t in table_names]
            table_schemas = [
                SQLTableSchema(table_name=t, context_str=c)
                for t, c in zip(table_names, context_strs)
            ]
            return lambda _: table_schemas

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")
        if self._verbose:
            print(f"> Table desc str: {table_desc_str}")

        response_str = self._llm.predict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")
        if self._verbose:
            print(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata = {"result": sql_query_str}
        else:
            try:
                retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                    sql_query_str
                )
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise

        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Async retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")

        response_str = await self._llm.apredict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata: Dict[str, Any] = {}
        else:
            try:
                (
                    retrieved_nodes,
                    metadata,
                ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise
        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Async retrieve nodes given query."""
        retrieved_nodes, _ = await self.aretrieve_with_metadata(query_bundle)
        return retrieved_nodes

    def _get_table_context(self, query_bundle: QueryBundle) -> str:
        """Get table context string."""
        table_schema_objs = self._get_tables(query_bundle.query_str)
        context_strs = []
        for table_schema_obj in table_schema_objs:
            # first append table info + additional context
            table_info = self._sql_database.get_single_table_info(
                table_schema_obj.table_name
            )
            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context

            # also lookup vector index to return relevant table rows
            # if rows_retrievers was not passed, no rows will be returned
            if self._rows_retrievers is not None:
                rows_retriever = self._rows_retrievers[table_schema_obj.table_name]
                relevant_nodes = rows_retriever.retrieve(query_bundle.query_str)
                if len(relevant_nodes) > 0:
                    table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
                    for node in relevant_nodes:
                        table_row_context += str(node.get_content()) + "\n"
                    table_info += table_row_context

            if self._verbose:
                print(f"> Table Info: {table_info}")
            context_strs.append(table_info)

        return "\n\n".join(context_strs)

retrieve_with_metadata #

retrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

使用元数据检索。

源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")
    if self._verbose:
        print(f"> Table desc str: {table_desc_str}")

    response_str = self._llm.predict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")
    if self._verbose:
        print(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata = {"result": sql_query_str}
    else:
        try:
            retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                sql_query_str
            )
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise

    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

aretrieve_with_metadata async #

aretrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

使用元数据进行异步检索。

源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
async def aretrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Async retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")

    response_str = await self._llm.apredict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata: Dict[str, Any] = {}
    else:
        try:
            (
                retrieved_nodes,
                metadata,
            ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise
    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

SQLParserMode #

基类: str, Enum

SQL 解析器模式。

源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
118
119
120
121
122
class SQLParserMode(str, Enum):
    """SQL Parser Mode."""

    DEFAULT = "default"
    PGVECTOR = "pgvector"

SQLRetriever #

基类: BaseRetriever

SQL 检索器。

通过原始 SQL 语句进行检索。

参数

名称 类型 描述 默认值
sql_database SQLDatabase

SQL 数据库。

必需
return_raw bool

是否返回原始结果或格式化结果。默认为 True。

True
源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class SQLRetriever(BaseRetriever):
    """
    SQL Retriever.

    Retrieves via raw SQL statements.

    Args:
        sql_database (SQLDatabase): SQL database.
        return_raw (bool): Whether to return raw results or format results.
            Defaults to True.

    """

    def __init__(
        self,
        sql_database: SQLDatabase,
        return_raw: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        """Initialize params."""
        self._sql_database = sql_database
        self._return_raw = return_raw
        super().__init__(callback_manager)

    def _format_node_results(
        self, results: List[List[Any]], col_keys: List[str]
    ) -> List[NodeWithScore]:
        """Format node results."""
        nodes = []
        for result in results:
            # associate column keys with result tuple
            metadata = dict(zip(col_keys, result))
            # NOTE: leave text field blank for now
            text_node = TextNode(
                text="",
                metadata=metadata,
            )
            nodes.append(NodeWithScore(node=text_node))
        return nodes

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """Retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
        if self._return_raw:
            return [
                NodeWithScore(
                    node=TextNode(
                        text=raw_response_str,
                        metadata={
                            "sql_query": query_bundle.query_str,
                            "result": metadata["result"],
                            "col_keys": metadata["col_keys"],
                        },
                        excluded_embed_metadata_keys=[
                            "sql_query",
                            "result",
                            "col_keys",
                        ],
                        excluded_llm_metadata_keys=["sql_query", "result", "col_keys"],
                    )
                )
            ], metadata
        else:
            # return formatted
            results = metadata["result"]
            col_keys = metadata["col_keys"]
            return self._format_node_results(results, col_keys), metadata

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        return self.retrieve_with_metadata(str_or_query_bundle)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

retrieve_with_metadata #

retrieve_with_metadata(str_or_query_bundle: QueryType) -> Tuple[List[NodeWithScore], Dict]

使用元数据检索。

源码位于 llama-index-core/llama_index/core/indices/struct_store/sql_retriever.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """Retrieve with metadata."""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
    if self._return_raw:
        return [
            NodeWithScore(
                node=TextNode(
                    text=raw_response_str,
                    metadata={
                        "sql_query": query_bundle.query_str,
                        "result": metadata["result"],
                        "col_keys": metadata["col_keys"],
                    },
                    excluded_embed_metadata_keys=[
                        "sql_query",
                        "result",
                        "col_keys",
                    ],
                    excluded_llm_metadata_keys=["sql_query", "result", "col_keys"],
                )
            )
        ], metadata
    else:
        # return formatted
        results = metadata["result"]
        col_keys = metadata["col_keys"]
        return self._format_node_results(results, col_keys), metadata