diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 01:22:13 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 01:22:13 +0200 |
| commit | 3ba6eca92a339e28ffce14adf46d2fb71e6f4958 (patch) | |
| tree | d03d23c080c51382e4c987f52932253a0f814136 /rag/llm/encoder.py | |
| parent | 13ac875b2269756045834d7a64e7b35acb9ce0b4 (diff) | |
Refactor
Diffstat (limited to 'rag/llm/encoder.py')
| -rw-r--r-- | rag/llm/encoder.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/rag/llm/encoder.py b/rag/llm/encoder.py index a686aaf..94d5559 100644 --- a/rag/llm/encoder.py +++ b/rag/llm/encoder.py @@ -1,13 +1,13 @@ import os -from typing import List +from typing import List, Optional from uuid import uuid4 -import numpy as np import ollama from langchain_core.documents import Document +from loguru import logger as log from qdrant_client.http.models import StrictFloat -from rag.db.embeddings import Point +from rag.db.vector import Point class Encoder: @@ -18,7 +18,8 @@ class Encoder: def __encode(self, prompt: str) -> List[StrictFloat]: return list(ollama.embeddings(model=self.model, prompt=prompt)["embedding"]) - def encode_document(self, chunks: List[Document]) -> np.ndarray: + def encode_document(self, chunks: List[Document]) -> List[Point]: + log.debug("Encoding document...") return [ Point( id=str(uuid4()), @@ -28,6 +29,7 @@ class Encoder: for chunk in chunks ] - def query(self, query: str) -> np.ndarray: + def encode_query(self, query: str) -> List[StrictFloat]: + log.debug(f"Encoding query: {query}") query = self.query_prompt + query return self.__encode(query) |