跳过内容

检索

评估模块。

BaseRetrievalEvaluator #

基类:BaseModel

基础检索评估器类。

参数

名称 类型 描述 默认值
metrics List[BaseRetrievalMetric]

要评估的指标列表

必需
源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class BaseRetrievalEvaluator(BaseModel):
    """Base Retrieval Evaluator class."""

    model_config = ConfigDict(arbitrary_types_allowed=True)
    metrics: List[BaseRetrievalMetric] = Field(
        ..., description="List of metrics to evaluate"
    )

    @classmethod
    def from_metric_names(
        cls, metric_names: List[str], **kwargs: Any
    ) -> "BaseRetrievalEvaluator":
        """
        Create evaluator from metric names.

        Args:
            metric_names (List[str]): List of metric names
            **kwargs: Additional arguments for the evaluator

        """
        metric_types = resolve_metrics(metric_names)
        return cls(metrics=[metric() for metric in metric_types], **kwargs)

    @abstractmethod
    async def _aget_retrieved_ids_and_texts(
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
    ) -> Tuple[List[str], List[str]]:
        """Get retrieved ids and texts."""
        raise NotImplementedError

    def evaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """
        Run evaluation results with query string and expected ids.

        Args:
            query (str): Query string
            expected_ids (List[str]): Expected ids

        Returns:
            RetrievalEvalResult: Evaluation result

        """
        return asyncio_run(
            self.aevaluate(
                query=query,
                expected_ids=expected_ids,
                expected_texts=expected_texts,
                mode=mode,
                **kwargs,
            )
        )

    # @abstractmethod
    async def aevaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """
        Run evaluation with query string, retrieved contexts,
        and generated response string.

        Subclasses can override this method to provide custom evaluation logic and
        take in additional arguments.
        """
        retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
            query, mode
        )
        metric_dict = {}
        for metric in self.metrics:
            eval_result = metric.compute(
                query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
            )
            metric_dict[metric.metric_name] = eval_result

        return RetrievalEvalResult(
            query=query,
            expected_ids=expected_ids,
            expected_texts=expected_texts,
            retrieved_ids=retrieved_ids,
            retrieved_texts=retrieved_texts,
            mode=mode,
            metric_dict=metric_dict,
        )

    async def aevaluate_dataset(
        self,
        dataset: EmbeddingQAFinetuneDataset,
        workers: int = 2,
        show_progress: bool = False,
        **kwargs: Any,
    ) -> List[RetrievalEvalResult]:
        """Run evaluation with dataset."""
        semaphore = asyncio.Semaphore(workers)

        async def eval_worker(
            query: str, expected_ids: List[str], mode: RetrievalEvalMode
        ) -> RetrievalEvalResult:
            async with semaphore:
                return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)

        response_jobs = []
        mode = RetrievalEvalMode.from_str(dataset.mode)
        for query_id, query in dataset.queries.items():
            expected_ids = dataset.relevant_docs[query_id]
            response_jobs.append(eval_worker(query, expected_ids, mode))
        if show_progress:
            from tqdm.asyncio import tqdm_asyncio

            eval_results = await tqdm_asyncio.gather(*response_jobs)
        else:
            eval_results = await asyncio.gather(*response_jobs)

        return eval_results

from_metric_names classmethod #

from_metric_names(metric_names: List[str], **kwargs: Any) -> BaseRetrievalEvaluator

从指标名称创建评估器。

参数

名称 类型 描述 默认值
metric_names List[str]

指标名称列表

必需
**kwargs Any

评估器的附加参数

{}
源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@classmethod
def from_metric_names(
    cls, metric_names: List[str], **kwargs: Any
) -> "BaseRetrievalEvaluator":
    """
    Create evaluator from metric names.

    Args:
        metric_names (List[str]): List of metric names
        **kwargs: Additional arguments for the evaluator

    """
    metric_types = resolve_metrics(metric_names)
    return cls(metrics=[metric() for metric in metric_types], **kwargs)

evaluate #

evaluate(query: str, expected_ids: List[str], expected_texts: Optional[List[str]] = None, mode: RetrievalEvalMode = TEXT, **kwargs: Any) -> RetrievalEvalResult

使用查询字符串和预期 ID 运行评估结果。

参数

名称 类型 描述 默认值
query str

查询字符串

必需
expected_ids List[str]

预期 ID

必需

返回

名称 类型 描述
默认值 默认值

评估结果

源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def evaluate(
    self,
    query: str,
    expected_ids: List[str],
    expected_texts: Optional[List[str]] = None,
    mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
    **kwargs: Any,
) -> RetrievalEvalResult:
    """
    Run evaluation results with query string and expected ids.

    Args:
        query (str): Query string
        expected_ids (List[str]): Expected ids

    Returns:
        RetrievalEvalResult: Evaluation result

    """
    return asyncio_run(
        self.aevaluate(
            query=query,
            expected_ids=expected_ids,
            expected_texts=expected_texts,
            mode=mode,
            **kwargs,
        )
    )

aevaluate async #

aevaluate(query: str, expected_ids: List[str], expected_texts: Optional[List[str]] = None, mode: RetrievalEvalMode = TEXT, **kwargs: Any) -> RetrievalEvalResult

使用查询字符串、检索到的上下文和生成的响应字符串运行评估。

子类可以覆盖此方法以提供自定义评估逻辑并接受附加参数。

源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
async def aevaluate(
    self,
    query: str,
    expected_ids: List[str],
    expected_texts: Optional[List[str]] = None,
    mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
    **kwargs: Any,
) -> RetrievalEvalResult:
    """
    Run evaluation with query string, retrieved contexts,
    and generated response string.

    Subclasses can override this method to provide custom evaluation logic and
    take in additional arguments.
    """
    retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
        query, mode
    )
    metric_dict = {}
    for metric in self.metrics:
        eval_result = metric.compute(
            query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
        )
        metric_dict[metric.metric_name] = eval_result

    return RetrievalEvalResult(
        query=query,
        expected_ids=expected_ids,
        expected_texts=expected_texts,
        retrieved_ids=retrieved_ids,
        retrieved_texts=retrieved_texts,
        mode=mode,
        metric_dict=metric_dict,
    )

aevaluate_dataset async #

aevaluate_dataset(dataset: EmbeddingQAFinetuneDataset, workers: int = 2, show_progress: bool = False, **kwargs: Any) -> List[RetrievalEvalResult]

使用数据集运行评估。

源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
async def aevaluate_dataset(
    self,
    dataset: EmbeddingQAFinetuneDataset,
    workers: int = 2,
    show_progress: bool = False,
    **kwargs: Any,
) -> List[RetrievalEvalResult]:
    """Run evaluation with dataset."""
    semaphore = asyncio.Semaphore(workers)

    async def eval_worker(
        query: str, expected_ids: List[str], mode: RetrievalEvalMode
    ) -> RetrievalEvalResult:
        async with semaphore:
            return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)

    response_jobs = []
    mode = RetrievalEvalMode.from_str(dataset.mode)
    for query_id, query in dataset.queries.items():
        expected_ids = dataset.relevant_docs[query_id]
        response_jobs.append(eval_worker(query, expected_ids, mode))
    if show_progress:
        from tqdm.asyncio import tqdm_asyncio

        eval_results = await tqdm_asyncio.gather(*response_jobs)
    else:
        eval_results = await asyncio.gather(*response_jobs)

    return eval_results

RetrieverEvaluator #

基类:BaseRetrievalEvaluator

检索器评估器。

此模块将使用一组指标评估检索器。

参数

名称 类型 描述 默认值
metrics List[BaseRetrievalMetric]

要评估的指标序列

必需
retriever BaseRetriever

要评估的检索器。

必需
node_postprocessors Optional[List[BaseNodePostprocessor]]

在检索后应用的后处理器。

工作流运行检查点
源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RetrieverEvaluator(BaseRetrievalEvaluator):
    """
    Retriever evaluator.

    This module will evaluate a retriever using a set of metrics.

    Args:
        metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
        retriever: Retriever to evaluate.
        node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.


    """

    retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
    node_postprocessors: Optional[List[SerializeAsAny[BaseNodePostprocessor]]] = Field(
        default=None, description="Optional post-processor"
    )

    async def _aget_retrieved_ids_and_texts(
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
    ) -> Tuple[List[str], List[str]]:
        """Get retrieved ids and texts, potentially applying a post-processor."""
        retrieved_nodes = await self.retriever.aretrieve(query)

        if self.node_postprocessors:
            for node_postprocessor in self.node_postprocessors:
                retrieved_nodes = node_postprocessor.postprocess_nodes(
                    retrieved_nodes, query_str=query
                )

        return (
            [node.node.node_id for node in retrieved_nodes],
            [node.text for node in retrieved_nodes],
        )

RetrievalEvalResult #

基类:BaseModel

检索评估结果。

注意:此抽象未来可能会更改。

参数

名称 类型 描述 默认值
query str

查询字符串

必需
expected_ids List[str]

预期 ID

必需
expected_texts List[str] | None

expected_ids 中提供的节点关联的预期文本

工作流运行检查点
retrieved_ids List[str]

检索到的 ID

必需
retrieved_texts List[str]

检索到的文本

必需
mode RetrievalEvalMode

文本或图像

<RetrievalEvalMode.TEXT: 'text'>
metric_dict Dict[str, RetrievalMetricResult]

评估的指标字典

必需

属性

名称 类型 描述
query str

查询字符串

expected_ids List[str]

预期 ID

retrieved_ids List[str]

检索到的 ID

metric_dict Dict[str, BaseRetrievalMetric]

评估的指标字典

源代码位于 llama-index-core/llama_index/core/evaluation/retrieval/base.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class RetrievalEvalResult(BaseModel):
    """
    Retrieval eval result.

    NOTE: this abstraction might change in the future.

    Attributes:
        query (str): Query string
        expected_ids (List[str]): Expected ids
        retrieved_ids (List[str]): Retrieved ids
        metric_dict (Dict[str, BaseRetrievalMetric]): \
            Metric dictionary for the evaluation

    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    query: str = Field(..., description="Query string")
    expected_ids: List[str] = Field(..., description="Expected ids")
    expected_texts: Optional[List[str]] = Field(
        default=None,
        description="Expected texts associated with nodes provided in `expected_ids`",
    )
    retrieved_ids: List[str] = Field(..., description="Retrieved ids")
    retrieved_texts: List[str] = Field(..., description="Retrieved texts")
    mode: "RetrievalEvalMode" = Field(
        default=RetrievalEvalMode.TEXT, description="text or image"
    )
    metric_dict: Dict[str, RetrievalMetricResult] = Field(
        ..., description="Metric dictionary for the evaluation"
    )

    @property
    def metric_vals_dict(self) -> Dict[str, float]:
        """Dictionary of metric values."""
        return {k: v.score for k, v in self.metric_dict.items()}

    def __str__(self) -> str:
        """String representation."""
        return f"Query: {self.query}\n" f"Metrics: {self.metric_vals_dict!s}\n"

metric_vals_dict property #

metric_vals_dict: Dict[str, float]

指标值字典。