diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 00:18:57 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2024-04-06 00:18:57 +0200 |
| commit | a1603d4c6c29f414304fc379074eb81b5b00c5d0 (patch) | |
| tree | 2ebad5348fe62148db405a4637eb49274f7c9766 /rag/db/embeddings.py | |
| parent | 093553777355e6d1d6c2dc9b0326909bf9859cba (diff) | |
Add logging in dbs
Diffstat (limited to 'rag/db/embeddings.py')
| -rw-r--r-- | rag/db/embeddings.py | 47 |
1 files changed, 0 insertions, 47 deletions
diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py deleted file mode 100644 index 4f61dcc..0000000 --- a/rag/db/embeddings.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -from dataclasses import dataclass -from typing import Dict, List -from uuid import uuid4 - -from qdrant_client import QdrantClient -from qdrant_client.http.models import StrictFloat -from qdrant_client.models import Distance, VectorParams, PointStruct - - -@dataclass -class Point: - id: str - vector: List[StrictFloat] - payload: Dict[str, str] - - -class Embeddings: - def __init__(self): - self.dim = int(os.environ["EMBEDDING_DIM"]) - self.collection_name = os.environ["QDRANT_COLLECTION_NAME"] - self.client = QdrantClient(url=os.environ["QDRANT_URL"]) - self.client.delete_collection( - collection_name=self.collection_name, - ) - self.client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE), - ) - - def add(self, points: List[Point]): - print(len(points)) - self.client.upload_points( - collection_name=self.collection_name, - points=[ - PointStruct(id=point.id, vector=point.vector, payload=point.payload) - for point in points - ], - parallel=4, - max_retries=3, - ) - - def search(self, query: List[float], limit: int = 4): - hits = self.client.search( - collection_name=self.collection_name, query_vector=query, limit=limit - ) - return hits |