Coverage for src/ragindexer/QdrantIndexer.py: 80%
64 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-20 15:57 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-20 15:57 +0000
1import hashlib
2from pathlib import Path
3import time
4from typing import Optional, List, Sequence, Union
5import uuid
7from qdrant_client.conversions import common_types as types
8from qdrant_client import QdrantClient
9from qdrant_client.models import (
10 VectorParams,
11 Distance,
12 PointStruct,
13 PointIdsList,
14)
15import requests
17from . import logger
18from .config import config
19from .models import ChunkType, EmbeddingType
22# === Qdrant helper ===
23class QdrantIndexer:
24 """Qdrant client that handles database operations based on the configuration
26 Args:
27 vector_size: Size of the embedding vectors
29 """
31 def __init__(self, vector_size: int):
32 self.__client = QdrantClient(url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY)
33 self.vector_size = vector_size
34 self.__create_collection_if_missing()
36 def create_snapshot(self, output: Path | None = None) -> Path:
37 snap_desc = self.__client.create_snapshot(collection_name=config.COLLECTION_NAME)
39 url = config.QDRANT_URL
40 headers = {"api-key": config.QDRANT_API_KEY}
41 response = requests.get(url, headers=headers)
43 if output.suffix == ".snapshot": 43 ↛ 46line 43 didn't jump to line 46 because the condition on line 43 was always true
44 snap_path = output
45 else:
46 snap_path = output / snap_desc.name
48 if response.status_code == 200: 48 ↛ 54line 48 didn't jump to line 54 because the condition on line 48 was always true
49 with open(snap_path, "wb") as f:
50 for chunk in response.iter_content(chunk_size=1024):
51 if chunk: 51 ↛ 50line 51 didn't jump to line 50 because the condition on line 51 was always true
52 f.write(chunk)
54 return snap_path
56 def info(self) -> types.CollectionInfo:
57 info = self.__client.get_collection(collection_name=config.COLLECTION_NAME)
58 return info
60 def empty_collection(self):
61 self.__client.delete_collection(collection_name=config.COLLECTION_NAME)
62 self.__create_collection_if_missing()
64 def search(
65 self,
66 query_vector: Optional[
67 Union[
68 Sequence[float],
69 tuple[str, list[float]],
70 types.NamedVector,
71 types.NamedSparseVector,
72 types.NumpyArray,
73 ]
74 ] = None,
75 limit: Optional[int] = 10,
76 query_filter: Optional[types.Filter] = None,
77 ):
78 """Search a vector in the database
79 See https://qdrant.tech/documentation/concepts/search/
80 and https://qdrant.tech/documentation/concepts/filtering/ for more details
82 Args:
83 query_vector: Search for vectors closest to this. If None, allows listing ids
84 limit: How many results return
85 query_filter:
86 - Exclude vectors which doesn't fit given conditions.
87 - If `None` - search among all vectors
89 Returns:
90 List of found close points with similarity scores.
92 """
93 if query_vector is None:
94 query_vect = [0.0] * self.vector_size # dummy vector; we only want IDs
95 else:
96 query_vect = query_vector
98 hits = self.__client.query_points(
99 collection_name=config.COLLECTION_NAME,
100 query=query_vect,
101 limit=limit,
102 query_filter=query_filter,
103 with_payload=True,
104 ).points
105 return hits
107 def __create_collection_if_missing(self):
108 """Creates the collection provided in the COLLECTION_NAME environment variable, if not already created"""
109 existing = [c.name for c in self.__client.get_collections().collections]
110 if config.COLLECTION_NAME not in existing:
111 logger.info(f"Creating Qdrant collection : '{config.COLLECTION_NAME}'...")
112 self.__client.recreate_collection(
113 collection_name=config.COLLECTION_NAME,
114 vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE),
115 on_disk_payload=True,
116 )
117 logger.info("... Done")
119 def delete(self, ids: List[str]):
120 """Deletes selected points from collection
122 Args:
123 ids: Selects points based on list of IDs
125 """
126 if ids:
127 pil = PointIdsList(points=ids)
128 self.__client.delete(collection_name=config.COLLECTION_NAME, points_selector=pil)
130 def record_embeddings(
131 self,
132 k_page: int,
133 chunks: List[ChunkType],
134 embeddings: List[EmbeddingType],
135 file_metadata: dict,
136 ):
137 """
138 Update or insert a new chunk into the collection.
140 Args:
141 chunks: List of chunks to record
142 embeddings: The corresponding list of vectors to record
143 file_metadata: Original file's information
145 """
146 filepath = file_metadata["abspath"]
148 points: list[PointStruct] = []
149 # Use MD5 of path + chunk index as unique point ID
150 for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)):
151 file_hash = hashlib.md5(f"{filepath}::{k_page}::{idx}".encode("utf-8")).hexdigest()
152 pid = str(uuid.UUID(int=int(file_hash, 16)))
153 payload = {
154 "source": str(filepath),
155 "chunk_index": idx,
156 "text": chunk,
157 "page": k_page,
158 "ocr_used": file_metadata.get("ocr_used", False),
159 }
160 points.append(PointStruct(id=pid, vector=emb, payload=payload))
162 # Upsert into Qdrant
163 if len(points) > 0:
164 self.__client.upsert(collection_name=config.COLLECTION_NAME, points=points)
165 time.sleep(0.1)