summaryrefslogtreecommitdiff
path: root/rag/cli.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-09 00:41:55 +0200
commit3f447bff69c20109474c455f1ad52bd547ab49e9 (patch)
tree66695ca0e3423e2973c5b24ec1ae7096019b5dd0 /rag/cli.py
parentc05eae81f9aaa0a764203446ec54d7dd7cbeb66f (diff)
Update
Diffstat (limited to 'rag/cli.py')
-rw-r--r--rag/cli.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/rag/cli.py b/rag/cli.py
index d5651c1..003718d 100644
--- a/rag/cli.py
+++ b/rag/cli.py
@@ -1,27 +1,33 @@
from pathlib import Path
+from dotenv import load_dotenv
+from rag.generator import get_generator, MODELS
+from rag.retriever.retriever import Retriever
+from rag.generator.prompt import Prompt
-try:
- from rag.rag import RAG
-except ModuleNotFoundError:
- from rag import RAG
if __name__ == "__main__":
- rag = RAG()
+ load_dotenv()
+ retriever = Retriever()
+
+ print("\n\nRetrieval Augmented Generation\n")
+ model = input(f"Enter model ({MODELS}):")
while True:
- print("Retrieval Augmented Generation")
- choice = input("1. add pdf from path\n2. Enter a query\n")
+ choice = input("1. Add pdf from path\n2. Enter a query\n")
match choice:
case "1":
path = input("Enter the path to the pdf: ")
path = Path(path)
- rag.add_pdf_from_path(path)
+ retriever.add_pdf(path=path)
case "2":
query = input("Enter your query: ")
if query:
- result = rag.retrive(query)
+ generator = get_generator(model)
+ documents = retriever.retrieve(query)
+ prompt = Prompt(query, documents)
print("Answer: \n")
- print(result.answer + "\n")
+ for chunk in generator.generate(prompt):
+ print(chunk, end="", flush=True)
case _:
print("Invalid option!")