"""
Core KARA algorithm implementation.
"""
import hashlib
import heapq
import json
import sys
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar
from .chunkers import BaseDocumentChunker
T = TypeVar("T")
[docs]
@dataclass
class ChunkData(Generic[T]):
"""Represents a chunk with its content and metadata."""
content: Any
splits: list[T]
hash: str
document_id: Optional[int] = None
[docs]
@classmethod
def from_splits(
cls,
splits: Sequence[T],
document_id: Optional[int] = None,
serializer: Optional[Callable[[Sequence[T]], bytes]] = None,
renderer: Optional[Callable[[Sequence[T]], Any]] = None,
) -> "ChunkData[T]":
"""Create ChunkData from splits."""
content: Any
if renderer is None:
if all(isinstance(unit, str) for unit in splits):
content = "".join(splits) # type: ignore
else:
content = list(splits)
else:
content = renderer(splits)
if serializer is None:
if all(isinstance(unit, str) for unit in splits):
serialized = "".join(splits).encode("utf-8") # type: ignore
else:
serialized = json.dumps(
list(splits), separators=(",", ":"), ensure_ascii=True
).encode("utf-8")
else:
serialized = serializer(splits)
hash_value = hashlib.md5(serialized).hexdigest()
return cls(content=content, splits=list(splits), hash=hash_value, document_id=document_id)
[docs]
@dataclass
class ChunkedDocument(Generic[T]):
"""Represents the current state of the document collection."""
chunks: list[ChunkData[T]]
[docs]
def get_chunk_hashes(self) -> set[str]:
"""Get all chunk hashes in the collection."""
return {chunk.hash for chunk in self.chunks}
[docs]
def get_chunks_by_document(self, document_id: int) -> list[ChunkData[T]]:
"""Get all chunks belonging to a specific document."""
return [chunk for chunk in self.chunks if chunk.document_id == document_id]
[docs]
def get_document_ids(self) -> set[int]:
"""Get all unique document IDs in the collection."""
return {chunk.document_id for chunk in self.chunks if chunk.document_id is not None}
[docs]
def get_chunk_contents(self) -> list[Any]:
"""Get all chunk contents."""
return [chunk.content for chunk in self.chunks]
[docs]
@classmethod
def from_chunks(
cls, chunks: list[Any], chunker: BaseDocumentChunker[T], document_id: Optional[int] = None
) -> "ChunkedDocument[T]":
"""Create a :class:`ChunkedDocument` from pre-split chunks.
Args:
chunks: list of text chunks to include
document_id: Optional document identifier
Returns:
ChunkedDocument with chunks created
"""
warnings.warn(
(
"If the separator list is not the same as that used for separating previous "
"chunks, the algorithm might fail."
),
UserWarning,
stacklevel=2,
)
result = []
for chunk in chunks:
splits = chunker.normalize_chunk(chunk)
chunk_length = sum(chunker.unit_length(unit) for unit in splits)
if chunk_length > chunker.chunk_size:
raise ValueError(
"Chunk length exceeds the maximum chunk size defined in the chunker. "
f"Chunk length: {chunk_length}, Max chunk size: {chunker.chunk_size}"
)
result.append(
ChunkData.from_splits(
splits,
document_id,
serializer=chunker.serialize_units,
renderer=chunker.render_units,
)
)
return cls(chunks=result)
[docs]
@dataclass
class UpdateResult(Generic[T]):
"""Result of a KARA update operation."""
num_added: int = 0
num_reused: int = 0
num_deleted: int = 0
new_chunked_doc: Optional["ChunkedDocument[T]"] = None
def __add__(self, other: "UpdateResult[T]") -> "UpdateResult[T]":
"""Add two UpdateResult objects."""
return UpdateResult(
num_added=self.num_added + other.num_added,
num_reused=self.num_reused + other.num_reused,
num_deleted=self.num_deleted + other.num_deleted,
)
@property
def total_operations(self) -> int:
"""Total number of operations performed."""
return self.num_added + self.num_deleted
@property
def efficiency_ratio(self) -> float:
"""Ratio of skipped operations to total operations."""
total_chunks = len(self.new_chunked_doc.chunks) if self.new_chunked_doc else 0
return self.num_reused / total_chunks if total_chunks > 0 else 0.0
[docs]
class KARAUpdater(Generic[T]):
"""
Knowledge-Aware Re-embedding Algorithm updater.
Efficiently updates document collections by minimizing embedding operations
through intelligent reuse of existing chunks.
"""
def __init__(
self,
chunker: BaseDocumentChunker[T],
):
"""
Initialize the KARA updater.
Args:
chunker: Document chunker for breaking documents into optimal chunks
"""
self.chunker: BaseDocumentChunker[T] = chunker
self.max_chunk_size: int = chunker.chunk_size
[docs]
def create_collection(self, documents: list[str]) -> UpdateResult[T]:
"""
Create a new document collection from documents.
Args:
documents: list of document texts
Returns:
UpdateResult with initial chunks
"""
if not documents:
return UpdateResult(
num_added=0,
new_chunked_doc=ChunkedDocument[T](chunks=[]),
)
all_chunks = []
total_added = 0
for doc_id, document in enumerate(documents):
chunk_list = self.chunker.create_chunks(document)
for chunk in chunk_list:
splits = self.chunker.normalize_chunk(chunk)
all_chunks.append(
ChunkData.from_splits(
splits,
doc_id,
serializer=self.chunker.serialize_units,
renderer=self.chunker.render_units,
)
)
total_added += 1
return UpdateResult(
num_added=total_added,
new_chunked_doc=ChunkedDocument[T](chunks=all_chunks),
)
[docs]
def update_collection(
self, current_collection: ChunkedDocument[T], documents: list[str]
) -> UpdateResult[T]:
"""
Update the document collection with new documents.
Args:
current_collection: Current document collection state
documents: list of updated document texts
Returns:
UpdateResult with statistics and new collection
"""
if not documents:
return UpdateResult(
num_deleted=len(current_collection.chunks),
new_chunked_doc=ChunkedDocument[T](chunks=[]),
)
# Process each document separately and combine results
all_new_chunks: list[ChunkData[T]] = []
combined_result: UpdateResult[T] = UpdateResult()
old_chunk_counts: dict[str, int] = {}
for chunk in current_collection.chunks:
old_chunk_counts[chunk.hash] = old_chunk_counts.get(chunk.hash, 0) + 1
used_counts: dict[str, int] = {}
for doc_id, document in enumerate(documents):
new_splits = self.chunker._split_to_units(document)
doc_result = self._update_chunks_for_document(
current_collection, new_splits, doc_id, set(old_chunk_counts.keys())
)
assert doc_result.new_chunked_doc is not None
all_new_chunks.extend(doc_result.new_chunked_doc.chunks)
# Track which hashes are used across all documents
for chunk in doc_result.new_chunked_doc.chunks:
used_counts[chunk.hash] = used_counts.get(chunk.hash, 0) + 1
# Calculate added and reused chunks based on inventory
combined_result.num_added = 0
combined_result.num_reused = 0
for chunk_hash, count in used_counts.items():
old_count = old_chunk_counts.get(chunk_hash, 0)
reused = min(count, old_count)
combined_result.num_reused += reused
combined_result.num_added += count - reused
# Count deleted chunks considering duplicate hashes
for chunk_hash, count in old_chunk_counts.items():
reused_count = used_counts.get(chunk_hash, 0)
if reused_count < count:
combined_result.num_deleted += count - reused_count
# Create the final chunked document
combined_result.new_chunked_doc = ChunkedDocument[T](chunks=all_new_chunks)
return combined_result
def _update_chunks_for_document(
self,
current_collection: ChunkedDocument[T],
new_splits: list[T],
document_id: int,
old_chunk_hashes: set[str],
) -> UpdateResult[T]:
"""
Update chunks for a single document using the KARA algorithm.
Args:
current_collection: Current document collection state
new_splits: New splits to process for this document
document_id: ID of the document being processed
old_chunk_hashes: set of existing chunk hashes
Returns:
UpdateResult with new chunks and statistics for this document
"""
N = len(new_splits)
if N == 0:
return UpdateResult(
num_deleted=0, # Will be calculated at the end
new_chunked_doc=ChunkedDocument[T](chunks=[]),
)
# Build graph of possible chunks for this document
edges: list[list[tuple[int, float, list[T], str]]] = [[] for _ in range(N + 1)]
max_chunk_size = self.max_chunk_size
max_chunk_size_float = float(max_chunk_size)
overlap_units = self.chunker.overlap
unit_length = self.chunker.unit_length
for i in range(N):
current_length = 0
chunk_splits: list[T] = []
for j in range(i + 1, N + 1):
if j <= N:
split = new_splits[j - 1]
chunk_splits.append(split)
current_length += unit_length(split)
# A single split cannot exceed the max chunk size
# TODO: handle the edge case in which all splits are larger than max_chunk_size
if unit_length(split) > max_chunk_size:
raise ValueError(
f"Split length {unit_length(split)} exceeds max chunk size "
f"{max_chunk_size}."
)
if current_length > max_chunk_size:
break
serialized = self.chunker.serialize_units(chunk_splits)
chunk_hash = hashlib.md5(serialized).hexdigest()
fill_rate = current_length / max_chunk_size_float
penalty = (1 - fill_rate) ** 2
if chunk_hash in old_chunk_hashes:
cost = penalty
else:
cost = 1.0 + penalty
if j == N:
next_node = N
else:
next_node = max(i + 1, j - overlap_units)
edges[i].append((next_node, cost, chunk_splits.copy(), chunk_hash))
# Find optimal path using Dijkstra's algorithm with edge count tie-breaking
int_inf: int = sys.maxsize
min_cost = [float("inf")] * (N + 1)
min_num_edges = [int_inf] * (N + 1)
min_cost[0] = 0
min_num_edges[0] = 0
previous_node: list[Optional[int]] = [None] * (N + 1)
previous_edge: list[Optional[tuple[int, float, list[T], str]]] = [None] * (N + 1)
heap: list[tuple[float, int, int]] = [(0, 0, 0)] # (cost, edge_count, node)
while heap:
cost_u, edges_count_u, u = heapq.heappop(heap)
if cost_u > min_cost[u] or (cost_u == min_cost[u] and edges_count_u > min_num_edges[u]):
continue
for v, edge_cost, chunk_splits, chunk_hash in edges[u]:
new_cost = min_cost[u] + edge_cost
new_num_edges = min_num_edges[u] + 1
if new_cost < min_cost[v] or (
new_cost == min_cost[v] and new_num_edges < min_num_edges[v]
):
min_cost[v] = new_cost
min_num_edges[v] = new_num_edges
previous_node[v] = u
previous_edge[v] = (v, edge_cost, chunk_splits, chunk_hash)
heap_item: tuple[float, int, int] = (new_cost, new_num_edges, v)
heapq.heappush(heap, heap_item)
# Reconstruct the solution for this document
new_chunks: list[ChunkData[T]] = []
result: UpdateResult[T] = UpdateResult()
node = N
while node > 0:
edge = previous_edge[node]
if edge is None:
break
_, edge_cost, chunk_splits, chunk_hash = edge
chunk_data = ChunkData.from_splits(
chunk_splits,
document_id,
serializer=self.chunker.serialize_units,
renderer=self.chunker.render_units,
)
new_chunks.insert(0, chunk_data)
prev_node = previous_node[node]
if prev_node is None:
break
node = prev_node
result.new_chunked_doc = ChunkedDocument[T](chunks=new_chunks)
return result
def _update_chunks(
self, current_collection: ChunkedDocument[Any], new_splits: list[Any]
) -> UpdateResult[Any]:
"""
Update chunks using the KARA algorithm for backward compatibility.
This method handles single document updates.
Args:
current_collection: Current document collection state
new_splits: New splits to process
Returns:
UpdateResult with new chunks and statistics
"""
old_chunk_counts: dict[str, int] = {}
for chunk in current_collection.chunks:
old_chunk_counts[chunk.hash] = old_chunk_counts.get(chunk.hash, 0) + 1
# Use the new multi-document method with document_id = 0
doc_result = self._update_chunks_for_document(
current_collection, new_splits, 0, set(old_chunk_counts.keys())
)
# Count used chunks
used_counts: dict[str, int] = {}
assert doc_result.new_chunked_doc is not None
for chunk in doc_result.new_chunked_doc.chunks:
used_counts[chunk.hash] = used_counts.get(chunk.hash, 0) + 1
# Calculate added and reused chunks based on inventory
doc_result.num_added = 0
doc_result.num_reused = 0
for chunk_hash, count in used_counts.items():
old_count = old_chunk_counts.get(chunk_hash, 0)
reused = min(count, old_count)
doc_result.num_reused += reused
doc_result.num_added += count - reused
# Count deleted chunks
doc_result.num_deleted = 0
for chunk_hash, count in old_chunk_counts.items():
reused_count = used_counts.get(chunk_hash, 0)
if reused_count < count:
doc_result.num_deleted += count - reused_count
return doc_result