Coverage for src/ragindexer/documents/ADocument.py: 96%

36 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-20 15:57 +0000

1from abc import abstractmethod, ABC 

2from pathlib import Path 

3from typing import List, Tuple, Iterable 

4 

5from nltk.tokenize import sent_tokenize 

6from sentence_transformers import SentenceTransformer 

7 

8from ..config import config 

9from ..models import ChunkType, EmbeddingType 

10 

11 

12class ADocument(ABC): 

13 """ 

14 Handle documents based on their extension 

15 

16 Args: 

17 abspath: Path to the file to handle 

18 

19 """ 

20 

21 def __init__(self, abspath: Path): 

22 self.__abspath = abspath 

23 

24 def get_abs_path(self) -> Path: 

25 """ 

26 Get the abspath to the handled file 

27 

28 Returns: 

29 Path to the handled file 

30 

31 """ 

32 return self.__abspath 

33 

34 @abstractmethod 

35 def iterate_raw_text(self) -> Iterable[Tuple[int, str, dict]]: 

36 """ 

37 Abstract method that should implement the concrete way to handle the file. 

38 

39 Yields: 

40 A tuple with extracted text and file metadata 

41 

42 """ 

43 

44 def __get_chunk_text(self, text: str, chunk_size: int, chunk_overlap: int) -> List[ChunkType]: 

45 """ 

46 Splits text into overlapping chunks of ~chunk_size characters, aligned on sentences. 

47 """ 

48 sentences = sent_tokenize(text) 

49 chunks = [] 

50 current_chunk = "" 

51 for sent in sentences: 

52 if len(current_chunk) + len(sent) + 1 <= chunk_size: 

53 current_chunk += " " + sent if current_chunk else sent 

54 else: 

55 chunks.append(current_chunk) 

56 # Start new chunk: include overlap 

57 overlap_text = ( 

58 current_chunk[-chunk_overlap:] 

59 if chunk_overlap < len(current_chunk) 

60 else current_chunk 

61 ) 

62 current_chunk = overlap_text + " " + sent 

63 

64 if current_chunk: 

65 chunks.append(current_chunk) 

66 return chunks 

67 

68 def __get_embeddings( 

69 self, text: str, embedding_model: SentenceTransformer 

70 ) -> Tuple[List[ChunkType], List[EmbeddingType]]: 

71 chunks = self.__get_chunk_text(text, config.CHUNK_SIZE, config.CHUNK_OVERLAP) 

72 embeddings = embedding_model.encode(chunks, device="cpu", show_progress_bar=False).tolist() 

73 return chunks, embeddings 

74 

75 def process( 

76 self, embedding_model: SentenceTransformer 

77 ) -> Iterable[Tuple[int, List[ChunkType], List[EmbeddingType], dict]]: 

78 for k_page, text, file_metadata in self.iterate_raw_text(): 

79 file_metadata["abspath"] = self.get_abs_path() 

80 

81 chunks, embeddings = self.__get_embeddings(text, embedding_model) 

82 while "" in chunks: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 chunks.remove("") 

84 

85 yield k_page, chunks, embeddings, file_metadata