From 9e0cbcb4e7f1f3f95f304046d3190c6ebc4d3901 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Apr 2024 09:09:24 +0200 Subject: Reformat and fix typo --- rag/retriever/rerank/cohere.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 rag/retriever/rerank/cohere.py (limited to 'rag/retriever/rerank/cohere.py') diff --git a/rag/retriever/rerank/cohere.py b/rag/retriever/rerank/cohere.py new file mode 100644 index 0000000..dac9ab5 --- /dev/null +++ b/rag/retriever/rerank/cohere.py @@ -0,0 +1,28 @@ +import os + +import cohere +from loguru import logger as log + +from rag.generator.prompt import Prompt +from rag.retriever.rerank.abstract import AbstractReranker + + +class CohereReranker(metaclass=AbstractReranker): + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + self.top_k = int(os.environ["RERANK_TOP_K"]) + + def rank(self, prompt: Prompt) -> Prompt: + if prompt.documents: + response = self.client.rerank( + model="rerank-english-v3.0", + query=prompt.query, + documents=[d.text for d in prompt.documents], + top_n=self.top_k, + ) + ranking = list(filter(lambda x: x.relevance_score > 0.5, response.results)) + log.debug( + f"Reranking gave {len(ranking)} relevant documents of {len(prompt.documents)}" + ) + prompt.documents = [prompt.documents[r.index] for r in ranking] + return prompt -- cgit v1.2.3-70-g09d2