From 3f447bff69c20109474c455f1ad52bd547ab49e9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 9 Apr 2024 00:41:55 +0200 Subject: Update --- rag/cli.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'rag/cli.py') diff --git a/rag/cli.py b/rag/cli.py index d5651c1..003718d 100644 --- a/rag/cli.py +++ b/rag/cli.py @@ -1,27 +1,33 @@ from pathlib import Path +from dotenv import load_dotenv +from rag.generator import get_generator, MODELS +from rag.retriever.retriever import Retriever +from rag.generator.prompt import Prompt -try: - from rag.rag import RAG -except ModuleNotFoundError: - from rag import RAG if __name__ == "__main__": - rag = RAG() + load_dotenv() + retriever = Retriever() + + print("\n\nRetrieval Augmented Generation\n") + model = input(f"Enter model ({MODELS}):") while True: - print("Retrieval Augmented Generation") - choice = input("1. add pdf from path\n2. Enter a query\n") + choice = input("1. Add pdf from path\n2. Enter a query\n") match choice: case "1": path = input("Enter the path to the pdf: ") path = Path(path) - rag.add_pdf_from_path(path) + retriever.add_pdf(path=path) case "2": query = input("Enter your query: ") if query: - result = rag.retrive(query) + generator = get_generator(model) + documents = retriever.retrieve(query) + prompt = Prompt(query, documents) print("Answer: \n") - print(result.answer + "\n") + for chunk in generator.generate(prompt): + print(chunk, end="", flush=True) case _: print("Invalid option!") -- cgit v1.2.3-70-g09d2