summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index a686aaf..94d5559 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,13 +1,13 @@
import os
-from typing import List
+from typing import List, Optional
from uuid import uuid4
-import numpy as np
import ollama
from langchain_core.documents import Document
+from loguru import logger as log
from qdrant_client.http.models import StrictFloat
-from rag.db.embeddings import Point
+from rag.db.vector import Point
class Encoder:
@@ -18,7 +18,8 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: List[Document]) -> np.ndarray:
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
+ log.debug("Encoding document...")
return [
Point(
id=str(uuid4()),
@@ -28,6 +29,7 @@ class Encoder:
for chunk in chunks
]
- def query(self, query: str) -> np.ndarray:
+ def encode_query(self, query: str) -> List[StrictFloat]:
+ log.debug(f"Encoding query: {query}")
query = self.query_prompt + query
return self.__encode(query)