classQueryFusionRetriever(BaseRetriever):def__init__(self,retrievers:List[BaseRetriever],llm:Optional[LLMType]=None,query_gen_prompt:Optional[str]=None,mode:FUSION_MODES=FUSION_MODES.SIMPLE,similarity_top_k:int=DEFAULT_SIMILARITY_TOP_K,num_queries:int=4,use_async:bool=True,verbose:bool=False,callback_manager:Optional[CallbackManager]=None,objects:Optional[List[IndexNode]]=None,object_map:Optional[dict]=None,retriever_weights:Optional[List[float]]=None,)->None:self.num_queries=num_queriesself.query_gen_prompt=query_gen_promptorQUERY_GEN_PROMPTself.similarity_top_k=similarity_top_kself.mode=modeself.use_async=use_asyncself._retrievers=retrieversifretriever_weightsisNone:self._retriever_weights=[1.0/len(retrievers)]*len(retrievers)else:# Sum of retriever_weights must be 1total_weight=sum(retriever_weights)self._retriever_weights=[w/total_weightforwinretriever_weights]self._llm=(resolve_llm(llm,callback_manager=callback_manager)ifllmelseSettings.llm)super().__init__(callback_manager=callback_manager,object_map=object_map,objects=objects,verbose=verbose,)def_get_prompts(self)->PromptDictType:"""Get prompts."""return{"query_gen_prompt":PromptTemplate(self.query_gen_prompt)}def_update_prompts(self,prompts:PromptDictType)->None:"""Update prompts."""if"query_gen_prompt"inprompts:self.query_gen_prompt=cast(PromptTemplate,prompts["query_gen_prompt"]).templatedef_get_queries(self,original_query:str)->List[QueryBundle]:prompt_str=self.query_gen_prompt.format(num_queries=self.num_queries-1,query=original_query,)response=self._llm.complete(prompt_str)# assume LLM proper put each query on a newlinequeries=response.text.split("\n")queries=[q.strip()forqinqueriesifq.strip()]ifself._verbose:queries_str="\n".join(queries)print(f"Generated queries:\n{queries_str}")# The LLM often returns more queries than we asked for, so trim the list.return[QueryBundle(q)forqinqueries[:self.num_queries-1]]def_reciprocal_rerank_fusion(self,results:Dict[Tuple[str,int],List[NodeWithScore]])->List[NodeWithScore]:""" Apply reciprocal rank fusion. The original paper uses k=60 for best results: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf """k=60.0# `k` is a parameter used to control the impact of outlier rankings.fused_scores={}hash_to_node={}# compute reciprocal rank scoresfornodes_with_scoresinresults.values():forrank,node_with_scoreinenumerate(sorted(nodes_with_scores,key=lambdax:x.scoreor0.0,reverse=True)):hash=node_with_score.node.hashhash_to_node[hash]=node_with_scoreifhashnotinfused_scores:fused_scores[hash]=0.0fused_scores[hash]+=1.0/(rank+k)# sort resultsreranked_results=dict(sorted(fused_scores.items(),key=lambdax:x[1],reverse=True))# adjust node scoresreranked_nodes:List[NodeWithScore]=[]forhash,scoreinreranked_results.items():reranked_nodes.append(hash_to_node[hash])reranked_nodes[-1].score=scorereturnreranked_nodesdef_relative_score_fusion(self,results:Dict[Tuple[str,int],List[NodeWithScore]],dist_based:Optional[bool]=False,)->List[NodeWithScore]:"""Apply relative score fusion."""# MinMax scale scores of each result set (highest value becomes 1, lowest becomes 0)# then scale by the weight of the retrievermin_max_scores={}forquery_tuple,nodes_with_scoresinresults.items():ifnotnodes_with_scores:min_max_scores[query_tuple]=(0.0,0.0)continuescores=[node_with_score.scoreor0.0fornode_with_scoreinnodes_with_scores]ifdist_based:# Set min and max based on mean and std devmean_score=sum(scores)/len(scores)std_dev=(sum((x-mean_score)**2forxinscores)/len(scores))**0.5min_score=mean_score-3*std_devmax_score=mean_score+3*std_develse:min_score=min(scores)max_score=max(scores)min_max_scores[query_tuple]=(min_score,max_score)forquery_tuple,nodes_with_scoresinresults.items():fornode_with_scoreinnodes_with_scores:min_score,max_score=min_max_scores[query_tuple]# Scale the score to be between 0 and 1ifmax_score==min_score:node_with_score.score=1.0ifmax_score>0else0.0else:node_with_score.score=(node_with_score.score-min_score)/(max_score-min_score)# Scale by the weight of the retrieverretriever_idx=query_tuple[1]existing_score=node_with_score.scoreor0.0node_with_score.score=(existing_score*self._retriever_weights[retriever_idx])# Divide by the number of queriesnode_with_score.score/=self.num_queries# Use a dict to de-duplicate nodesall_nodes:Dict[str,NodeWithScore]={}# Sum scores for each nodefornodes_with_scoresinresults.values():fornode_with_scoreinnodes_with_scores:hash=node_with_score.node.hashifhashinall_nodes:cur_score=all_nodes[hash].scoreor0.0all_nodes[hash].score=cur_score+(node_with_score.scoreor0.0)else:all_nodes[hash]=node_with_scorereturnsorted(all_nodes.values(),key=lambdax:x.scoreor0.0,reverse=True)def_simple_fusion(self,results:Dict[Tuple[str,int],List[NodeWithScore]])->List[NodeWithScore]:"""Apply simple fusion."""# Use a dict to de-duplicate nodesall_nodes:Dict[str,NodeWithScore]={}fornodes_with_scoresinresults.values():fornode_with_scoreinnodes_with_scores:hash=node_with_score.node.hashifhashinall_nodes:max_score=max(node_with_score.scoreor0.0,all_nodes[hash].scoreor0.0)all_nodes[hash].score=max_scoreelse:all_nodes[hash]=node_with_scorereturnsorted(all_nodes.values(),key=lambdax:x.scoreor0.0,reverse=True)def_run_nested_async_queries(self,queries:List[QueryBundle])->Dict[Tuple[str,int],List[NodeWithScore]]:tasks,task_queries=[],[]forqueryinqueries:fori,retrieverinenumerate(self._retrievers):tasks.append(retriever.aretrieve(query))task_queries.append((query.query_str,i))task_results=run_async_tasks(tasks)results={}forquery_tuple,query_resultinzip(task_queries,task_results):results[query_tuple]=query_resultreturnresultsasyncdef_run_async_queries(self,queries:List[QueryBundle])->Dict[Tuple[str,int],List[NodeWithScore]]:tasks,task_queries=[],[]forqueryinqueries:fori,retrieverinenumerate(self._retrievers):tasks.append(retriever.aretrieve(query))task_queries.append((query.query_str,i))task_results=awaitasyncio.gather(*tasks)results={}forquery_tuple,query_resultinzip(task_queries,task_results):results[query_tuple]=query_resultreturnresultsdef_run_sync_queries(self,queries:List[QueryBundle])->Dict[Tuple[str,int],List[NodeWithScore]]:results={}forqueryinqueries:fori,retrieverinenumerate(self._retrievers):results[(query.query_str,i)]=retriever.retrieve(query)returnresultsdef_retrieve(self,query_bundle:QueryBundle)->List[NodeWithScore]:queries:List[QueryBundle]=[query_bundle]ifself.num_queries>1:queries.extend(self._get_queries(query_bundle.query_str))ifself.use_async:results=self._run_nested_async_queries(queries)else:results=self._run_sync_queries(queries)ifself.mode==FUSION_MODES.RECIPROCAL_RANK:returnself._reciprocal_rerank_fusion(results)[:self.similarity_top_k]elifself.mode==FUSION_MODES.RELATIVE_SCORE:returnself._relative_score_fusion(results)[:self.similarity_top_k]elifself.mode==FUSION_MODES.DIST_BASED_SCORE:returnself._relative_score_fusion(results,dist_based=True)[:self.similarity_top_k]elifself.mode==FUSION_MODES.SIMPLE:returnself._simple_fusion(results)[:self.similarity_top_k]else:raiseValueError(f"Invalid fusion mode: {self.mode}")asyncdef_aretrieve(self,query_bundle:QueryBundle)->List[NodeWithScore]:queries:List[QueryBundle]=[query_bundle]ifself.num_queries>1:queries.extend(self._get_queries(query_bundle.query_str))results=awaitself._run_async_queries(queries)ifself.mode==FUSION_MODES.RECIPROCAL_RANK:returnself._reciprocal_rerank_fusion(results)[:self.similarity_top_k]elifself.mode==FUSION_MODES.RELATIVE_SCORE:returnself._relative_score_fusion(results)[:self.similarity_top_k]elifself.mode==FUSION_MODES.DIST_BASED_SCORE:returnself._relative_score_fusion(results,dist_based=True)[:self.similarity_top_k]elifself.mode==FUSION_MODES.SIMPLE:returnself._simple_fusion(results)[:self.similarity_top_k]else:raiseValueError(f"Invalid fusion mode: {self.mode}")