From d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 22:28:47 +0200 Subject: Wip refactor --- rag/rag.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) (limited to 'rag/rag.py') diff --git a/rag/rag.py b/rag/rag.py index cd4537e..93f9fd7 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -5,20 +5,22 @@ from typing import List from dotenv import load_dotenv from loguru import logger as log -from qdrant_client.models import StrictFloat + try: - from rag.db.vector import VectorDB + from rag.db.vector import VectorDB, Document from rag.db.document import DocumentDB from rag.llm.encoder import Encoder - from rag.llm.generator import Generator, Prompt + from rag.llm.ollama_generator import OllamaGenerator, Prompt + from rag.llm.cohere_generator import CohereGenerator from rag.parser.pdf import PDFParser except ModuleNotFoundError: - from db.vector import VectorDB + from db.vector import VectorDB, Document from db.document import DocumentDB from llm.encoder import Encoder - from llm.generator import Generator, Prompt + from llm.ollama_generator import OllamaGenerator, Prompt + from llm.cohere_generator import CohereGenerator from parser.pdf import PDFParser @@ -34,7 +36,7 @@ class RAG: # FIXME: load this somewhere else? load_dotenv() self.pdf_parser = PDFParser() - self.generator = Generator() + self.generator = CohereGenerator() self.encoder = Encoder() self.vector_db = VectorDB() self.doc_db = DocumentDB() @@ -43,23 +45,19 @@ class RAG: blob = self.pdf_parser.from_path(path) self.add_pdf_from_blob(blob) - def add_pdf_from_blob(self, blob: BytesIO): + def add_pdf_from_blob(self, blob: BytesIO, source: str): if self.doc_db.add(blob): log.debug("Adding pdf to vector database...") - chunks = self.pdf_parser.from_data(blob) + document = self.pdf_parser.from_data(blob) + chunks = self.pdf_parser.chunk(document, source) points = self.encoder.encode_document(chunks) self.vector_db.add(points) else: log.debug("Document already exists!") - def __context(self, query_emb: List[StrictFloat], limit: int) -> str: - hits = self.vector_db.search(query_emb, limit) - log.debug(f"Got {len(hits)} hits in the vector db with limit={limit}") - return [h.payload["text"] for h in hits] - - def retrive(self, query: str, limit: int = 5) -> Response: + def search(self, query: str, limit: int = 5) -> List[Document]: query_emb = self.encoder.encode_query(query) - context = self.__context(query_emb, limit) - prompt = Prompt(query, "\n".join(context)) - answer = self.generator.generate(prompt)["response"] - return Response(query, context, answer) + return self.vector_db.search(query_emb, limit) + + def retrieve(self, prompt: Prompt): + yield from self.generator.generate(prompt) -- cgit v1.2.3-70-g09d2