Coverage for src/ragindexer/QdrantIndexer.py: 80%

64 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-20 15:57 +0000

1import hashlib 

2from pathlib import Path 

3import time 

4from typing import Optional, List, Sequence, Union 

5import uuid 

6 

7from qdrant_client.conversions import common_types as types 

8from qdrant_client import QdrantClient 

9from qdrant_client.models import ( 

10 VectorParams, 

11 Distance, 

12 PointStruct, 

13 PointIdsList, 

14) 

15import requests 

16 

17from . import logger 

18from .config import config 

19from .models import ChunkType, EmbeddingType 

20 

21 

22# === Qdrant helper === 

23class QdrantIndexer: 

24 """Qdrant client that handles database operations based on the configuration 

25 

26 Args: 

27 vector_size: Size of the embedding vectors 

28 

29 """ 

30 

31 def __init__(self, vector_size: int): 

32 self.__client = QdrantClient(url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY) 

33 self.vector_size = vector_size 

34 self.__create_collection_if_missing() 

35 

36 def create_snapshot(self, output: Path | None = None) -> Path: 

37 snap_desc = self.__client.create_snapshot(collection_name=config.COLLECTION_NAME) 

38 

39 url = config.QDRANT_URL 

40 headers = {"api-key": config.QDRANT_API_KEY} 

41 response = requests.get(url, headers=headers) 

42 

43 if output.suffix == ".snapshot": 43 ↛ 46line 43 didn't jump to line 46 because the condition on line 43 was always true

44 snap_path = output 

45 else: 

46 snap_path = output / snap_desc.name 

47 

48 if response.status_code == 200: 48 ↛ 54line 48 didn't jump to line 54 because the condition on line 48 was always true

49 with open(snap_path, "wb") as f: 

50 for chunk in response.iter_content(chunk_size=1024): 

51 if chunk: 51 ↛ 50line 51 didn't jump to line 50 because the condition on line 51 was always true

52 f.write(chunk) 

53 

54 return snap_path 

55 

56 def info(self) -> types.CollectionInfo: 

57 info = self.__client.get_collection(collection_name=config.COLLECTION_NAME) 

58 return info 

59 

60 def empty_collection(self): 

61 self.__client.delete_collection(collection_name=config.COLLECTION_NAME) 

62 self.__create_collection_if_missing() 

63 

64 def search( 

65 self, 

66 query_vector: Optional[ 

67 Union[ 

68 Sequence[float], 

69 tuple[str, list[float]], 

70 types.NamedVector, 

71 types.NamedSparseVector, 

72 types.NumpyArray, 

73 ] 

74 ] = None, 

75 limit: Optional[int] = 10, 

76 query_filter: Optional[types.Filter] = None, 

77 ): 

78 """Search a vector in the database 

79 See https://qdrant.tech/documentation/concepts/search/ 

80 and https://qdrant.tech/documentation/concepts/filtering/ for more details 

81 

82 Args: 

83 query_vector: Search for vectors closest to this. If None, allows listing ids 

84 limit: How many results return 

85 query_filter: 

86 - Exclude vectors which doesn't fit given conditions. 

87 - If `None` - search among all vectors 

88 

89 Returns: 

90 List of found close points with similarity scores. 

91 

92 """ 

93 if query_vector is None: 

94 query_vect = [0.0] * self.vector_size # dummy vector; we only want IDs 

95 else: 

96 query_vect = query_vector 

97 

98 hits = self.__client.query_points( 

99 collection_name=config.COLLECTION_NAME, 

100 query=query_vect, 

101 limit=limit, 

102 query_filter=query_filter, 

103 with_payload=True, 

104 ).points 

105 return hits 

106 

107 def __create_collection_if_missing(self): 

108 """Creates the collection provided in the COLLECTION_NAME environment variable, if not already created""" 

109 existing = [c.name for c in self.__client.get_collections().collections] 

110 if config.COLLECTION_NAME not in existing: 

111 logger.info(f"Creating Qdrant collection : '{config.COLLECTION_NAME}'...") 

112 self.__client.recreate_collection( 

113 collection_name=config.COLLECTION_NAME, 

114 vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE), 

115 on_disk_payload=True, 

116 ) 

117 logger.info("... Done") 

118 

119 def delete(self, ids: List[str]): 

120 """Deletes selected points from collection 

121 

122 Args: 

123 ids: Selects points based on list of IDs 

124 

125 """ 

126 if ids: 

127 pil = PointIdsList(points=ids) 

128 self.__client.delete(collection_name=config.COLLECTION_NAME, points_selector=pil) 

129 

130 def record_embeddings( 

131 self, 

132 k_page: int, 

133 chunks: List[ChunkType], 

134 embeddings: List[EmbeddingType], 

135 file_metadata: dict, 

136 ): 

137 """ 

138 Update or insert a new chunk into the collection. 

139 

140 Args: 

141 chunks: List of chunks to record 

142 embeddings: The corresponding list of vectors to record 

143 file_metadata: Original file's information 

144 

145 """ 

146 filepath = file_metadata["abspath"] 

147 

148 points: list[PointStruct] = [] 

149 # Use MD5 of path + chunk index as unique point ID 

150 for idx, (chunk, emb) in enumerate(zip(chunks, embeddings)): 

151 file_hash = hashlib.md5(f"{filepath}::{k_page}::{idx}".encode("utf-8")).hexdigest() 

152 pid = str(uuid.UUID(int=int(file_hash, 16))) 

153 payload = { 

154 "source": str(filepath), 

155 "chunk_index": idx, 

156 "text": chunk, 

157 "page": k_page, 

158 "ocr_used": file_metadata.get("ocr_used", False), 

159 } 

160 points.append(PointStruct(id=pid, vector=emb, payload=payload)) 

161 

162 # Upsert into Qdrant 

163 if len(points) > 0: 

164 self.__client.upsert(collection_name=config.COLLECTION_NAME, points=points) 

165 time.sleep(0.1)