summaryrefslogtreecommitdiff
path: root/rag/llm/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 01:22:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 01:22:13 +0200
commit3ba6eca92a339e28ffce14adf46d2fb71e6f4958 (patch)
treed03d23c080c51382e4c987f52932253a0f814136 /rag/llm/encoder.py
parent13ac875b2269756045834d7a64e7b35acb9ce0b4 (diff)
Refactor
Diffstat (limited to 'rag/llm/encoder.py')
-rw-r--r--rag/llm/encoder.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py
index a686aaf..94d5559 100644
--- a/rag/llm/encoder.py
+++ b/rag/llm/encoder.py
@@ -1,13 +1,13 @@
import os
-from typing import List
+from typing import List, Optional
from uuid import uuid4
-import numpy as np
import ollama
from langchain_core.documents import Document
+from loguru import logger as log
from qdrant_client.http.models import StrictFloat
-from rag.db.embeddings import Point
+from rag.db.vector import Point
class Encoder:
@@ -18,7 +18,8 @@ class Encoder:
def __encode(self, prompt: str) -> List[StrictFloat]:
return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"])
- def encode_document(self, chunks: List[Document]) -> np.ndarray:
+ def encode_document(self, chunks: List[Document]) -> List[Point]:
+ log.debug("Encoding document...")
return [
Point(
id=str(uuid4()),
@@ -28,6 +29,7 @@ class Encoder:
for chunk in chunks
]
- def query(self, query: str) -> np.ndarray:
+ def encode_query(self, query: str) -> List[StrictFloat]:
+ log.debug(f"Encoding query: {query}")
query = self.query_prompt + query
return self.__encode(query)