推断、检索、重排包。
源代码位于 llama-index-packs/llama-index-packs-infer-retrieve-rerank/llama_index/packs/infer_retrieve_rerank/base.py
get_modules
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157 | class InferRetrieveRerankPack(BaseLlamaPack):
"""Infer Retrieve Rerank pack."""
def __init__(
self,
labels: List[str],
llm: Optional[LLM] = None,
pred_context: str = "",
reranker_top_n: int = 3,
infer_prompt: Optional[PromptTemplate] = None,
rerank_prompt: Optional[PromptTemplate] = None,
verbose: bool = False,
) -> None:
"""Init params."""
# NOTE: we use 16k model by default to fit longer contexts
self.llm = llm or OpenAI(model="gpt-3.5-turbo-16k")
label_nodes = [TextNode(text=label) for label in labels]
pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()])
label_nodes_w_embed = pipeline.run(documents=label_nodes)
index = VectorStoreIndex(label_nodes_w_embed, show_progress=verbose)
self.label_retriever = index.as_retriever(similarity_top_k=2)
self.pred_context = pred_context
self.reranker_top_n = reranker_top_n
self.verbose = verbose
self.infer_prompt = infer_prompt or INFER_PROMPT_TMPL
self.rerank_prompt = rerank_prompt or RERANK_PROMPT_TMPL
def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {
"llm": self.llm,
"label_retriever": self.label_retriever,
}
def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run the pipeline."""
inputs = kwargs.get("inputs", [])
pred_reactions = []
for idx, input in enumerate(inputs):
if self.verbose:
print(f"\n\n> Generating predictions for input {idx}: {input[:300]}")
cur_pred_reactions = infer_retrieve_rerank(
input,
self.label_retriever,
self.llm,
self.pred_context,
self.infer_prompt,
self.rerank_prompt,
reranker_top_n=self.reranker_top_n,
)
if self.verbose:
print(f"> Generated predictions: {cur_pred_reactions}")
pred_reactions.append(cur_pred_reactions)
return pred_reactions
|
获取模块。
get_modules() -> Dict[str, Any]
run
get_modules
| def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {
"llm": self.llm,
"label_retriever": self.label_retriever,
}
|
运行管道。
run(*args: Any, **kwargs: Any) -> Any
回到顶部
get_modules
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157 | def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run the pipeline."""
inputs = kwargs.get("inputs", [])
pred_reactions = []
for idx, input in enumerate(inputs):
if self.verbose:
print(f"\n\n> Generating predictions for input {idx}: {input[:300]}")
cur_pred_reactions = infer_retrieve_rerank(
input,
self.label_retriever,
self.llm,
self.pred_context,
self.infer_prompt,
self.rerank_prompt,
reranker_top_n=self.reranker_top_n,
)
if self.verbose:
print(f"> Generated predictions: {cur_pred_reactions}")
pred_reactions.append(cur_pred_reactions)
return pred_reactions
|