From d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 22:28:47 +0200 Subject: Wip refactor --- rag/llm/encoder.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'rag/llm/encoder.py') diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index 95f3c6a..a59b1b4 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,5 +1,6 @@ import os -from typing import Iterator, List +from pathlib import Path +from typing import List, Dict from uuid import uuid4 import ollama @@ -13,6 +14,7 @@ try: except ModuleNotFoundError: from db.vector import Point + class Encoder: def __init__(self) -> None: self.model = os.environ["ENCODER_MODEL"] @@ -21,13 +23,20 @@ class Encoder: def __encode(self, prompt: str) -> List[StrictFloat]: return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - def encode_document(self, chunks: Iterator[Document]) -> List[Point]: + def __get_source(self, metadata: Dict[str, str]) -> str: + source = metadata["source"] + return Path(source).name + + def encode_document(self, chunks: List[Document]) -> List[Point]: log.debug("Encoding document...") return [ Point( id=uuid4().hex, vector=self.__encode(chunk.page_content), - payload={"text": chunk.page_content}, + payload={ + "text": chunk.page_content, + "source": self.__get_source(chunk.metadata), + }, ) for chunk in chunks ] -- cgit v1.2.3-70-g09d2