summaryrefslogtreecommitdiff
path: root/rag/db/embeddings.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 00:18:57 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2024-04-06 00:18:57 +0200
commita1603d4c6c29f414304fc379074eb81b5b00c5d0 (patch)
tree2ebad5348fe62148db405a4637eb49274f7c9766 /rag/db/embeddings.py
parent093553777355e6d1d6c2dc9b0326909bf9859cba (diff)
Add logging in dbs
Diffstat (limited to 'rag/db/embeddings.py')
-rw-r--r--rag/db/embeddings.py47
1 files changed, 0 insertions, 47 deletions
diff --git a/rag/db/embeddings.py b/rag/db/embeddings.py
deleted file mode 100644
index 4f61dcc..0000000
--- a/rag/db/embeddings.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import os
-from dataclasses import dataclass
-from typing import Dict, List
-from uuid import uuid4
-
-from qdrant_client import QdrantClient
-from qdrant_client.http.models import StrictFloat
-from qdrant_client.models import Distance, VectorParams, PointStruct
-
-
-@dataclass
-class Point:
- id: str
- vector: List[StrictFloat]
- payload: Dict[str, str]
-
-
-class Embeddings:
- def __init__(self):
- self.dim = int(os.environ["EMBEDDING_DIM"])
- self.collection_name = os.environ["QDRANT_COLLECTION_NAME"]
- self.client = QdrantClient(url=os.environ["QDRANT_URL"])
- self.client.delete_collection(
- collection_name=self.collection_name,
- )
- self.client.create_collection(
- collection_name=self.collection_name,
- vectors_config=VectorParams(size=self.dim, distance=Distance.COSINE),
- )
-
- def add(self, points: List[Point]):
- print(len(points))
- self.client.upload_points(
- collection_name=self.collection_name,
- points=[
- PointStruct(id=point.id, vector=point.vector, payload=point.payload)
- for point in points
- ],
- parallel=4,
- max_retries=3,
- )
-
- def search(self, query: List[float], limit: int = 4):
- hits = self.client.search(
- collection_name=self.collection_name, query_vector=query, limit=limit
- )
- return hits