jack 3 mesiacov pred
rodič
commit
bf409edbfa
75 zmenil súbory, kde vykonal 3044 pridanie a 1205 odobranie
  1. 2 2
      cli/main.py
  2. 0 1
      compose.full.yaml
  3. 0 1
      compose.full_with_replicas.yaml
  4. 2 8
      core/agent/rag.py
  5. 2 12
      core/base/abstractions/__init__.py
  6. 0 3
      core/base/logger/__init__.py
  7. 2 4
      core/base/providers/auth.py
  8. 1 1
      core/base/providers/base.py
  9. 1 1
      core/base/providers/crypto.py
  10. 6 45
      core/base/providers/database.py
  11. 9 6
      core/base/providers/email.py
  12. 14 39
      core/base/providers/ingestion.py
  13. 0 3
      core/configs/full_azure.toml
  14. 0 6
      core/configs/full_local_llm.toml
  15. 0 6
      core/configs/local_llm.toml
  16. 0 3
      core/configs/r2r_azure.toml
  17. 0 3
      core/configs/r2r_azure_with_test_limits.toml
  18. 3 4
      core/database/chunks.py
  19. 120 4
      core/database/collections.py
  20. 201 1
      core/database/conversations.py
  21. 125 10
      core/database/documents.py
  22. 62 3
      core/database/files.py
  23. 4 5
      core/database/filters.py
  24. 354 114
      core/database/graphs.py
  25. 33 1
      core/database/limits.py
  26. 15 26
      core/database/postgres.py
  27. 112 2
      core/database/users.py
  28. 3 3
      core/database/vecs/adapter/base.py
  29. 0 4
      core/main/abstractions.py
  30. 2 2
      core/main/api/v3/base_router.py
  31. 109 1
      core/main/api/v3/collections_router.py
  32. 214 0
      core/main/api/v3/conversations_router.py
  33. 439 6
      core/main/api/v3/documents_router.py
  34. 347 86
      core/main/api/v3/graph_router.py
  35. 2 2
      core/main/api/v3/indices_router.py
  36. 109 1
      core/main/api/v3/users_router.py
  37. 31 51
      core/main/assembly/factory.py
  38. 0 121
      core/main/orchestration/hatchet/kg_workflow.py
  39. 0 26
      core/main/orchestration/simple/kg_workflow.py
  40. 2 2
      core/main/services/auth_service.py
  41. 14 88
      core/main/services/graph_service.py
  42. 173 7
      core/main/services/management_service.py
  43. 3 3
      core/parsers/structured/csv_parser.py
  44. 0 4
      core/pipes/__init__.py
  45. 2 2
      core/pipes/abstractions/search_pipe.py
  46. 2 2
      core/pipes/ingestion/embedding_pipe.py
  47. 1 1
      core/pipes/kg/community_summary.py
  48. 2 2
      core/pipes/kg/extraction.py
  49. 13 6
      core/providers/auth/r2r_auth.py
  50. 19 0
      core/providers/auth/supabase.py
  51. 3 3
      core/providers/ingestion/unstructured/base.py
  52. 0 69
      core/telemetry/events.py
  53. 1 17
      sdk/async_client.py
  54. 0 4
      sdk/models.py
  55. 0 48
      sdk/sync_client.py
  56. 3 3
      sdk/v3/chunks.py
  57. 5 5
      sdk/v3/collections.py
  58. 108 6
      sdk/v3/conversations.py
  59. 212 41
      sdk/v3/documents.py
  60. 2 2
      sdk/v3/graphs.py
  61. 3 29
      sdk/v3/indices.py
  62. 6 2
      sdk/v3/prompts.py
  63. 6 6
      sdk/v3/retrieval.py
  64. 13 8
      sdk/v3/users.py
  65. 2 9
      shared/abstractions/__init__.py
  66. 9 4
      shared/abstractions/base.py
  67. 1 1
      shared/abstractions/document.py
  68. 6 0
      shared/abstractions/graph.py
  69. 1 1
      shared/abstractions/ingestion.py
  70. 0 84
      shared/abstractions/kg.py
  71. 1 19
      shared/abstractions/search.py
  72. 0 1
      shared/utils/base_utils.py
  73. 77 80
      shared/utils/splitter/text.py
  74. 4 2
      tests/unit/test_collections.py
  75. 26 27
      tests/unit/test_graphs.py

+ 2 - 2
cli/main.py

@@ -1,5 +1,5 @@
 import json
 import json
-from typing import Any, Dict
+from typing import Any
 
 
 import asyncclick as click
 import asyncclick as click
 from rich.console import Console
 from rich.console import Console
@@ -73,7 +73,7 @@ def _ensure_config_dir_exists() -> None:
     CONFIG_DIR.mkdir(parents=True, exist_ok=True)
     CONFIG_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
-def save_config(config_data: Dict[str, Any]) -> None:
+def save_config(config_data: dict[str, Any]) -> None:
     """
     """
     Persist the given config data to ~/.r2r/config.json.
     Persist the given config data to ~/.r2r/config.json.
     """
     """

+ 0 - 1
compose.full.yaml

@@ -307,7 +307,6 @@ services:
       - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
       - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
       - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
       - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
       - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
       - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
-      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
       - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
       - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
       - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
       - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
 
 

+ 0 - 1
compose.full_with_replicas.yaml

@@ -305,7 +305,6 @@ services:
       - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
       - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
       - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
       - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
       - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
       - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
-      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
       - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
       - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
       - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
       - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
 
 

+ 2 - 8
core/agent/rag.py

@@ -1,5 +1,3 @@
-from typing import Union
-
 from core.agent import R2RAgent, R2RStreamingAgent
 from core.agent import R2RAgent, R2RStreamingAgent
 from core.base import (
 from core.base import (
     format_search_results_for_llm,
     format_search_results_for_llm,
@@ -126,9 +124,7 @@ class R2RRAGAgent(RAGAgentMixin, R2RAgent):
     def __init__(
     def __init__(
         self,
         self,
         database_provider: DatabaseProvider,
         database_provider: DatabaseProvider,
-        llm_provider: Union[
-            LiteLLMCompletionProvider, OpenAICompletionProvider
-        ],
+        llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
         search_pipeline: SearchPipeline,
         search_pipeline: SearchPipeline,
         config: AgentConfig,
         config: AgentConfig,
     ):
     ):
@@ -144,9 +140,7 @@ class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
     def __init__(
     def __init__(
         self,
         self,
         database_provider: DatabaseProvider,
         database_provider: DatabaseProvider,
-        llm_provider: Union[
-            LiteLLMCompletionProvider, OpenAICompletionProvider
-        ],
+        llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
         search_pipeline: SearchPipeline,
         search_pipeline: SearchPipeline,
         config: AgentConfig,
         config: AgentConfig,
     ):
     ):

+ 2 - 12
core/base/abstractions/__init__.py

@@ -25,20 +25,16 @@ from shared.abstractions.graph import (
     Graph,
     Graph,
     KGExtraction,
     KGExtraction,
     Relationship,
     Relationship,
+    StoreType,
 )
 )
 from shared.abstractions.ingestion import (
 from shared.abstractions.ingestion import (
     ChunkEnrichmentSettings,
     ChunkEnrichmentSettings,
     ChunkEnrichmentStrategy,
     ChunkEnrichmentStrategy,
 )
 )
 from shared.abstractions.kg import (
 from shared.abstractions.kg import (
-    GraphBuildSettings,
     GraphCommunitySettings,
     GraphCommunitySettings,
-    GraphEntitySettings,
-    GraphRelationshipSettings,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
-    KGEntityDeduplicationType,
     KGRunType,
     KGRunType,
 )
 )
 from shared.abstractions.llm import (
 from shared.abstractions.llm import (
@@ -59,7 +55,6 @@ from shared.abstractions.search import (
     HybridSearchSettings,
     HybridSearchSettings,
     KGCommunityResult,
     KGCommunityResult,
     KGEntityResult,
     KGEntityResult,
-    KGGlobalResult,
     KGRelationshipResult,
     KGRelationshipResult,
     KGSearchResultType,
     KGSearchResultType,
     SearchMode,
     SearchMode,
@@ -110,6 +105,7 @@ __all__ = [
     # Graph abstractions
     # Graph abstractions
     "Entity",
     "Entity",
     "Community",
     "Community",
+    "StoreType",
     "KGExtraction",
     "KGExtraction",
     "Relationship",
     "Relationship",
     # Index abstractions
     # Index abstractions
@@ -130,7 +126,6 @@ __all__ = [
     "KGEntityResult",
     "KGEntityResult",
     "KGRelationshipResult",
     "KGRelationshipResult",
     "KGCommunityResult",
     "KGCommunityResult",
-    "KGGlobalResult",
     "GraphSearchSettings",
     "GraphSearchSettings",
     "ChunkSearchSettings",
     "ChunkSearchSettings",
     "ChunkSearchResult",
     "ChunkSearchResult",
@@ -141,12 +136,7 @@ __all__ = [
     # KG abstractions
     # KG abstractions
     "KGCreationSettings",
     "KGCreationSettings",
     "KGEnrichmentSettings",
     "KGEnrichmentSettings",
-    "KGEntityDeduplicationSettings",
-    "GraphBuildSettings",
-    "GraphEntitySettings",
-    "GraphRelationshipSettings",
     "GraphCommunitySettings",
     "GraphCommunitySettings",
-    "KGEntityDeduplicationType",
     "KGRunType",
     "KGRunType",
     # User abstractions
     # User abstractions
     "Token",
     "Token",

+ 0 - 3
core/base/logger/__init__.py

@@ -1,9 +1,6 @@
-from .base import RunInfoLog
 from .run_manager import RunManager, manage_run
 from .run_manager import RunManager, manage_run
 
 
 __all__ = [
 __all__ = [
-    # Basic types
-    "RunInfoLog",
     # Run Manager
     # Run Manager
     "RunManager",
     "RunManager",
     "manage_run",
     "manage_run",

+ 2 - 4
core/base/providers/auth.py

@@ -66,10 +66,8 @@ class AuthProvider(Provider, ABC):
         self.database_provider = database_provider
         self.database_provider = database_provider
         self.email_provider = email_provider
         self.email_provider = email_provider
         super().__init__(config)
         super().__init__(config)
-        self.config: AuthConfig = config  # for type hinting
-        self.database_provider: "PostgresDatabaseProvider" = (
-            database_provider  # for type hinting
-        )
+        self.config: AuthConfig = config
+        self.database_provider: "PostgresDatabaseProvider" = database_provider
 
 
     async def _get_default_admin_user(self) -> User:
     async def _get_default_admin_user(self) -> User:
         return await self.database_provider.users_handler.get_user_by_email(
         return await self.database_provider.users_handler.get_user_by_email(

+ 1 - 1
core/base/providers/base.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Any, Optional, Sequence, Type
+from typing import Any, Optional, Type
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 

+ 1 - 1
core/base/providers/crypto.py

@@ -10,7 +10,7 @@ class CryptoConfig(ProviderConfig):
 
 
     @property
     @property
     def supported_providers(self) -> list[str]:
     def supported_providers(self) -> list[str]:
-        return ["bcrypt", "nacl"]  # Add other crypto providers as needed
+        return ["bcrypt", "nacl"]
 
 
     def validate_config(self) -> None:
     def validate_config(self) -> None:
         if self.provider not in self.supported_providers:
         if self.provider not in self.supported_providers:

+ 6 - 45
core/base/providers/database.py

@@ -1,56 +1,20 @@
-import logging
-from abc import abstractmethod
-from datetime import datetime
-from io import BytesIO
-from typing import BinaryIO, Optional, Tuple
-from uuid import UUID
-
-from pydantic import BaseModel
-
-from core.base.abstractions import (
-    ChunkSearchResult,
-    Community,
-    DocumentResponse,
-    Entity,
-    IndexArgsHNSW,
-    IndexArgsIVFFlat,
-    IndexMeasure,
-    IndexMethod,
-    KGCreationSettings,
-    KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
-    Message,
-    Relationship,
-    SearchSettings,
-    User,
-    VectorEntry,
-    VectorTableName,
-)
-from core.base.api.models import CollectionResponse, GraphResponse
-
-from .base import Provider, ProviderConfig
-
-"""Base classes for knowledge graph providers."""
+"""Base classes for database providers."""
 
 
 import logging
 import logging
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Any, Optional, Sequence, Tuple, Type
+from typing import Any, Optional, Sequence
 from uuid import UUID
 from uuid import UUID
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
-from ..abstractions import (
-    Community,
-    Entity,
+from core.base.abstractions import (
     GraphSearchSettings,
     GraphSearchSettings,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
-    KGExtraction,
-    R2RSerializable,
-    Relationship,
 )
 )
 
 
+from .base import Provider, ProviderConfig
+
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
@@ -168,14 +132,11 @@ class DatabaseConfig(ProviderConfig):
     collection_summary_task_prompt: str = "default_collection_summary"
     collection_summary_task_prompt: str = "default_collection_summary"
     enable_fts: bool = False
     enable_fts: bool = False
 
 
-    # KG settings
+    # Graph settings
     batch_size: Optional[int] = 1
     batch_size: Optional[int] = 1
     kg_store_path: Optional[str] = None
     kg_store_path: Optional[str] = None
     graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
     graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
     graph_creation_settings: KGCreationSettings = KGCreationSettings()
     graph_creation_settings: KGCreationSettings = KGCreationSettings()
-    graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
-        KGEntityDeduplicationSettings()
-    )
     graph_search_settings: GraphSearchSettings = GraphSearchSettings()
     graph_search_settings: GraphSearchSettings = GraphSearchSettings()
 
 
     # Rate limits
     # Rate limits

+ 9 - 6
core/base/providers/email.py

@@ -29,11 +29,14 @@ class EmailConfig(ProviderConfig):
         ]  # Could add more providers like AWS SES, SendGrid etc.
         ]  # Could add more providers like AWS SES, SendGrid etc.
 
 
     def validate_config(self) -> None:
     def validate_config(self) -> None:
-        if self.provider == "sendgrid":
-            if not (self.sendgrid_api_key or os.getenv("SENDGRID_API_KEY")):
-                raise ValueError(
-                    "SendGrid API key is required when using SendGrid provider"
-                )
+        if (
+            self.provider == "sendgrid"
+            and not self.sendgrid_api_key
+            and not os.getenv("SENDGRID_API_KEY")
+        ):
+            raise ValueError(
+                "SendGrid API key is required when using SendGrid provider"
+            )
 
 
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -46,7 +49,7 @@ class EmailProvider(Provider, ABC):
                 "EmailProvider must be initialized with an EmailConfig"
                 "EmailProvider must be initialized with an EmailConfig"
             )
             )
         super().__init__(config)
         super().__init__(config)
-        self.config: EmailConfig = config  # for type hinting
+        self.config: EmailConfig = config
 
 
     @abstractmethod
     @abstractmethod
     async def send_email(
     async def send_email(

+ 14 - 39
core/base/providers/ingestion.py

@@ -16,6 +16,19 @@ if TYPE_CHECKING:
     from core.database import PostgresDatabaseProvider
     from core.database import PostgresDatabaseProvider
 
 
 
 
+class ChunkingStrategy(str, Enum):
+    RECURSIVE = "recursive"
+    CHARACTER = "character"
+    BASIC = "basic"
+    BY_TITLE = "by_title"
+
+
+class IngestionMode(str, Enum):
+    hi_res = "hi-res"
+    fast = "fast"
+    custom = "custom"
+
+
 class IngestionConfig(ProviderConfig):
 class IngestionConfig(ProviderConfig):
     _defaults: ClassVar[dict] = {
     _defaults: ClassVar[dict] = {
         "app": AppConfig(),
         "app": AppConfig(),
@@ -44,7 +57,7 @@ class IngestionConfig(ProviderConfig):
     excluded_parsers: list[str] = Field(
     excluded_parsers: list[str] = Field(
         default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
         default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
     )
     )
-    chunking_strategy: str = Field(
+    chunking_strategy: str | ChunkingStrategy = Field(
         default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
         default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
     )
     )
     chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
     chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
@@ -131,31 +144,6 @@ class IngestionConfig(ProviderConfig):
         else:
         else:
             return cls(app=app)
             return cls(app=app)
 
 
-    @classmethod
-    def get_default(cls, mode: str, app) -> "IngestionConfig":
-        """Return default ingestion configuration for a given mode."""
-        if mode == "hi-res":
-            # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
-            return cls(app=app, parser_overrides={"pdf": "zerox"})
-        # elif mode == "fast":
-        #     # Skip summaries and other enrichment steps for speed.
-        #     return cls(
-        #         app=app,
-        #     )
-        else:
-            # For `custom` or any unrecognized mode, return a base config
-            return cls(app=app)
-
-    @classmethod
-    def set_default(cls, **kwargs):
-        for key, value in kwargs.items():
-            if key in cls._defaults:
-                cls._defaults[key] = value
-            else:
-                raise AttributeError(
-                    f"No default attribute '{key}' in GenerationConfig"
-                )
-
     class Config:
     class Config:
         populate_by_name = True
         populate_by_name = True
         json_schema_extra = {
         json_schema_extra = {
@@ -193,16 +181,3 @@ class IngestionProvider(Provider, ABC):
         self.config: IngestionConfig = config
         self.config: IngestionConfig = config
         self.llm_provider = llm_provider
         self.llm_provider = llm_provider
         self.database_provider: "PostgresDatabaseProvider" = database_provider
         self.database_provider: "PostgresDatabaseProvider" = database_provider
-
-
-class ChunkingStrategy(str, Enum):
-    RECURSIVE = "recursive"
-    CHARACTER = "character"
-    BASIC = "basic"
-    BY_TITLE = "by_title"
-
-
-class IngestionMode(str, Enum):
-    hi_res = "hi-res"
-    fast = "fast"
-    custom = "custom"

+ 0 - 3
core/configs/full_azure.toml

@@ -15,9 +15,6 @@ concurrent_request_limit = 128
     clustering_mode = "remote"
     clustering_mode = "remote"
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 
-  [database.graph_entity_deduplication_settings]
-    generation_config = { model = "azure/gpt-4o-mini" }
-
   [database.graph_enrichment_settings]
   [database.graph_enrichment_settings]
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 

+ 0 - 6
core/configs/full_local_llm.toml

@@ -31,12 +31,6 @@ provider = "postgres"
     max_description_input_length = 65536
     max_description_input_length = 65536
     generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
     generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
 
 
-  [database.graph_entity_deduplication_settings]
-    graph_entity_deduplication_type = "by_name"
-    graph_entity_deduplication_prompt = "graphrag_entity_deduplication"
-    max_description_input_length = 65536
-    generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication
-
   [database.graph_enrichment_settings]
   [database.graph_enrichment_settings]
     community_reports_prompt = "graphrag_community_reports"
     community_reports_prompt = "graphrag_community_reports"
     max_summary_input_length = 65536
     max_summary_input_length = 65536

+ 0 - 6
core/configs/local_llm.toml

@@ -37,12 +37,6 @@ provider = "postgres"
     max_description_input_length = 65536
     max_description_input_length = 65536
     generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
     generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
 
 
-  [database.graph_entity_deduplication_settings]
-    graph_entity_deduplication_type = "by_name"
-    graph_entity_deduplication_prompt = "graphrag_entity_deduplication"
-    max_description_input_length = 65536
-    generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication
-
   [database.graph_enrichment_settings]
   [database.graph_enrichment_settings]
     community_reports_prompt = "graphrag_community_reports"
     community_reports_prompt = "graphrag_community_reports"
     max_summary_input_length = 65536
     max_summary_input_length = 65536

+ 0 - 3
core/configs/r2r_azure.toml

@@ -13,9 +13,6 @@ batch_size = 256
   [database.graph_creation_settings]
   [database.graph_creation_settings]
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 
-  [database.graph_entity_deduplication_settings]
-    generation_config = { model = "azure/gpt-4o-mini" }
-
   [database.graph_enrichment_settings]
   [database.graph_enrichment_settings]
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 

+ 0 - 3
core/configs/r2r_azure_with_test_limits.toml

@@ -16,9 +16,6 @@ batch_size = 256
   [database.graph_creation_settings]
   [database.graph_creation_settings]
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 
-  [database.graph_entity_deduplication_settings]
-    generation_config = { model = "azure/gpt-4o-mini" }
-
   [database.graph_enrichment_settings]
   [database.graph_enrichment_settings]
     generation_config = { model = "azure/gpt-4o-mini" }
     generation_config = { model = "azure/gpt-4o-mini" }
 
 

+ 3 - 4
core/database/chunks.py

@@ -612,10 +612,9 @@ class PostgresChunksHandler(Handler):
         SET collection_ids = array_append(collection_ids, $1)
         SET collection_ids = array_append(collection_ids, $1)
         WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
         WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
         """
         """
-        result = await self.connection_manager.execute_query(
+        return await self.connection_manager.execute_query(
             query, (str(collection_id), str(document_id))
             query, (str(collection_id), str(document_id))
         )
         )
-        return result
 
 
     async def remove_document_from_collection_vector(
     async def remove_document_from_collection_vector(
         self, document_id: UUID, collection_id: UUID
         self, document_id: UUID, collection_id: UUID
@@ -883,7 +882,7 @@ class PostgresChunksHandler(Handler):
 
 
         where_clause = " AND ".join(where_clauses) if where_clauses else ""
         where_clause = " AND ".join(where_clauses) if where_clauses else ""
         if where_clause:
         if where_clause:
-            where_clause = "AND " + where_clause
+            where_clause = f"AND {where_clause}"
 
 
         query = f"""
         query = f"""
         WITH index_info AS (
         WITH index_info AS (
@@ -1223,7 +1222,7 @@ class PostgresChunksHandler(Handler):
                     ) as body_rank
                     ) as body_rank
                 FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
                 FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
                 WHERE $1 != ''
                 WHERE $1 != ''
-                {f"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
+                {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
                 GROUP BY document_id
                 GROUP BY document_id
             ),
             ),
             -- Combined scores with document metadata
             -- Combined scores with document metadata

+ 120 - 4
core/database/collections.py

@@ -1,6 +1,8 @@
+import csv
 import json
 import json
 import logging
 import logging
-from typing import Any, Optional
+import tempfile
+from typing import IO, Any, Optional
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 
 
 from asyncpg.exceptions import UniqueViolationError
 from asyncpg.exceptions import UniqueViolationError
@@ -117,6 +119,11 @@ class PostgresCollectionsHandler(Handler):
                 message="Collection with this ID already exists",
                 message="Collection with this ID already exists",
                 status_code=409,
                 status_code=409,
             )
             )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while creating the collection: {e}",
+            ) from e
 
 
     async def update_collection(
     async def update_collection(
         self,
         self,
@@ -189,7 +196,7 @@ class PostgresCollectionsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while updating the collection: {e}",
                 detail=f"An error occurred while updating the collection: {e}",
-            )
+            ) from e
 
 
     async def delete_collection_relational(self, collection_id: UUID) -> None:
     async def delete_collection_relational(self, collection_id: UUID) -> None:
         # Remove collection_id from users
         # Remove collection_id from users
@@ -361,7 +368,7 @@ class PostgresCollectionsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while fetching collections: {e}",
                 detail=f"An error occurred while fetching collections: {e}",
-            )
+            ) from e
 
 
     async def assign_document_to_collection_relational(
     async def assign_document_to_collection_relational(
         self,
         self,
@@ -435,7 +442,7 @@ class PostgresCollectionsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error '{e}' occurred while assigning the document to the collection",
                 detail=f"An error '{e}' occurred while assigning the document to the collection",
-            )
+            ) from e
 
 
     async def remove_document_from_collection_relational(
     async def remove_document_from_collection_relational(
         self, document_id: UUID, collection_id: UUID
         self, document_id: UUID, collection_id: UUID
@@ -468,3 +475,112 @@ class PostgresCollectionsHandler(Handler):
                 status_code=404,
                 status_code=404,
                 message="Document not found in the specified collection",
                 message="Document not found in the specified collection",
             )
             )
+
+    async def export_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "owner_id",
+            "name",
+            "description",
+            "graph_sync_status",
+            "graph_cluster_status",
+            "created_at",
+            "updated_at",
+            "user_count",
+            "document_count",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                owner_id::text,
+                name,
+                description,
+                graph_sync_status,
+                graph_cluster_status,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+                user_count,
+                document_count
+            FROM {self._get_table_name(self.TABLE_NAME)}
+        """
+
+        params = []
+        if filters:
+            conditions = []
+            param_index = 1
+
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+            if conditions:
+                select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e

+ 201 - 1
core/database/conversations.py

@@ -1,5 +1,7 @@
+import csv
 import json
 import json
-from typing import Any, Optional
+import tempfile
+from typing import IO, Any, Optional
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 
 
 from fastapi import HTTPException
 from fastapi import HTTPException
@@ -452,3 +454,201 @@ class PostgresConversationsHandler(Handler):
         await self.connection_manager.execute_query(
         await self.connection_manager.execute_query(
             del_conv_query, [conversation_id]
             del_conv_query, [conversation_id]
         )
         )
+
+    async def export_conversations_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "user_id",
+            "created_at",
+            "name",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                user_id::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                name
+            FROM {self._get_table_name("conversations")}
+        """
+
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
+    async def export_messages_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "conversation_id",
+            "parent_id",
+            "content",
+            "metadata",
+            "created_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                conversation_id::text,
+                parent_id::text,
+                content::text,
+                metadata::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at
+            FROM {self._get_table_name("messages")}
+        """
+
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e

+ 125 - 10
core/database/documents.py

@@ -1,8 +1,10 @@
 import asyncio
 import asyncio
 import copy
 import copy
+import csv
 import json
 import json
 import logging
 import logging
-from typing import Any, Optional
+import tempfile
+from typing import IO, Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 import asyncpg
 import asyncpg
@@ -20,7 +22,7 @@ from core.base import (
 )
 )
 
 
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
-from .filters import apply_filters  # Add this near other imports
+from .filters import apply_filters
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
@@ -247,7 +249,7 @@ class PostgresDocumentsHandler(Handler):
         Get the IDs from a given table.
         Get the IDs from a given table.
 
 
         Args:
         Args:
-            status (Union[str, list[str]]): The status or list of statuses to retrieve.
+            status (str | list[str]): The status or list of statuses to retrieve.
             table_name (str): The table name.
             table_name (str): The table name.
             status_type (str): The type of status to retrieve.
             status_type (str): The type of status to retrieve.
         """
         """
@@ -299,9 +301,7 @@ class PostgresDocumentsHandler(Handler):
             return IngestionStatus
             return IngestionStatus
         elif status_type == "extraction_status":
         elif status_type == "extraction_status":
             return KGExtractionStatus
             return KGExtractionStatus
-        elif status_type == "graph_cluster_status":
-            return KGEnrichmentStatus
-        elif status_type == "graph_sync_status":
+        elif status_type in {"graph_cluster_status", "graph_sync_status"}:
             return KGEnrichmentStatus
             return KGEnrichmentStatus
         else:
         else:
             raise R2RException(
             raise R2RException(
@@ -315,7 +315,7 @@ class PostgresDocumentsHandler(Handler):
         Get the workflow status for a given document or list of documents.
         Get the workflow status for a given document or list of documents.
 
 
         Args:
         Args:
-            id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
+            id (UUID | list[UUID]): The document ID or list of document IDs.
             status_type (str): The type of status to retrieve.
             status_type (str): The type of status to retrieve.
 
 
         Returns:
         Returns:
@@ -341,7 +341,7 @@ class PostgresDocumentsHandler(Handler):
         Set the workflow status for a given document or list of documents.
         Set the workflow status for a given document or list of documents.
 
 
         Args:
         Args:
-            id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
+            id (UUID | list[UUID]): The document ID or list of document IDs.
             status_type (str): The type of status to set.
             status_type (str): The type of status to set.
             status (str): The status to set.
             status (str): The status to set.
         """
         """
@@ -368,7 +368,7 @@ class PostgresDocumentsHandler(Handler):
         Args:
         Args:
             ids_key (str): The key to retrieve the IDs.
             ids_key (str): The key to retrieve the IDs.
             status_type (str): The type of status to retrieve.
             status_type (str): The type of status to retrieve.
-            status (Union[str, list[str]]): The status or list of statuses to retrieve.
+            status (str | list[str]): The status or list of statuses to retrieve.
         """
         """
 
 
         if isinstance(status, str):
         if isinstance(status, str):
@@ -501,7 +501,7 @@ class PostgresDocumentsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail="Database query failed",
                 detail="Database query failed",
-            )
+            ) from e
 
 
     async def semantic_document_search(
     async def semantic_document_search(
         self, query_embedding: list[float], search_settings: SearchSettings
         self, query_embedding: list[float], search_settings: SearchSettings
@@ -792,3 +792,118 @@ class PostgresDocumentsHandler(Handler):
             )
             )
         else:
         else:
             return await self.full_text_document_search(query_text, settings)
             return await self.full_text_document_search(query_text, settings)
+
+    async def export_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "collection_ids",
+            "owner_id",
+            "type",
+            "metadata",
+            "title",
+            "summary",
+            "version",
+            "size_in_bytes",
+            "ingestion_status",
+            "extraction_status",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                collection_ids::text,
+                owner_id::text,
+                type::text,
+                metadata::text AS metadata,
+                title,
+                summary,
+                version,
+                size_in_bytes,
+                ingestion_status,
+                extraction_status,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self.TABLE_NAME)}
+        """
+
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e

+ 62 - 3
core/database/files.py

@@ -1,7 +1,10 @@
 import io
 import io
 import logging
 import logging
-from typing import BinaryIO, Optional, Union
+from datetime import datetime
+from io import BytesIO
+from typing import BinaryIO, Optional
 from uuid import UUID
 from uuid import UUID
+from zipfile import ZipFile
 
 
 import asyncpg
 import asyncpg
 from fastapi import HTTPException
 from fastapi import HTTPException
@@ -119,7 +122,7 @@ class PostgresFilesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"Failed to write to large object: {e}",
                 detail=f"Failed to write to large object: {e}",
-            )
+            ) from e
 
 
     async def retrieve_file(
     async def retrieve_file(
         self, document_id: UUID
         self, document_id: UUID
@@ -150,6 +153,62 @@ class PostgresFilesHandler(Handler):
             file_content = await self._read_lobject(conn, oid)
             file_content = await self._read_lobject(conn, oid)
             return file_name, io.BytesIO(file_content), size
             return file_name, io.BytesIO(file_content), size
 
 
+    async def retrieve_files_as_zip(
+        self,
+        document_ids: Optional[list[UUID]] = None,
+        start_date: Optional[datetime] = None,
+        end_date: Optional[datetime] = None,
+    ) -> tuple[str, BinaryIO, int]:
+        """Retrieve multiple files and return them as a zip file."""
+
+        query = f"""
+        SELECT document_id, name, oid, size
+        FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+        WHERE 1=1
+        """
+        params: list = []
+
+        if document_ids:
+            query += f" AND document_id = ANY(${len(params) + 1})"
+            params.append([str(doc_id) for doc_id in document_ids])
+
+        if start_date:
+            query += f" AND created_at >= ${len(params) + 1}"
+            params.append(start_date)
+
+        if end_date:
+            query += f" AND created_at <= ${len(params) + 1}"
+            params.append(end_date)
+
+        query += " ORDER BY created_at DESC"
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        if not results:
+            raise R2RException(
+                status_code=404,
+                message="No files found matching the specified criteria",
+            )
+
+        zip_buffer = BytesIO()
+        total_size = 0
+
+        async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+            with ZipFile(zip_buffer, "w") as zip_file:
+                for record in results:
+                    file_content = await self._read_lobject(
+                        conn, record["oid"]
+                    )
+
+                    zip_file.writestr(record["name"], file_content)
+                    total_size += record["size"]
+
+        zip_buffer.seek(0)
+        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+        zip_filename = f"files_export_{timestamp}.zip"
+
+        return zip_filename, zip_buffer, zip_buffer.getbuffer().nbytes
+
     async def _read_lobject(self, conn, oid: int) -> bytes:
     async def _read_lobject(self, conn, oid: int) -> bytes:
         """Read content from a large object."""
         """Read content from a large object."""
         file_data = io.BytesIO()
         file_data = io.BytesIO()
@@ -233,7 +292,7 @@ class PostgresFilesHandler(Handler):
     ) -> list[dict]:
     ) -> list[dict]:
         """Get an overview of stored files."""
         """Get an overview of stored files."""
         conditions = []
         conditions = []
-        params: list[Union[str, list[str], int]] = []
+        params: list[str | list[str] | int] = []
         query = f"""
         query = f"""
         SELECT document_id, name, oid, size, type, created_at, updated_at
         SELECT document_id, name, oid, size, type, created_at, updated_at
         FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
         FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}

+ 4 - 5
core/database/filters.py

@@ -1,6 +1,5 @@
 import json
 import json
-from typing import Any, Optional, Tuple, Union
-from uuid import UUID
+from typing import Any, Optional, Tuple
 
 
 COLUMN_VARS = [
 COLUMN_VARS = [
     "id",
     "id",
@@ -46,7 +45,7 @@ class FilterCondition:
 class FilterExpression:
 class FilterExpression:
     def __init__(self, logical_op: Optional[str] = None):
     def __init__(self, logical_op: Optional[str] = None):
         self.logical_op = logical_op
         self.logical_op = logical_op
-        self.conditions: list[Union[FilterCondition, "FilterExpression"]] = []
+        self.conditions: list[FilterCondition | "FilterExpression"] = []
 
 
 
 
 class FilterParser:
 class FilterParser:
@@ -410,13 +409,13 @@ class SQLFilterBuilder:
 
 
 def apply_filters(
 def apply_filters(
     filters: dict, params: list[Any], mode: str = "where_clause"
     filters: dict, params: list[Any], mode: str = "where_clause"
-) -> str:
+) -> tuple[str, list[Any]]:
     """
     """
     Apply filters with consistent WHERE clause handling
     Apply filters with consistent WHERE clause handling
     """
     """
 
 
     if not filters:
     if not filters:
-        return ""
+        return "", params
 
 
     parser = FilterParser()
     parser = FilterParser()
     expr = parser.parse(filters)
     expr = parser.parse(filters)

+ 354 - 114
core/database/graphs.py

@@ -1,12 +1,13 @@
 import asyncio
 import asyncio
 import contextlib
 import contextlib
+import csv
 import datetime
 import datetime
 import json
 import json
 import logging
 import logging
 import os
 import os
+import tempfile
 import time
 import time
-from enum import Enum
-from typing import Any, AsyncGenerator, Optional, Tuple
+from typing import IO, Any, AsyncGenerator, Optional, Tuple
 from uuid import UUID
 from uuid import UUID
 
 
 import asyncpg
 import asyncpg
@@ -20,10 +21,10 @@ from core.base.abstractions import (
     Graph,
     Graph,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
     KGExtractionStatus,
     KGExtractionStatus,
     R2RException,
     R2RException,
     Relationship,
     Relationship,
+    StoreType,
     VectorQuantizationType,
     VectorQuantizationType,
 )
 )
 from core.base.api.models import GraphResponse
 from core.base.api.models import GraphResponse
@@ -37,12 +38,6 @@ from core.base.utils import (
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
 from .collections import PostgresCollectionsHandler
 from .collections import PostgresCollectionsHandler
 
 
-
-class StoreType(str, Enum):
-    GRAPHS = "graphs"
-    DOCUMENTS = "documents"
-
-
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
@@ -59,9 +54,7 @@ class PostgresEntitiesHandler(Handler):
 
 
     def _get_entity_table_for_store(self, store_type: StoreType) -> str:
     def _get_entity_table_for_store(self, store_type: StoreType) -> str:
         """Get the appropriate table name for the store type."""
         """Get the appropriate table name for the store type."""
-        if isinstance(store_type, StoreType):
-            store_type = store_type.value
-        return f"{store_type}_entities"
+        return f"{store_type.value}_entities"
 
 
     def _get_parent_constraint(self, store_type: StoreType) -> str:
     def _get_parent_constraint(self, store_type: StoreType) -> str:
         """Get the appropriate foreign key constraint for the store type."""
         """Get the appropriate foreign key constraint for the store type."""
@@ -376,6 +369,115 @@ class PostgresEntitiesHandler(Handler):
                     404,
                     404,
                 )
                 )
 
 
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "name",
+            "category",
+            "description",
+            "parent_id",
+            "chunk_ids",
+            "metadata",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                name,
+                category,
+                description,
+                parent_id::text,
+                chunk_ids::text,
+                metadata::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
+        """
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
 
 
 class PostgresRelationshipsHandler(Handler):
 class PostgresRelationshipsHandler(Handler):
     def __init__(self, *args: Any, **kwargs: Any) -> None:
     def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -390,9 +492,7 @@ class PostgresRelationshipsHandler(Handler):
 
 
     def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
     def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
         """Get the appropriate table name for the store type."""
         """Get the appropriate table name for the store type."""
-        if isinstance(store_type, StoreType):
-            store_type = store_type.value
-        return f"{store_type}_relationships"
+        return f"{store_type.value}_relationships"
 
 
     def _get_parent_constraint(self, store_type: StoreType) -> str:
     def _get_parent_constraint(self, store_type: StoreType) -> str:
         """Get the appropriate foreign key constraint for the store type."""
         """Get the appropriate foreign key constraint for the store type."""
@@ -774,6 +874,123 @@ class PostgresRelationshipsHandler(Handler):
                     404,
                     404,
                 )
                 )
 
 
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "subject",
+            "predicate",
+            "object",
+            "description",
+            "subject_id",
+            "object_id",
+            "weight",
+            "chunk_ids",
+            "parent_id",
+            "metadata",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                subject,
+                predicate,
+                object,
+                description,
+                subject_id::text,
+                object_id::text,
+                weight,
+                chunk_ids::text,
+                parent_id::text,
+                metadata::text,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
+        """
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
 
 
 class PostgresCommunitiesHandler(Handler):
 class PostgresCommunitiesHandler(Handler):
     def __init__(self, *args: Any, **kwargs: Any) -> None:
     def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -946,7 +1163,7 @@ class PostgresCommunitiesHandler(Handler):
     async def delete(
     async def delete(
         self,
         self,
         parent_id: UUID,
         parent_id: UUID,
-        community_id: UUID = None,
+        community_id: UUID,
     ) -> None:
     ) -> None:
         table_name = "graphs_communities"
         table_name = "graphs_communities"
 
 
@@ -964,7 +1181,7 @@ class PostgresCommunitiesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while deleting the community: {e}",
                 detail=f"An error occurred while deleting the community: {e}",
-            )
+            ) from e
 
 
     async def delete_all_communities(
     async def delete_all_communities(
         self,
         self,
@@ -986,7 +1203,7 @@ class PostgresCommunitiesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while deleting communities: {e}",
                 detail=f"An error occurred while deleting communities: {e}",
-            )
+            ) from e
 
 
     async def get(
     async def get(
         self,
         self,
@@ -1059,6 +1276,123 @@ class PostgresCommunitiesHandler(Handler):
 
 
         return communities, count
         return communities, count
 
 
+    async def export_to_csv(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "collection_id",
+            "community_id",
+            "level",
+            "name",
+            "summary",
+            "findings",
+            "rating",
+            "rating_explanation",
+            "created_at",
+            "updated_at",
+            "metadata",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        table_name = "graphs_communities"
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                collection_id::text,
+                community_id::text,
+                level,
+                name,
+                summary,
+                findings::text,
+                rating,
+                rating_explanation,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+                metadata::text
+            FROM {self._get_table_name(table_name)}
+        """
+
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if filters:
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+        if conditions:
+            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            ) from e
+
 
 
 class PostgresGraphsHandler(Handler):
 class PostgresGraphsHandler(Handler):
     """Handler for Knowledge Graph METHODS in PostgreSQL."""
     """Handler for Knowledge Graph METHODS in PostgreSQL."""
@@ -1572,72 +1906,6 @@ class PostgresGraphsHandler(Handler):
             + _get_str_estimation_output(estimated_time),
             + _get_str_estimation_output(estimated_time),
         }
         }
 
 
-    async def get_deduplication_estimate(
-        self,
-        collection_id: UUID,
-        kg_deduplication_settings: KGEntityDeduplicationSettings,
-    ):
-        """Get the estimated cost and time for deduplicating entities in a KG."""
-        try:
-            query = f"""
-                SELECT name, count(name)
-                FROM {self._get_table_name("entity")}
-                WHERE document_id = ANY(
-                    SELECT document_id FROM {self._get_table_name("documents")}
-                    WHERE $1 = ANY(collection_ids)
-                )
-                GROUP BY name
-                HAVING count(name) >= 5
-            """
-            entities = await self.connection_manager.fetch_query(
-                query, [collection_id]
-            )
-            num_entities = len(entities)
-
-            estimated_llm_calls = (num_entities, num_entities)
-            tokens_in_millions = (
-                estimated_llm_calls[0] * 1000 / 1000000,
-                estimated_llm_calls[1] * 5000 / 1000000,
-            )
-            cost_per_million = llm_cost_per_million_tokens(
-                kg_deduplication_settings.generation_config.model
-            )
-            estimated_cost = (
-                tokens_in_millions[0] * cost_per_million,
-                tokens_in_millions[1] * cost_per_million,
-            )
-            estimated_time = (
-                tokens_in_millions[0] * 10 / 60,
-                tokens_in_millions[1] * 10 / 60,
-            )
-
-            return {
-                "message": "Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.",
-                "num_entities": num_entities,
-                "estimated_llm_calls": _get_str_estimation_output(
-                    estimated_llm_calls
-                ),
-                "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
-                    tokens_in_millions
-                ),
-                "estimated_cost_in_usd": _get_str_estimation_output(
-                    estimated_cost
-                ),
-                "estimated_total_time_in_minutes": _get_str_estimation_output(
-                    estimated_time
-                ),
-            }
-        except UndefinedTableError:
-            raise R2RException(
-                "Entity embedding table not found. Please run `extract-triples` first.",
-                404,
-            )
-        except Exception as e:
-            logger.error(f"Error in get_deduplication_estimate: {str(e)}")
-            raise HTTPException(
-                500, "Error fetching deduplication estimate."
-            ) from e
-
     async def get_entities(
     async def get_entities(
         self,
         self,
         parent_id: UUID,
         parent_id: UUID,
@@ -1987,7 +2255,6 @@ class PostgresGraphsHandler(Handler):
             QUERY, [tuple(non_null_attrs.values())]
             QUERY, [tuple(non_null_attrs.values())]
         )
         )
 
 
-    # async def delete(self, collection_id: UUID, cascade: bool = False) -> None:
     async def delete(self, collection_id: UUID) -> None:
     async def delete(self, collection_id: UUID) -> None:
         graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
         graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
 
 
@@ -2009,33 +2276,6 @@ class PostgresGraphsHandler(Handler):
             DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
             DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
         """
         """
 
 
-        # if cascade:
-        #     documents = []
-        #     document_response = (
-        #         await self.collections_handler.documents_in_collection(
-        #             offset=0,
-        #             limit=100,
-        #             collection_id=collection_id,
-        #         )
-        #     )["results"]
-        #     documents.extend(document_response)
-        #     document_ids = [doc.id for doc in documents]
-        #     for document_id in document_ids:
-        #         self.entities.delete(
-        #             parent_id=document_id, store_type=StoreType.DOCUMENTS
-        #         )
-        #         self.relationships.delete(
-        #             parent_id=document_id, store_type=StoreType.DOCUMENTS
-        #         )
-
-        #     # setting the extraction status to PENDING for the documents in this collection.
-        #     QUERY = f"""
-        #         UPDATE {self._get_table_name("documents")} SET extraction_status = $1 WHERE $2::uuid = ANY(collection_ids)
-        #     """
-        #     await self.connection_manager.execute_query(
-        #         QUERY, [KGExtractionStatus.PENDING, collection_id]
-        #     )
-
     async def perform_graph_clustering(
     async def perform_graph_clustering(
         self,
         self,
         collection_id: UUID,
         collection_id: UUID,
@@ -2224,13 +2464,13 @@ class PostgresGraphsHandler(Handler):
                 relationship_ids_cache.setdefault(relationship.subject, [])
                 relationship_ids_cache.setdefault(relationship.subject, [])
                 if relationship.id is not None:
                 if relationship.id is not None:
                     relationship_ids_cache[relationship.subject].append(
                     relationship_ids_cache[relationship.subject].append(
-                        relationship.id
+                        int(relationship.id)
                     )
                     )
             if relationship.object is not None:
             if relationship.object is not None:
                 relationship_ids_cache.setdefault(relationship.object, [])
                 relationship_ids_cache.setdefault(relationship.object, [])
                 if relationship.id is not None:
                 if relationship.id is not None:
                     relationship_ids_cache[relationship.object].append(
                     relationship_ids_cache[relationship.object].append(
-                        relationship.id
+                        int(relationship.id)
                     )
                     )
 
 
         return relationship_ids_cache
         return relationship_ids_cache

+ 33 - 1
core/database/limits.py

@@ -4,7 +4,7 @@ from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from core.base import Handler
 from core.base import Handler
-from shared.abstractions import User  # your domain user model
+from shared.abstractions import User
 
 
 from ..base.providers.database import DatabaseConfig, LimitSettings
 from ..base.providers.database import DatabaseConfig, LimitSettings
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
@@ -87,6 +87,38 @@ class PostgresLimitsHandler(Handler):
         )
         )
         return await self._count_requests(user_id, None, start_of_month)
         return await self._count_requests(user_id, None, start_of_month)
 
 
+        return await self._count_requests(
+            user_id, route=None, since=start_of_month
+        )
+
+    def _determine_limits_for(
+        self, user_id: UUID, route: str
+    ) -> LimitSettings:
+        # Start with base limits
+        limits = self.config.limits
+
+        # Route-specific limits - directly override if present
+        if route_limits := self.config.route_limits.get(route):
+            # Only override non-None values from route_limits
+            if route_limits.global_per_min is not None:
+                limits.global_per_min = route_limits.global_per_min
+            if route_limits.route_per_min is not None:
+                limits.route_per_min = route_limits.route_per_min
+            if route_limits.monthly_limit is not None:
+                limits.monthly_limit = route_limits.monthly_limit
+
+        # User-specific limits - directly override if present
+        if user_limits := self.config.user_limits.get(user_id):
+            # Only override non-None values from user_limits
+            if user_limits.global_per_min is not None:
+                limits.global_per_min = user_limits.global_per_min
+            if user_limits.route_per_min is not None:
+                limits.route_per_min = user_limits.route_per_min
+            if user_limits.monthly_limit is not None:
+                limits.monthly_limit = user_limits.monthly_limit
+
+        return limits
+
     async def check_limits(self, user: User, route: str):
     async def check_limits(self, user: User, route: str):
         """
         """
         Perform rate limit checks for a user on a specific route.
         Perform rate limit checks for a user on a specific route.

+ 15 - 26
core/database/postgres.py

@@ -1,7 +1,6 @@
 # TODO: Clean this up and make it more congruent across the vector database and the relational database.
 # TODO: Clean this up and make it more congruent across the vector database and the relational database.
 import logging
 import logging
 import os
 import os
-import warnings
 from typing import TYPE_CHECKING, Any, Optional
 from typing import TYPE_CHECKING, Any, Optional
 
 
 from ..base.abstractions import VectorQuantizationType
 from ..base.abstractions import VectorQuantizationType
@@ -28,18 +27,11 @@ from .tokens import PostgresTokensHandler
 from .users import PostgresUserHandler
 from .users import PostgresUserHandler
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from ..providers.crypto import NaClCryptoProvider
-
-logger = logging.getLogger()
+    from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider
 
 
+    CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
 
 
-def get_env_var(new_var, old_var, config_value):
-    value = config_value or os.getenv(new_var) or os.getenv(old_var)
-    if os.getenv(old_var) and not os.getenv(new_var):
-        warnings.warn(
-            f"{old_var} is deprecated and support for it will be removed in release 3.5.0. Use {new_var} instead."
-        )
-    return value
+logger = logging.getLogger()
 
 
 
 
 class PostgresDatabaseProvider(DatabaseProvider):
 class PostgresDatabaseProvider(DatabaseProvider):
@@ -57,7 +49,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
     dimension: int
     dimension: int
     conn: Optional[Any]
     conn: Optional[Any]
 
 
-    crypto_provider: "NaClCryptoProvider"
+    crypto_provider: "CryptoProviderType"
     postgres_configuration_settings: PostgresConfigurationSettings
     postgres_configuration_settings: PostgresConfigurationSettings
     default_collection_name: str
     default_collection_name: str
     default_collection_description: str
     default_collection_description: str
@@ -81,7 +73,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
         self,
         self,
         config: DatabaseConfig,
         config: DatabaseConfig,
         dimension: int,
         dimension: int,
-        crypto_provider: "NaClCryptoProvider",
+        crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
         quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
         quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
         *args,
         *args,
         **kwargs,
         **kwargs,
@@ -89,29 +81,26 @@ class PostgresDatabaseProvider(DatabaseProvider):
         super().__init__(config)
         super().__init__(config)
 
 
         env_vars = [
         env_vars = [
-            ("user", "R2R_POSTGRES_USER", "POSTGRES_USER"),
-            ("password", "R2R_POSTGRES_PASSWORD", "POSTGRES_PASSWORD"),
-            ("host", "R2R_POSTGRES_HOST", "POSTGRES_HOST"),
-            ("port", "R2R_POSTGRES_PORT", "POSTGRES_PORT"),
-            ("db_name", "R2R_POSTGRES_DBNAME", "POSTGRES_DBNAME"),
+            ("user", "R2R_POSTGRES_USER"),
+            ("password", "R2R_POSTGRES_PASSWORD"),
+            ("host", "R2R_POSTGRES_HOST"),
+            ("port", "R2R_POSTGRES_PORT"),
+            ("db_name", "R2R_POSTGRES_DBNAME"),
         ]
         ]
 
 
-        for attr, new_var, old_var in env_vars:
-            if value := get_env_var(new_var, old_var, getattr(config, attr)):
+        for attr, env_var in env_vars:
+            if value := (getattr(config, attr) or os.getenv(env_var)):
                 setattr(self, attr, value)
                 setattr(self, attr, value)
             else:
             else:
                 raise ValueError(
                 raise ValueError(
-                    f"Error, please set a valid {new_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
+                    f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
                 )
                 )
 
 
         self.port = int(self.port)
         self.port = int(self.port)
 
 
         self.project_name = (
         self.project_name = (
-            get_env_var(
-                "R2R_PROJECT_NAME",
-                "R2R_POSTGRES_PROJECT_NAME",  # Remove this after deprecation
-                config.app.project_name,
-            )
+            config.app.project_name
+            or os.getenv("R2R_PROJECT_NAME")
             or "r2r_default"
             or "r2r_default"
         )
         )
 
 

+ 112 - 2
core/database/users.py

@@ -1,6 +1,8 @@
+import csv
 import json
 import json
+import tempfile
 from datetime import datetime
 from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import IO, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import HTTPException
 from fastapi import HTTPException
@@ -372,7 +374,7 @@ class PostgresUserHandler(Handler):
             query, [new_hashed_password, id]
             query, [new_hashed_password, id]
         )
         )
 
 
-    async def get_all_users(self) -> List[User]:
+    async def get_all_users(self) -> list[User]:
         """Get all users with minimal information."""
         """Get all users with minimal information."""
         query, params = (
         query, params = (
             QueryBuilder(self._get_table_name(self.TABLE_NAME))
             QueryBuilder(self._get_table_name(self.TABLE_NAME))
@@ -897,3 +899,111 @@ class PostgresUserHandler(Handler):
         if result is None:
         if result is None:
             raise R2RException(status_code=404, message="API key not found")
             raise R2RException(status_code=404, message="API key not found")
         return True
         return True
+
+    async def export_to_csv(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        """
+        Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
+        """
+        valid_columns = {
+            "id",
+            "email",
+            "is_superuser",
+            "is_active",
+            "is_verified",
+            "name",
+            "bio",
+            "collection_ids",
+            "created_at",
+            "updated_at",
+        }
+
+        if not columns:
+            columns = list(valid_columns)
+        elif invalid_cols := set(columns) - valid_columns:
+            raise ValueError(f"Invalid columns: {invalid_cols}")
+
+        select_stmt = f"""
+            SELECT
+                id::text,
+                email,
+                is_superuser,
+                is_active,
+                is_verified,
+                name,
+                bio,
+                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
+            FROM {self._get_table_name(self.TABLE_NAME)}
+        """
+
+        params = []
+        if filters:
+            conditions = []
+            param_index = 1
+
+            for field, value in filters.items():
+                if field not in valid_columns:
+                    continue
+
+                if isinstance(value, dict):
+                    for op, val in value.items():
+                        if op == "$eq":
+                            conditions.append(f"{field} = ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$gt":
+                            conditions.append(f"{field} > ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                        elif op == "$lt":
+                            conditions.append(f"{field} < ${param_index}")
+                            params.append(val)
+                            param_index += 1
+                else:
+                    # Direct equality
+                    conditions.append(f"{field} = ${param_index}")
+                    params.append(value)
+                    param_index += 1
+
+            if conditions:
+                select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+        select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+        temp_file = None
+        try:
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="w", delete=True, suffix=".csv"
+            )
+            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                async with conn.transaction():
+                    cursor = await conn.cursor(select_stmt, *params)
+
+                    if include_header:
+                        writer.writerow(columns)
+
+                    chunk_size = 1000
+                    while True:
+                        rows = await cursor.fetch(chunk_size)
+                        if not rows:
+                            break
+                        for row in rows:
+                            writer.writerow(row)
+
+            temp_file.flush()
+            return temp_file.name, temp_file
+
+        except Exception as e:
+            if temp_file:
+                temp_file.close()
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to export data: {str(e)}",
+            )

+ 3 - 3
core/database/vecs/adapter/base.py

@@ -8,14 +8,14 @@ All public classes, enums, and functions are re-exported by `vecs.adapters` modu
 
 
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from enum import Enum
 from enum import Enum
-from typing import Any, Generator, Iterable, Optional, Tuple, Union
+from typing import Any, Generator, Iterable, Optional, Tuple
 from uuid import UUID
 from uuid import UUID
 
 
 from vecs.exc import ArgError
 from vecs.exc import ArgError
 
 
-MetadataValues = Union[str, int, float, bool, list[str]]
+MetadataValues = str | int | float | bool | list[str]
 Metadata = dict[str, MetadataValues]
 Metadata = dict[str, MetadataValues]
-Numeric = Union[int, float, complex]
+Numeric = int | float | complex
 
 
 Record = Tuple[
 Record = Tuple[
     UUID,
     UUID,

+ 0 - 4
core/main/abstractions.py

@@ -10,8 +10,6 @@ from core.pipes import (
     EmbeddingPipe,
     EmbeddingPipe,
     GraphClusteringPipe,
     GraphClusteringPipe,
     GraphCommunitySummaryPipe,
     GraphCommunitySummaryPipe,
-    GraphDeduplicationPipe,
-    GraphDeduplicationSummaryPipe,
     GraphDescriptionPipe,
     GraphDescriptionPipe,
     GraphExtractionPipe,
     GraphExtractionPipe,
     GraphSearchSearchPipe,
     GraphSearchSearchPipe,
@@ -76,8 +74,6 @@ class R2RPipes(BaseModel):
     graph_storage_pipe: GraphStoragePipe
     graph_storage_pipe: GraphStoragePipe
     graph_description_pipe: GraphDescriptionPipe
     graph_description_pipe: GraphDescriptionPipe
     graph_clustering_pipe: GraphClusteringPipe
     graph_clustering_pipe: GraphClusteringPipe
-    graph_deduplication_pipe: GraphDeduplicationPipe
-    graph_deduplication_summary_pipe: GraphDeduplicationSummaryPipe
     graph_community_summary_pipe: GraphCommunitySummaryPipe
     graph_community_summary_pipe: GraphCommunitySummaryPipe
     rag_pipe: RAGPipe
     rag_pipe: RAGPipe
     streaming_rag_pipe: StreamingRAGPipe
     streaming_rag_pipe: StreamingRAGPipe

+ 2 - 2
core/main/api/v3/base_router.py

@@ -4,7 +4,7 @@ from abc import abstractmethod
 from typing import Callable, Optional
 from typing import Callable, Optional
 
 
 from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
 from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
-from fastapi.responses import StreamingResponse
+from fastapi.responses import FileResponse, StreamingResponse
 
 
 from core.base import R2RException, manage_run
 from core.base import R2RException, manage_run
 
 
@@ -64,7 +64,7 @@ class BaseRouterV3:
                     else:
                     else:
                         results, outer_kwargs = func_result, {}
                         results, outer_kwargs = func_result, {}
 
 
-                    if isinstance(results, StreamingResponse):
+                    if isinstance(results, (StreamingResponse, FileResponse)):
                         return results
                         return results
                     return {"results": results, **outer_kwargs}
                     return {"results": results, **outer_kwargs}
 
 

+ 109 - 1
core/main/api/v3/collections_router.py

@@ -1,10 +1,12 @@
 import logging
 import logging
 import textwrap
 import textwrap
-import time
+from tempfile import NamedTemporaryFile
 from typing import Optional
 from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
 
 
 from core.base import KGCreationSettings, KGRunType, R2RException
 from core.base import KGCreationSettings, KGRunType, R2RException
 from core.base.api.models import (
 from core.base.api.models import (
@@ -177,6 +179,112 @@ class CollectionsRouter(BaseRouterV3):
             )
             )
             return collection
             return collection
 
 
+        @self.router.post(
+            "/collections/export",
+            summary="Export collections to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.collections.export(
+                                output_path="export.csv",
+                                columns=["id", "name", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.collections.export({
+                                    outputPath: "export.csv",
+                                    columns: ["id", "name", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/collections/export" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_collections(
+            background_tasks: BackgroundTasks,
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export collections as a CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_collections(
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="collections_export.csv",
+            )
+
         @self.router.get(
         @self.router.get(
             "/collections",
             "/collections",
             summary="List collections",
             summary="List collections",

+ 214 - 0
core/main/api/v3/conversations_router.py

@@ -4,6 +4,8 @@ from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
 
 
 from core.base import Message, R2RException
 from core.base import Message, R2RException
 from core.base.api.models import (
 from core.base.api.models import (
@@ -206,6 +208,218 @@ class ConversationsRouter(BaseRouterV3):
                 "total_entries": conversations_response["total_entries"]
                 "total_entries": conversations_response["total_entries"]
             }
             }
 
 
+        @self.router.post(
+            "/conversations/export",
+            summary="Export conversations to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.conversations.export(
+                                output_path="export.csv",
+                                columns=["id", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.conversations.export({
+                                    outputPath: "export.csv",
+                                    columns: ["id", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/conversations/export" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_conversations(
+            background_tasks: BackgroundTasks,
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export conversations as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_conversations(
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
+        @self.router.post(
+            "/conversations/export_messages",
+            summary="Export messages to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.conversations.export_messages(
+                                output_path="export.csv",
+                                columns=["id", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.conversations.exportMessages({
+                                    outputPath: "export.csv",
+                                    columns: ["id", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/conversations/export_messages" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_messages(
+            background_tasks: BackgroundTasks,
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export conversations as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_messages(
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.get(
         @self.router.get(
             "/conversations/{id}",
             "/conversations/{id}",
             summary="Get conversation details",
             summary="Get conversation details",

+ 439 - 6
core/main/api/v3/documents_router.py

@@ -3,12 +3,14 @@ import json
 import logging
 import logging
 import mimetypes
 import mimetypes
 import textwrap
 import textwrap
+from datetime import datetime
 from io import BytesIO
 from io import BytesIO
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, File, Form, Path, Query, UploadFile
 from fastapi import Body, Depends, File, Form, Path, Query, UploadFile
-from fastapi.responses import StreamingResponse
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse, StreamingResponse
 from pydantic import Json
 from pydantic import Json
 
 
 from core.base import (
 from core.base import (
@@ -23,7 +25,7 @@ from core.base import (
     generate_id,
     generate_id,
     select_search_filters,
     select_search_filters,
 )
 )
-from core.base.abstractions import KGCreationSettings, KGRunType
+from core.base.abstractions import KGCreationSettings, KGRunType, StoreType
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     WrappedBooleanResponse,
     WrappedBooleanResponse,
@@ -444,7 +446,7 @@ class DocumentsRouter(BaseRouterV3):
                         )
                         )
                     )
                     )
                     raw_message["document_id"] = str(document_id)
                     raw_message["document_id"] = str(document_id)
-                    return raw_message
+                    return raw_message  # type: ignore
 
 
                 else:
                 else:
                     logger.info(
                     logger.info(
@@ -546,6 +548,211 @@ class DocumentsRouter(BaseRouterV3):
                     "task_id": None,
                     "task_id": None,
                 }
                 }
 
 
+        @self.router.post(
+            "/documents/export",
+            summary="Export documents to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.documents.export(
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.documents.export({
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/documents/export" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_documents(
+            background_tasks: BackgroundTasks,
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_documents(
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
+        @self.router.get(
+            "/documents/download_zip",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            response_class=StreamingResponse,
+            summary="Export multiple documents as zip",
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            client.documents.download_zip(
+                                document_ids=["uuid1", "uuid2"],
+                                start_date="2024-01-01",
+                                end_date="2024-12-31"
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X GET "https://api.example.com/v3/documents/download_zip?document_ids=uuid1,uuid2&start_date=2024-01-01&end_date=2024-12-31" \\
+                            -H "Authorization: Bearer YOUR_API_KEY"
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_files(
+            document_ids: Optional[list[UUID]] = Query(
+                None,
+                description="List of document IDs to include in the export. If not provided, all accessible documents will be included.",
+            ),
+            start_date: Optional[datetime] = Query(
+                None,
+                description="Filter documents created on or after this date.",
+            ),
+            end_date: Optional[datetime] = Query(
+                None,
+                description="Filter documents created before this date.",
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> StreamingResponse:
+            """
+            Export multiple documents as a zip file. Documents can be filtered by IDs and/or date range.
+
+            The endpoint allows downloading:
+            - Specific documents by providing their IDs
+            - Documents within a date range
+            - All accessible documents if no filters are provided
+
+            Files are streamed as a zip archive to handle potentially large downloads efficiently.
+            """
+            if not auth_user.is_superuser:
+                # For non-superusers, verify access to requested documents
+                if document_ids:
+                    documents_overview = (
+                        await self.services.management.documents_overview(
+                            user_ids=[auth_user.id],
+                            document_ids=document_ids,
+                            offset=0,
+                            limit=len(document_ids),
+                        )
+                    )
+                    if len(documents_overview["results"]) != len(document_ids):
+                        raise R2RException(
+                            status_code=403,
+                            message="You don't have access to one or more requested documents.",
+                        )
+                if not document_ids:
+                    raise R2RException(
+                        status_code=403,
+                        message="Non-superusers must provide document IDs to export.",
+                    )
+
+            zip_name, zip_content, zip_size = (
+                await self.services.management.export_files(
+                    document_ids=document_ids,
+                    start_date=start_date,
+                    end_date=end_date,
+                )
+            )
+
+            async def stream_file():
+                yield zip_content.getvalue()
+
+            return StreamingResponse(
+                stream_file(),
+                media_type="application/zip",
+                headers={
+                    "Content-Disposition": f'attachment; filename="{zip_name}"',
+                    "Content-Length": str(zip_size),
+                },
+            )
+
         @self.router.get(
         @self.router.get(
             "/documents",
             "/documents",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -1377,7 +1584,7 @@ class DocumentsRouter(BaseRouterV3):
                     "user": auth_user.json(),
                     "user": auth_user.json(),
                 }
                 }
 
 
-                return await self.providers.orchestration.run_workflow(
+                return await self.providers.orchestration.run_workflow(  # type: ignore
                     "extract-triples", {"request": workflow_input}, {}
                     "extract-triples", {"request": workflow_input}, {}
                 )
                 )
             else:
             else:
@@ -1482,7 +1689,7 @@ class DocumentsRouter(BaseRouterV3):
                 count,
                 count,
             ) = await self.providers.database.graphs_handler.entities.get(
             ) = await self.providers.database.graphs_handler.entities.get(
                 parent_id=id,
                 parent_id=id,
-                store_type="documents",
+                store_type=StoreType.DOCUMENTS,
                 offset=offset,
                 offset=offset,
                 limit=limit,
                 limit=limit,
                 include_embeddings=include_embeddings,
                 include_embeddings=include_embeddings,
@@ -1490,6 +1697,119 @@ class DocumentsRouter(BaseRouterV3):
 
 
             return entities, {"total_entries": count}  # type: ignore
             return entities, {"total_entries": count}  # type: ignore
 
 
+        @self.router.post(
+            "/documents/{id}/entities/export",
+            summary="Export document entities to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.documents.export_entities(
+                                id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.documents.exportEntities({
+                                    id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_entities(
+            background_tasks: BackgroundTasks,
+            id: UUID = Path(
+                ...,
+                description="The ID of the document to export entities from.",
+            ),
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_document_entities(
+                    id=id,
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.get(
         @self.router.get(
             "/documents/{id}/relationships",
             "/documents/{id}/relationships",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -1624,7 +1944,7 @@ class DocumentsRouter(BaseRouterV3):
                 count,
                 count,
             ) = await self.providers.database.graphs_handler.relationships.get(
             ) = await self.providers.database.graphs_handler.relationships.get(
                 parent_id=id,
                 parent_id=id,
-                store_type="documents",
+                store_type=StoreType.DOCUMENTS,
                 entity_names=entity_names,
                 entity_names=entity_names,
                 relationship_types=relationship_types,
                 relationship_types=relationship_types,
                 offset=offset,
                 offset=offset,
@@ -1633,6 +1953,119 @@ class DocumentsRouter(BaseRouterV3):
 
 
             return relationships, {"total_entries": count}  # type: ignore
             return relationships, {"total_entries": count}  # type: ignore
 
 
+        @self.router.post(
+            "/documents/{id}/relationships/export",
+            summary="Export document relationships to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.documents.export_entities(
+                                id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.documents.exportEntities({
+                                    id: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/documents/export_entities" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_relationships(
+            background_tasks: BackgroundTasks,
+            id: UUID = Path(
+                ...,
+                description="The ID of the document to export entities from.",
+            ),
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_document_relationships(
+                    id=id,
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.post(
         @self.router.post(
             "/documents/search",
             "/documents/search",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],

+ 347 - 86
core/main/api/v3/graph_router.py

@@ -4,9 +4,11 @@ from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
 
 
 from core.base import KGEnrichmentStatus, R2RException, Workflow
 from core.base import KGEnrichmentStatus, R2RException, Workflow
-from core.base.abstractions import KGRunType
+from core.base.abstractions import KGRunType, StoreType
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     WrappedBooleanResponse,
     WrappedBooleanResponse,
@@ -48,9 +50,6 @@ class GraphRouter(BaseRouterV3):
             workflow_messages["build-communities"] = (
             workflow_messages["build-communities"] = (
                 "Graph enrichment task queued successfully."
                 "Graph enrichment task queued successfully."
             )
             )
-            workflow_messages["entity-deduplication"] = (
-                "KG Entity Deduplication task queued successfully."
-            )
         else:
         else:
             workflow_messages["extract-triples"] = (
             workflow_messages["extract-triples"] = (
                 "Document entities and relationships extracted successfully."
                 "Document entities and relationships extracted successfully."
@@ -58,9 +57,6 @@ class GraphRouter(BaseRouterV3):
             workflow_messages["build-communities"] = (
             workflow_messages["build-communities"] = (
                 "Graph communities created successfully."
                 "Graph communities created successfully."
             )
             )
-            workflow_messages["entity-deduplication"] = (
-                "KG Entity Deduplication completed successfully."
-            )
 
 
         self.providers.orchestration.register_workflows(
         self.providers.orchestration.register_workflows(
             Workflow.KG,
             Workflow.KG,
@@ -68,80 +64,6 @@ class GraphRouter(BaseRouterV3):
             workflow_messages,
             workflow_messages,
         )
         )
 
 
-    async def _deduplicate_entities(
-        self,
-        collection_id: UUID,
-        settings,
-        run_type: Optional[KGRunType] = KGRunType.ESTIMATE,
-        run_with_orchestration: bool = True,
-        auth_user=None,
-    ):
-        """Deduplicates entities in the knowledge graph using LLM-based analysis.
-
-        The deduplication process:
-        1. Groups potentially duplicate entities by name/type
-        2. Uses LLM analysis to determine if entities refer to same thing
-        3. Merges duplicate entities while preserving relationships
-        4. Updates all references to use canonical entity IDs
-
-        Args:
-            id (UUID): Graph containing the entities
-            settings (dict, optional): Deduplication settings including:
-                - graph_entity_deduplication_type (str): Deduplication method (e.g. "by_name")
-                - graph_entity_deduplication_prompt (str): Custom prompt for analysis
-                - max_description_input_length (int): Max chars for entity descriptions
-                - generation_config (dict): LLM generation parameters
-            run_type (KGRunType): Whether to estimate cost or run deduplication
-            run_with_orchestration (bool): Whether to run async with task queue
-            auth_user: Authenticated user making request
-
-        Returns:
-            Result containing:
-                message (str): Status message
-                task_id (UUID): Async task ID if run with orchestration
-
-        Raises:
-            R2RException: If user unauthorized or deduplication fails
-        """
-        if not auth_user.is_superuser:
-            raise R2RException(
-                "Only superusers can deduplicate a graphs entities", 403
-            )
-
-        server_settings = (
-            self.providers.database.config.graph_entity_deduplication_settings
-        )
-        if settings:
-            server_settings = update_settings_from_dict(
-                server_settings, settings
-            )
-
-        # Return cost estimate if requested
-        if run_type == KGRunType.ESTIMATE:
-            return await self.services.graph.get_deduplication_estimate(
-                collection_id, server_settings
-            )
-
-        workflow_input = {
-            "graph_id": str(collection_id),
-            "graph_entity_deduplication_settings": server_settings.model_dump_json(),
-            "user": auth_user.model_dump_json(),
-        }
-
-        if run_with_orchestration:
-            return await self.providers.orchestration.run_workflow(  # type: ignore
-                "entity-deduplication", {"request": workflow_input}, {}
-            )
-        else:
-            from core.main.orchestration import simple_kg_factory
-
-            simple_kg = simple_kg_factory(self.services.graph)
-            await simple_kg["entity-deduplication"](workflow_input)
-            return {  # type: ignore
-                "message": "Entity deduplication completed successfully.",
-                "task_id": None,
-            }
-
     async def _get_collection_id(
     async def _get_collection_id(
         self, collection_id: Optional[UUID], auth_user
         self, collection_id: Optional[UUID], auth_user
     ) -> UUID:
     ) -> UUID:
@@ -307,7 +229,7 @@ class GraphRouter(BaseRouterV3):
                 offset=0,
                 offset=0,
                 limit=1,
                 limit=1,
             )
             )
-            return list_graphs_response["results"][0]
+            return list_graphs_response["results"][0]  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/communities/build",
             "/graphs/{collection_id}/communities/build",
@@ -645,6 +567,119 @@ class GraphRouter(BaseRouterV3):
                 "total_entries": count,
                 "total_entries": count,
             }
             }
 
 
+        @self.router.post(
+            "/graphs/{collection_id}/entities/export",
+            summary="Export graph entities to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.graphs.export_entities(
+                                collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.graphs.exportEntities({
+                                    collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/graphs/export_entities" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_entities(
+            background_tasks: BackgroundTasks,
+            collection_id: UUID = Path(
+                ...,
+                description="The ID of the collection to export entities from.",
+            ),
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_graph_entities(
+                    id=collection_id,
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/entities",
             "/graphs/{collection_id}/entities",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -754,6 +789,119 @@ class GraphRouter(BaseRouterV3):
                 parent_id=collection_id,
                 parent_id=collection_id,
             )
             )
 
 
+        @self.router.post(
+            "/graphs/{collection_id}/relationships/export",
+            summary="Export graph relationships to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.graphs.export_entities(
+                                collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.graphs.exportEntities({
+                                    collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/graphs/export_relationships" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_relationships(
+            background_tasks: BackgroundTasks,
+            collection_id: UUID = Path(
+                ...,
+                description="The ID of the document to export entities from.",
+            ),
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_graph_relationships(
+                    id=collection_id,
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/entities/{entity_id}",
             "/graphs/{collection_id}/entities/{entity_id}",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -821,7 +969,7 @@ class GraphRouter(BaseRouterV3):
 
 
             result = await self.providers.database.graphs_handler.entities.get(
             result = await self.providers.database.graphs_handler.entities.get(
                 parent_id=collection_id,
                 parent_id=collection_id,
-                store_type="graphs",
+                store_type=StoreType.GRAPHS,
                 offset=0,
                 offset=0,
                 limit=1,
                 limit=1,
                 entity_ids=[entity_id],
                 entity_ids=[entity_id],
@@ -1110,7 +1258,7 @@ class GraphRouter(BaseRouterV3):
             results = (
             results = (
                 await self.providers.database.graphs_handler.relationships.get(
                 await self.providers.database.graphs_handler.relationships.get(
                     parent_id=collection_id,
                     parent_id=collection_id,
-                    store_type="graphs",
+                    store_type=StoreType.GRAPHS,
                     offset=0,
                     offset=0,
                     limit=1,
                     limit=1,
                     relationship_ids=[relationship_id],
                     relationship_ids=[relationship_id],
@@ -1526,7 +1674,7 @@ class GraphRouter(BaseRouterV3):
                 await self.providers.database.graphs_handler.communities.get(
                 await self.providers.database.graphs_handler.communities.get(
                     parent_id=collection_id,
                     parent_id=collection_id,
                     community_ids=[community_id],
                     community_ids=[community_id],
-                    store_type="graphs",
+                    store_type=StoreType.GRAPHS,
                     offset=0,
                     offset=0,
                     limit=1,
                     limit=1,
                 )
                 )
@@ -1615,6 +1763,119 @@ class GraphRouter(BaseRouterV3):
             )
             )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
+        @self.router.post(
+            "/graphs/{collection_id}/communities/export",
+            summary="Export document communities to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.graphs.export_communities(
+                                collection_id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                output_path="export.csv",
+                                columns=["id", "title", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.graphs.exportCommunities({
+                                    collectionId: "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
+                                    outputPath: "export.csv",
+                                    columns: ["id", "title", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/graphs/export_communities" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "title", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_relationships(
+            background_tasks: BackgroundTasks,
+            collection_id: UUID = Path(
+                ...,
+                description="The ID of the document to export entities from.",
+            ),
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export documents as a downloadable CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_graph_communities(
+                    id=collection_id,
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="documents_export.csv",
+            )
+
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/communities/{community_id}",
             "/graphs/{collection_id}/communities/{community_id}",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -1839,7 +2100,7 @@ class GraphRouter(BaseRouterV3):
                 entities = (
                 entities = (
                     await self.providers.database.graphs_handler.entities.get(
                     await self.providers.database.graphs_handler.entities.get(
                         parent_id=document.id,
                         parent_id=document.id,
-                        store_type="documents",
+                        store_type=StoreType.DOCUMENTS,
                         offset=0,
                         offset=0,
                         limit=100,
                         limit=100,
                     )
                     )

+ 2 - 2
core/main/api/v3/indices_router.py

@@ -230,7 +230,7 @@ class IndicesRouter(BaseRouterV3):
                 },
                 },
             )
             )
 
 
-            return result
+            return result  # type: ignore
 
 
         @self.router.get(
         @self.router.get(
             "/indices",
             "/indices",
@@ -599,7 +599,7 @@ class IndicesRouter(BaseRouterV3):
                 f"Deleting vector index {index_name} from table {table_name}"
                 f"Deleting vector index {index_name} from table {table_name}"
             )
             )
 
 
-            return await self.providers.orchestration.run_workflow(
+            return await self.providers.orchestration.run_workflow(  # type: ignore
                 "delete-vector-index",
                 "delete-vector-index",
                 {
                 {
                     "request": {
                     "request": {

+ 109 - 1
core/main/api/v3/users_router.py

@@ -1,8 +1,10 @@
 import textwrap
 import textwrap
-from typing import Optional, Union
+from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
+from fastapi.background import BackgroundTasks
+from fastapi.responses import FileResponse
 from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
 from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
 from pydantic import EmailStr
 from pydantic import EmailStr
 
 
@@ -139,6 +141,112 @@ class UsersRouter(BaseRouterV3):
 
 
             return registration_response
             return registration_response
 
 
+        @self.router.post(
+            "/users/export",
+            summary="Export users to CSV",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            response = client.users.export(
+                                output_path="export.csv",
+                                columns=["id", "name", "created_at"],
+                                include_header=True,
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                await client.users.export({
+                                    outputPath: "export.csv",
+                                    columns: ["id", "name", "created_at"],
+                                    includeHeader: true,
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "http://127.0.0.1:7272/v3/users/export" \
+                            -H "Authorization: Bearer YOUR_API_KEY" \
+                            -H "Content-Type: application/json" \
+                            -H "Accept: text/csv" \
+                            -d '{ "columns": ["id", "name", "created_at"], "include_header": true }' \
+                            --output export.csv
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def export_users(
+            background_tasks: BackgroundTasks,
+            columns: Optional[list[str]] = Body(
+                None, description="Specific columns to export"
+            ),
+            filters: Optional[dict] = Body(
+                None, description="Filters to apply to the export"
+            ),
+            include_header: Optional[bool] = Body(
+                True, description="Whether to include column headers"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> FileResponse:
+            """
+            Export users as a CSV file.
+            """
+
+            if not auth_user.is_superuser:
+                raise R2RException(
+                    "Only a superuser can export data.",
+                    403,
+                )
+
+            csv_file_path, temp_file = (
+                await self.services.management.export_users(
+                    columns=columns,
+                    filters=filters,
+                    include_header=include_header,
+                )
+            )
+
+            background_tasks.add_task(temp_file.close)
+
+            return FileResponse(
+                path=csv_file_path,
+                media_type="text/csv",
+                filename="users_export.csv",
+            )
+
         # TODO: deprecated, remove in next release
         # TODO: deprecated, remove in next release
         @self.router.post(
         @self.router.post(
             "/users/register",
             "/users/register",

+ 31 - 51
core/main/assembly/factory.py

@@ -17,7 +17,23 @@ from core.base import (
     OrchestrationConfig,
     OrchestrationConfig,
 )
 )
 from core.pipelines import RAGPipeline, SearchPipeline
 from core.pipelines import RAGPipeline, SearchPipeline
-from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe
+from core.pipes import (
+    EmbeddingPipe,
+    GeneratorPipe,
+    GraphClusteringPipe,
+    GraphCommunitySummaryPipe,
+    GraphDescriptionPipe,
+    GraphExtractionPipe,
+    GraphSearchSearchPipe,
+    GraphStoragePipe,
+    MultiSearchPipe,
+    ParsingPipe,
+    RAGPipe,
+    SearchPipe,
+    StreamingRAGPipe,
+    VectorSearchPipe,
+    VectorStoragePipe,
+)
 from core.providers.email.sendgrid import SendGridEmailProvider
 from core.providers.email.sendgrid import SendGridEmailProvider
 
 
 from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
 from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
@@ -366,20 +382,20 @@ class R2RPipeFactory:
 
 
     def create_pipes(
     def create_pipes(
         self,
         self,
-        parsing_pipe_override: Optional[AsyncPipe] = None,
-        embedding_pipe_override: Optional[AsyncPipe] = None,
-        graph_extraction_pipe_override: Optional[AsyncPipe] = None,
-        graph_storage_pipe_override: Optional[AsyncPipe] = None,
-        graph_search_pipe_override: Optional[AsyncPipe] = None,
-        vector_storage_pipe_override: Optional[AsyncPipe] = None,
-        vector_search_pipe_override: Optional[AsyncPipe] = None,
-        rag_pipe_override: Optional[AsyncPipe] = None,
-        streaming_rag_pipe_override: Optional[AsyncPipe] = None,
-        graph_description_pipe: Optional[AsyncPipe] = None,
-        graph_clustering_pipe: Optional[AsyncPipe] = None,
-        graph_deduplication_pipe: Optional[AsyncPipe] = None,
-        graph_deduplication_summary_pipe: Optional[AsyncPipe] = None,
-        graph_community_summary_pipe: Optional[AsyncPipe] = None,
+        parsing_pipe_override: Optional[ParsingPipe] = None,
+        embedding_pipe_override: Optional[EmbeddingPipe] = None,
+        graph_extraction_pipe_override: Optional[GraphExtractionPipe] = None,
+        graph_storage_pipe_override: Optional[GraphStoragePipe] = None,
+        graph_search_pipe_override: Optional[GraphSearchSearchPipe] = None,
+        vector_storage_pipe_override: Optional[VectorStoragePipe] = None,
+        vector_search_pipe_override: Optional[VectorSearchPipe] = None,
+        rag_pipe_override: Optional[RAGPipe] = None,
+        streaming_rag_pipe_override: Optional[StreamingRAGPipe] = None,
+        graph_description_pipe: Optional[GraphDescriptionPipe] = None,
+        graph_clustering_pipe: Optional[GraphClusteringPipe] = None,
+        graph_community_summary_pipe: Optional[
+            GraphCommunitySummaryPipe
+        ] = None,
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> R2RPipes:
     ) -> R2RPipes:
@@ -410,10 +426,6 @@ class R2RPipeFactory:
             or self.create_graph_description_pipe(*args, **kwargs),
             or self.create_graph_description_pipe(*args, **kwargs),
             graph_clustering_pipe=graph_clustering_pipe
             graph_clustering_pipe=graph_clustering_pipe
             or self.create_graph_clustering_pipe(*args, **kwargs),
             or self.create_graph_clustering_pipe(*args, **kwargs),
-            graph_deduplication_pipe=graph_deduplication_pipe
-            or self.create_graph_deduplication_pipe(*args, **kwargs),
-            graph_deduplication_summary_pipe=graph_deduplication_summary_pipe
-            or self.create_graph_deduplication_summary_pipe(*args, **kwargs),
             graph_community_summary_pipe=graph_community_summary_pipe
             graph_community_summary_pipe=graph_community_summary_pipe
             or self.create_graph_community_summary_pipe(*args, **kwargs),
             or self.create_graph_community_summary_pipe(*args, **kwargs),
         )
         )
@@ -598,16 +610,6 @@ class R2RPipeFactory:
             config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
             config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
         )
         )
 
 
-    def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import GraphDeduplicationSummaryPipe
-
-        return GraphDeduplicationSummaryPipe(
-            database_provider=self.providers.database,
-            llm_provider=self.providers.llm,
-            embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
-        )
-
     def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
     def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
         from core.pipes import GraphCommunitySummaryPipe
         from core.pipes import GraphCommunitySummaryPipe
 
 
@@ -618,28 +620,6 @@ class R2RPipeFactory:
             config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
             config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
         )
         )
 
 
-    def create_graph_deduplication_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import GraphDeduplicationPipe
-
-        return GraphDeduplicationPipe(
-            database_provider=self.providers.database,
-            llm_provider=self.providers.llm,
-            embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
-        )
-
-    def create_graph_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import GraphDeduplicationSummaryPipe
-
-        return GraphDeduplicationSummaryPipe(
-            database_provider=self.providers.database,
-            llm_provider=self.providers.llm,
-            embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(
-                name="graph_deduplication_summary_pipe"
-            ),
-        )
-
 
 
 class R2RPipelineFactory:
 class R2RPipelineFactory:
     def __init__(
     def __init__(

+ 0 - 121
core/main/orchestration/hatchet/kg_workflow.py

@@ -97,25 +97,6 @@ def hatchet_kg_factory(
                 except:
                 except:
                     pass
                     pass
 
 
-            if key == "graph_entity_deduplication_settings":
-                try:
-                    input_data[key] = json.loads(value)
-                except:
-                    pass
-
-                if isinstance(input_data[key]["generation_config"], str):
-                    input_data[key]["generation_config"] = json.loads(
-                        input_data[key]["generation_config"]
-                    )
-
-                input_data[key]["generation_config"] = GenerationConfig(
-                    **input_data[key]["generation_config"]
-                )
-
-                logger.info(
-                    f"KG Entity Deduplication Settings: {input_data[key]}"
-                )
-
             if key == "generation_config":
             if key == "generation_config":
                 input_data[key] = GenerationConfig(**input_data[key])
                 input_data[key] = GenerationConfig(**input_data[key])
         return input_data
         return input_data
@@ -383,104 +364,6 @@ def hatchet_kg_factory(
                     f"Failed to update document status for {document_id}: {e}"
                     f"Failed to update document status for {document_id}: {e}"
                 )
                 )
 
 
-    @orchestration_provider.workflow(
-        name="entity-deduplication", timeout="360m"
-    )
-    class EntityDeduplicationWorkflow:
-        def __init__(self, kg_service: GraphService):
-            self.kg_service = kg_service
-
-        @orchestration_provider.step(retries=0, timeout="360m")
-        async def kg_entity_deduplication_setup(
-            self, context: Context
-        ) -> dict:
-            input_data = get_input_data_dict(
-                context.workflow_input()["request"]
-            )
-
-            graph_id = input_data["graph_id"]
-
-            logger.info(
-                f"Running KG Entity Deduplication for collection {graph_id}"
-            )
-            logger.info(f"Input data: {input_data}")
-            logger.info(
-                f"KG Entity Deduplication Settings: {input_data['graph_entity_deduplication_settings']}"
-            )
-
-            number_of_distinct_entities = (
-                await self.kg_service.kg_entity_deduplication(
-                    graph_id=graph_id,
-                    **input_data["graph_entity_deduplication_settings"],
-                )
-            )[0]["num_entities"]
-
-            input_data["graph_entity_deduplication_settings"][
-                "generation_config"
-            ] = input_data["graph_entity_deduplication_settings"][
-                "generation_config"
-            ].model_dump_json()
-
-            # run 100 entities in one workflow
-            total_workflows = math.ceil(number_of_distinct_entities / 100)
-            workflows = []
-            for i in range(total_workflows):
-                offset = i * 100
-                workflows.append(
-                    context.aio.spawn_workflow(
-                        "kg-entity-deduplication-summary",
-                        {
-                            "request": {
-                                "graph_id": graph_id,
-                                "offset": offset,
-                                "limit": 100,
-                                "graph_entity_deduplication_settings": json.dumps(
-                                    input_data[
-                                        "graph_entity_deduplication_settings"
-                                    ]
-                                ),
-                            }
-                        },
-                        key=f"{i}/{total_workflows}_entity_deduplication_part",
-                    )
-                )
-
-            await asyncio.gather(*workflows)
-            return {
-                "result": f"successfully queued kg entity deduplication for collection {graph_id} with {number_of_distinct_entities} distinct entities"
-            }
-
-    @orchestration_provider.workflow(
-        name="kg-entity-deduplication-summary", timeout="360m"
-    )
-    class EntityDeduplicationSummaryWorkflow:
-        def __init__(self, kg_service: GraphService):
-            self.kg_service = kg_service
-
-        @orchestration_provider.step(retries=0, timeout="360m")
-        async def kg_entity_deduplication_summary(
-            self, context: Context
-        ) -> dict:
-            logger.info(
-                f"Running KG Entity Deduplication Summary for input data: {context.workflow_input()['request']}"
-            )
-
-            input_data = get_input_data_dict(
-                context.workflow_input()["request"]
-            )
-            graph_id = input_data["graph_id"]
-
-            await self.kg_service.kg_entity_deduplication_summary(
-                graph_id=graph_id,
-                offset=input_data["offset"],
-                limit=input_data["limit"],
-                **input_data["graph_entity_deduplication_settings"],
-            )
-
-            return {
-                "result": f"successfully queued kg entity deduplication summary for collection {graph_id}"
-            }
-
     @orchestration_provider.workflow(name="build-communities", timeout="360m")
     @orchestration_provider.workflow(name="build-communities", timeout="360m")
     class EnrichGraphWorkflow:
     class EnrichGraphWorkflow:
         def __init__(self, kg_service: GraphService):
         def __init__(self, kg_service: GraphService):
@@ -676,8 +559,4 @@ def hatchet_kg_factory(
         "extract-triples": CreateGraphWorkflow(service),
         "extract-triples": CreateGraphWorkflow(service),
         "build-communities": EnrichGraphWorkflow(service),
         "build-communities": EnrichGraphWorkflow(service),
         "kg-community-summary": KGCommunitySummaryWorkflow(service),
         "kg-community-summary": KGCommunitySummaryWorkflow(service),
-        "kg-entity-deduplication": EntityDeduplicationWorkflow(service),
-        "kg-entity-deduplication-summary": EntityDeduplicationSummaryWorkflow(
-            service
-        ),
     }
     }

+ 0 - 26
core/main/orchestration/simple/kg_workflow.py

@@ -174,34 +174,8 @@ def simple_kg_factory(service: GraphService):
             **input_data["graph_enrichment_settings"],
             **input_data["graph_enrichment_settings"],
         )
         )
 
 
-    async def entity_deduplication_workflow(input_data):
-        # TODO: We should determine how we want to handle the input here and syncronize it across all simple orchestration methods
-        if isinstance(input_data["graph_entity_deduplication_settings"], str):
-            input_data["graph_entity_deduplication_settings"] = json.loads(
-                input_data["graph_entity_deduplication_settings"]
-            )
-
-        collection_id = input_data.get("collection_id", None)
-        graph_id = input_data.get("graph_id", None)
-
-        number_of_distinct_entities = (
-            await service.kg_entity_deduplication(
-                collection_id=collection_id,
-                graph_id=graph_id,
-                **input_data["graph_entity_deduplication_settings"],
-            )
-        )[0]["num_entities"]
-
-        await service.kg_entity_deduplication_summary(
-            collection_id=collection_id,
-            offset=0,
-            limit=number_of_distinct_entities,
-            **input_data["graph_entity_deduplication_settings"],
-        )
-
     return {
     return {
         "extract-triples": extract_triples,
         "extract-triples": extract_triples,
         "build-communities": enrich_graph,
         "build-communities": enrich_graph,
         "kg-community-summary": kg_community_summary,
         "kg-community-summary": kg_community_summary,
-        "entity-deduplication": entity_deduplication_workflow,
     }
     }

+ 2 - 2
core/main/services/auth_service.py

@@ -283,7 +283,7 @@ class AuthService(Service):
         """
         """
         return await self.providers.auth.create_user_api_key(user_id)
         return await self.providers.auth.create_user_api_key(user_id)
 
 
-    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
+    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
         """
         """
         Delete the API key for the user.
         Delete the API key for the user.
 
 
@@ -292,7 +292,7 @@ class AuthService(Service):
             key_id (str): The ID of the API key
             key_id (str): The ID of the API key
 
 
         Returns:
         Returns:
-            dict: Contains the message
+            bool: True if the API key was deleted successfully
         """
         """
         return await self.providers.auth.delete_user_api_key(
         return await self.providers.auth.delete_user_api_key(
             user_id=user_id, key_id=key_id
             user_id=user_id, key_id=key_id

+ 14 - 88
core/main/services/graph_service.py

@@ -21,10 +21,9 @@ from core.base.abstractions import (
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
     KGEnrichmentStatus,
     KGEnrichmentStatus,
-    KGEntityDeduplicationSettings,
-    KGEntityDeduplicationType,
     R2RException,
     R2RException,
     Relationship,
     Relationship,
+    StoreType,
 )
 )
 from core.base.api.models import GraphResponse
 from core.base.api.models import GraphResponse
 from core.telemetry.telemetry_decorator import telemetry_event
 from core.telemetry.telemetry_decorator import telemetry_event
@@ -144,7 +143,7 @@ class GraphService(Service):
         return await self.providers.database.graphs_handler.entities.create(
         return await self.providers.database.graphs_handler.entities.create(
             name=name,
             name=name,
             parent_id=parent_id,
             parent_id=parent_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             category=category,
             category=category,
             description=description,
             description=description,
             description_embedding=description_embedding,
             description_embedding=description_embedding,
@@ -168,7 +167,7 @@ class GraphService(Service):
 
 
         return await self.providers.database.graphs_handler.entities.update(
         return await self.providers.database.graphs_handler.entities.update(
             entity_id=entity_id,
             entity_id=entity_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             name=name,
             name=name,
             description=description,
             description=description,
             description_embedding=description_embedding,
             description_embedding=description_embedding,
@@ -185,7 +184,7 @@ class GraphService(Service):
         return await self.providers.database.graphs_handler.entities.delete(
         return await self.providers.database.graphs_handler.entities.delete(
             parent_id=parent_id,
             parent_id=parent_id,
             entity_ids=[entity_id],
             entity_ids=[entity_id],
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
         )
         )
 
 
     @telemetry_event("get_entities")
     @telemetry_event("get_entities")
@@ -238,7 +237,7 @@ class GraphService(Service):
                 description_embedding=description_embedding,
                 description_embedding=description_embedding,
                 weight=weight,
                 weight=weight,
                 metadata=metadata,
                 metadata=metadata,
-                store_type="graphs",  # type: ignore
+                store_type=StoreType.GRAPHS,
             )
             )
         )
         )
 
 
@@ -252,7 +251,7 @@ class GraphService(Service):
             await self.providers.database.graphs_handler.relationships.delete(
             await self.providers.database.graphs_handler.relationships.delete(
                 parent_id=parent_id,
                 parent_id=parent_id,
                 relationship_ids=[relationship_id],
                 relationship_ids=[relationship_id],
-                store_type="graphs",  # type: ignore
+                store_type=StoreType.GRAPHS,
             )
             )
         )
         )
 
 
@@ -287,7 +286,7 @@ class GraphService(Service):
                 description_embedding=description_embedding,
                 description_embedding=description_embedding,
                 weight=weight,
                 weight=weight,
                 metadata=metadata,
                 metadata=metadata,
-                store_type="graphs",  # type: ignore
+                store_type=StoreType.GRAPHS,
             )
             )
         )
         )
 
 
@@ -302,7 +301,7 @@ class GraphService(Service):
     ):
     ):
         return await self.providers.database.graphs_handler.relationships.get(
         return await self.providers.database.graphs_handler.relationships.get(
             parent_id=parent_id,
             parent_id=parent_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             offset=offset,
             offset=offset,
             limit=limit,
             limit=limit,
             relationship_ids=relationship_ids,
             relationship_ids=relationship_ids,
@@ -324,7 +323,7 @@ class GraphService(Service):
         )
         )
         return await self.providers.database.graphs_handler.communities.create(
         return await self.providers.database.graphs_handler.communities.create(
             parent_id=parent_id,
             parent_id=parent_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             name=name,
             name=name,
             summary=summary,
             summary=summary,
             description_embedding=description_embedding,
             description_embedding=description_embedding,
@@ -351,7 +350,7 @@ class GraphService(Service):
 
 
         return await self.providers.database.graphs_handler.communities.update(
         return await self.providers.database.graphs_handler.communities.update(
             community_id=community_id,
             community_id=community_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             name=name,
             name=name,
             summary=summary,
             summary=summary,
             summary_embedding=summary_embedding,
             summary_embedding=summary_embedding,
@@ -380,7 +379,7 @@ class GraphService(Service):
     ):
     ):
         return await self.providers.database.graphs_handler.communities.get(
         return await self.providers.database.graphs_handler.communities.get(
             parent_id=collection_id,
             parent_id=collection_id,
-            store_type="graphs",  # type: ignore
+            store_type=StoreType.GRAPHS,
             offset=offset,
             offset=offset,
             limit=limit,
             limit=limit,
         )
         )
@@ -622,21 +621,17 @@ class GraphService(Service):
     async def delete_graph(
     async def delete_graph(
         self,
         self,
         collection_id: UUID,
         collection_id: UUID,
-        cascade: bool,
-        **kwargs,
     ):
     ):
-        return await self.delete(collection_id=collection_id, cascade=cascade)
+        return await self.delete(collection_id=collection_id)
 
 
     @telemetry_event("delete")
     @telemetry_event("delete")
     async def delete(
     async def delete(
         self,
         self,
         collection_id: UUID,
         collection_id: UUID,
-        cascade: bool,
         **kwargs,
         **kwargs,
     ):
     ):
         return await self.providers.database.graphs_handler.delete(
         return await self.providers.database.graphs_handler.delete(
             collection_id=collection_id,
             collection_id=collection_id,
-            cascade=cascade,
         )
         )
 
 
     @telemetry_event("get_creation_estimate")
     @telemetry_event("get_creation_estimate")
@@ -674,75 +669,6 @@ class GraphService(Service):
             graph_enrichment_settings=graph_enrichment_settings,
             graph_enrichment_settings=graph_enrichment_settings,
         )
         )
 
 
-    @telemetry_event("get_deduplication_estimate")
-    async def get_deduplication_estimate(
-        self,
-        collection_id: UUID,
-        kg_deduplication_settings: KGEntityDeduplicationSettings,
-        **kwargs,
-    ):
-        return await self.providers.database.graphs_handler.get_deduplication_estimate(
-            collection_id=collection_id,
-            kg_deduplication_settings=kg_deduplication_settings,
-        )
-
-    @telemetry_event("kg_entity_deduplication")
-    async def kg_entity_deduplication(
-        self,
-        collection_id: UUID,
-        graph_id: UUID,
-        graph_entity_deduplication_type: KGEntityDeduplicationType,
-        graph_entity_deduplication_prompt: str,
-        generation_config: GenerationConfig,
-        **kwargs,
-    ):
-        deduplication_results = await self.pipes.graph_deduplication_pipe.run(
-            input=self.pipes.graph_deduplication_pipe.Input(
-                message={
-                    "collection_id": collection_id,
-                    "graph_id": graph_id,
-                    "graph_entity_deduplication_type": graph_entity_deduplication_type,
-                    "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
-                    "generation_config": generation_config,
-                    **kwargs,
-                }
-            ),
-            state=None,
-            run_manager=self.run_manager,
-        )
-        return await _collect_results(deduplication_results)
-
-    @telemetry_event("kg_entity_deduplication_summary")
-    async def kg_entity_deduplication_summary(
-        self,
-        collection_id: UUID,
-        offset: int,
-        limit: int,
-        graph_entity_deduplication_type: KGEntityDeduplicationType,
-        graph_entity_deduplication_prompt: str,
-        generation_config: GenerationConfig,
-        **kwargs,
-    ):
-        logger.info(
-            f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
-        )
-        deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run(
-            input=self.pipes.graph_deduplication_summary_pipe.Input(
-                message={
-                    "collection_id": collection_id,
-                    "offset": offset,
-                    "limit": limit,
-                    "graph_entity_deduplication_type": graph_entity_deduplication_type,
-                    "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
-                    "generation_config": generation_config,
-                }
-            ),
-            state=None,
-            run_manager=self.run_manager,
-        )
-
-        return await _collect_results(deduplication_summary_results)
-
     async def kg_extraction(  # type: ignore
     async def kg_extraction(  # type: ignore
         self,
         self,
         document_id: UUID,
         document_id: UUID,
@@ -1045,7 +971,7 @@ class GraphService(Service):
                 result = await self.providers.database.graphs_handler.entities.create(
                 result = await self.providers.database.graphs_handler.entities.create(
                     name=entity.name,
                     name=entity.name,
                     parent_id=entity.parent_id,
                     parent_id=entity.parent_id,
-                    store_type="documents",  # type: ignore
+                    store_type=StoreType.DOCUMENTS,
                     category=entity.category,
                     category=entity.category,
                     description=entity.description,
                     description=entity.description,
                     description_embedding=entity.description_embedding,
                     description_embedding=entity.description_embedding,
@@ -1067,5 +993,5 @@ class GraphService(Service):
                         description_embedding=relationship.description_embedding,
                         description_embedding=relationship.description_embedding,
                         weight=relationship.weight,
                         weight=relationship.weight,
                         metadata=relationship.metadata,
                         metadata=relationship.metadata,
-                        store_type="documents",  # type: ignore
+                        store_type=StoreType.DOCUMENTS,
                     )
                     )

+ 173 - 7
core/main/services/management_service.py

@@ -1,7 +1,8 @@
 import logging
 import logging
 import os
 import os
 from collections import defaultdict
 from collections import defaultdict
-from typing import Any, BinaryIO, Optional, Tuple
+from datetime import datetime
+from typing import IO, Any, BinaryIO, Optional, Tuple
 from uuid import UUID
 from uuid import UUID
 
 
 import toml
 import toml
@@ -16,6 +17,7 @@ from core.base import (
     Prompt,
     Prompt,
     R2RException,
     R2RException,
     RunManager,
     RunManager,
+    StoreType,
     User,
     User,
 )
 )
 from core.telemetry.telemetry_decorator import telemetry_event
 from core.telemetry.telemetry_decorator import telemetry_event
@@ -189,10 +191,12 @@ class ManagementService(Service):
         for doc_id in docs_to_delete:
         for doc_id in docs_to_delete:
             # Delete related entities & relationships if needed:
             # Delete related entities & relationships if needed:
             await self.providers.database.graphs_handler.entities.delete(
             await self.providers.database.graphs_handler.entities.delete(
-                parent_id=doc_id, store_type="documents"
+                parent_id=doc_id,
+                store_type=StoreType.DOCUMENTS,
             )
             )
             await self.providers.database.graphs_handler.relationships.delete(
             await self.providers.database.graphs_handler.relationships.delete(
-                parent_id=doc_id, store_type="documents"
+                parent_id=doc_id,
+                store_type=StoreType.DOCUMENTS,
             )
             )
 
 
             # Finally, delete the document from documents_overview:
             # Finally, delete the document from documents_overview:
@@ -218,6 +222,166 @@ class ManagementService(Service):
             return result
             return result
         return None
         return None
 
 
+    @telemetry_event("ExportFiles")
+    async def export_files(
+        self,
+        document_ids: Optional[list[UUID]] = None,
+        start_date: Optional[datetime] = None,
+        end_date: Optional[datetime] = None,
+    ) -> tuple[str, BinaryIO, int]:
+        return (
+            await self.providers.database.files_handler.retrieve_files_as_zip(
+                document_ids=document_ids,
+                start_date=start_date,
+                end_date=end_date,
+            )
+        )
+
+    @telemetry_event("ExportCollections")
+    async def export_collections(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.collections_handler.export_to_csv(
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportDocuments")
+    async def export_documents(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.documents_handler.export_to_csv(
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportDocumentEntities")
+    async def export_document_entities(
+        self,
+        id: UUID,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.graphs_handler.entities.export_to_csv(
+            parent_id=id,
+            store_type=StoreType.DOCUMENTS,
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportDocumentRelationships")
+    async def export_document_relationships(
+        self,
+        id: UUID,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.graphs_handler.relationships.export_to_csv(
+            parent_id=id,
+            store_type=StoreType.DOCUMENTS,
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportConversations")
+    async def export_conversations(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.conversations_handler.export_conversations_to_csv(
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportGraphEntities")
+    async def export_graph_entities(
+        self,
+        id: UUID,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.graphs_handler.entities.export_to_csv(
+            parent_id=id,
+            store_type=StoreType.GRAPHS,
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportGraphRelationships")
+    async def export_graph_relationships(
+        self,
+        id: UUID,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.graphs_handler.relationships.export_to_csv(
+            parent_id=id,
+            store_type=StoreType.GRAPHS,
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportGraphCommunities")
+    async def export_graph_communities(
+        self,
+        id: UUID,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.graphs_handler.communities.export_to_csv(
+            parent_id=id,
+            store_type=StoreType.GRAPHS,
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportMessages")
+    async def export_messages(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.conversations_handler.export_messages_to_csv(
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
+    @telemetry_event("ExportUsers")
+    async def export_users(
+        self,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> tuple[str, IO]:
+        return await self.providers.database.users_handler.export_to_csv(
+            columns=columns,
+            filters=filters,
+            include_header=include_header,
+        )
+
     @telemetry_event("DocumentsOverview")
     @telemetry_event("DocumentsOverview")
     async def documents_overview(
     async def documents_overview(
         self,
         self,
@@ -538,7 +702,9 @@ class ManagementService(Service):
             return {
             return {
                 "message": (
                 "message": (
                     await self.providers.database.prompts_handler.get_cached_prompt(
                     await self.providers.database.prompts_handler.get_cached_prompt(
-                        prompt_name, inputs, prompt_override
+                        prompt_name=prompt_name,
+                        inputs=inputs,
+                        prompt_override=prompt_override,
                     )
                     )
                 )
                 )
             }
             }
@@ -674,11 +840,11 @@ class ManagementService(Service):
             filter_user_ids=user_ids,
             filter_user_ids=user_ids,
         )
         )
 
 
-    async def get_user_max_documents(self, user_id: UUID) -> int:
+    async def get_user_max_documents(self, user_id: UUID) -> int | None:
         return self.config.app.default_max_documents_per_user
         return self.config.app.default_max_documents_per_user
 
 
-    async def get_user_max_chunks(self, user_id: UUID) -> int:
+    async def get_user_max_chunks(self, user_id: UUID) -> int | None:
         return self.config.app.default_max_chunks_per_user
         return self.config.app.default_max_chunks_per_user
 
 
-    async def get_user_max_collections(self, user_id: UUID) -> int:
+    async def get_user_max_collections(self, user_id: UUID) -> int | None:
         return self.config.app.default_max_collections_per_user
         return self.config.app.default_max_collections_per_user

+ 3 - 3
core/parsers/structured/csv_parser.py

@@ -1,5 +1,5 @@
 # type: ignore
 # type: ignore
-from typing import IO, AsyncGenerator, Optional, Union
+from typing import IO, AsyncGenerator, Optional
 
 
 from core.base.parsers.base_parser import AsyncParser
 from core.base.parsers.base_parser import AsyncParser
 from core.base.providers import (
 from core.base.providers import (
@@ -29,7 +29,7 @@ class CSVParser(AsyncParser[str | bytes]):
         self.StringIO = StringIO
         self.StringIO = StringIO
 
 
     async def ingest(
     async def ingest(
-        self, data: Union[str, bytes], *args, **kwargs
+        self, data: str | bytes, *args, **kwargs
     ) -> AsyncGenerator[str, None]:
     ) -> AsyncGenerator[str, None]:
         """Ingest CSV data and yield text from each row."""
         """Ingest CSV data and yield text from each row."""
         if isinstance(data, bytes):
         if isinstance(data, bytes):
@@ -72,7 +72,7 @@ class CSVParserAdvanced(AsyncParser[str | bytes]):
 
 
     async def ingest(
     async def ingest(
         self,
         self,
-        data: Union[str, bytes],
+        data: str | bytes,
         num_col_times_num_rows: int = 100,
         num_col_times_num_rows: int = 100,
         *args,
         *args,
         **kwargs,
         **kwargs,

+ 0 - 4
core/pipes/__init__.py

@@ -5,8 +5,6 @@ from .ingestion.parsing_pipe import ParsingPipe
 from .ingestion.vector_storage_pipe import VectorStoragePipe
 from .ingestion.vector_storage_pipe import VectorStoragePipe
 from .kg.clustering import GraphClusteringPipe
 from .kg.clustering import GraphClusteringPipe
 from .kg.community_summary import GraphCommunitySummaryPipe
 from .kg.community_summary import GraphCommunitySummaryPipe
-from .kg.deduplication import GraphDeduplicationPipe
-from .kg.deduplication_summary import GraphDeduplicationSummaryPipe
 from .kg.description import GraphDescriptionPipe
 from .kg.description import GraphDescriptionPipe
 from .kg.extraction import GraphExtractionPipe
 from .kg.extraction import GraphExtractionPipe
 from .kg.storage import GraphStoragePipe
 from .kg.storage import GraphStoragePipe
@@ -36,6 +34,4 @@ __all__ = [
     "MultiSearchPipe",
     "MultiSearchPipe",
     "GraphCommunitySummaryPipe",
     "GraphCommunitySummaryPipe",
     "RoutingSearchPipe",
     "RoutingSearchPipe",
-    "GraphDeduplicationPipe",
-    "GraphDeduplicationSummaryPipe",
 ]
 ]

+ 2 - 2
core/pipes/abstractions/search_pipe.py

@@ -1,6 +1,6 @@
 import logging
 import logging
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Any, AsyncGenerator, Optional, Union
+from typing import Any, AsyncGenerator
 from uuid import UUID
 from uuid import UUID
 
 
 from core.base import AsyncPipe, AsyncState, ChunkSearchResult
 from core.base import AsyncPipe, AsyncState, ChunkSearchResult
@@ -15,7 +15,7 @@ class SearchPipe(AsyncPipe[ChunkSearchResult]):
         limit: int = 10
         limit: int = 10
 
 
     class Input(AsyncPipe.Input):
     class Input(AsyncPipe.Input):
-        message: Union[AsyncGenerator[str, None], str]
+        message: AsyncGenerator[str, None] | str
 
 
     def __init__(
     def __init__(
         self,
         self,

+ 2 - 2
core/pipes/ingestion/embedding_pipe.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 import logging
 import logging
-from typing import Any, AsyncGenerator, Optional, Union
+from typing import Any, AsyncGenerator
 
 
 from core.base import (
 from core.base import (
     AsyncState,
     AsyncState,
@@ -113,7 +113,7 @@ class EmbeddingPipe(AsyncPipe[VectorEntry]):
 
 
     async def _process_extraction(
     async def _process_extraction(
         self, extraction: DocumentChunk
         self, extraction: DocumentChunk
-    ) -> Union[VectorEntry, R2RDocumentProcessingError]:
+    ) -> VectorEntry | R2RDocumentProcessingError:
         try:
         try:
             if isinstance(extraction.data, bytes):
             if isinstance(extraction.data, bytes):
                 raise ValueError(
                 raise ValueError(

+ 1 - 1
core/pipes/kg/community_summary.py

@@ -300,7 +300,7 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         )
         )
 
 
         # Organize clusters
         # Organize clusters
-        clusters: dict[Any] = {}
+        clusters: dict[Any, Any] = {}
         for item in community_clusters:
         for item in community_clusters:
             cluster_id = (
             cluster_id = (
                 item["cluster"]
                 item["cluster"]

+ 2 - 2
core/pipes/kg/extraction.py

@@ -3,7 +3,7 @@ import json
 import logging
 import logging
 import re
 import re
 import time
 import time
-from typing import Any, AsyncGenerator, Optional, Union
+from typing import Any, AsyncGenerator, Optional
 
 
 from core.base import (
 from core.base import (
     AsyncState,
     AsyncState,
@@ -211,7 +211,7 @@ class GraphExtractionPipe(AsyncPipe[dict]):
         run_id: Any,
         run_id: Any,
         *args: Any,
         *args: Any,
         **kwargs: Any,
         **kwargs: Any,
-    ) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]:
+    ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
         start_time = time.time()
         start_time = time.time()
 
 
         document_id = input.message["document_id"]
         document_id = input.message["document_id"]

+ 13 - 6
core/providers/auth/r2r_auth.py

@@ -106,17 +106,24 @@ class R2RAuthProvider(AuthProvider):
                 status_code=401, message="Invalid or expired token"
                 status_code=401, message="Invalid or expired token"
             )
             )
 
 
-        email: str = payload.get("sub")
-        token_type: str = payload.get("token_type")
-        exp: float = payload.get("exp")
+        email = payload.get("sub")
+        token_type = payload.get("token_type")
+        exp = payload.get("exp")
+
         if email is None or token_type is None or exp is None:
         if email is None or token_type is None or exp is None:
             raise R2RException(status_code=401, message="Invalid token claims")
             raise R2RException(status_code=401, message="Invalid token claims")
 
 
-        exp_datetime = datetime.fromtimestamp(exp, tz=timezone.utc)
+        email_str: str = email
+        token_type_str: str = token_type
+        exp_float: float = exp
+
+        exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
         if exp_datetime < datetime.now(timezone.utc):
         if exp_datetime < datetime.now(timezone.utc):
             raise R2RException(status_code=401, message="Token has expired")
             raise R2RException(status_code=401, message="Token has expired")
 
 
-        return TokenData(email=email, token_type=token_type, exp=exp_datetime)
+        return TokenData(
+            email=email_str, token_type=token_type_str, exp=exp_datetime
+        )
 
 
     async def authenticate_api_key(self, api_key: str) -> Optional[User]:
     async def authenticate_api_key(self, api_key: str) -> Optional[User]:
         """
         """
@@ -479,7 +486,7 @@ class R2RAuthProvider(AuthProvider):
             user_id=user_id
             user_id=user_id
         )
         )
 
 
-    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
         return await self.database_provider.users_handler.delete_api_key(
         return await self.database_provider.users_handler.delete_api_key(
             user_id=user_id,
             user_id=user_id,
             key_id=key_id,
             key_id=key_id,

+ 19 - 0
core/providers/auth/supabase.py

@@ -1,5 +1,7 @@
 import logging
 import logging
 import os
 import os
+from typing import Optional
+from uuid import UUID
 
 
 from fastapi import Depends, HTTPException
 from fastapi import Depends, HTTPException
 from fastapi.security import OAuth2PasswordBearer
 from fastapi.security import OAuth2PasswordBearer
@@ -216,3 +218,20 @@ class SupabaseAuthProvider(AuthProvider):
 
 
     async def send_reset_email(self, email: str) -> dict[str, str]:
     async def send_reset_email(self, email: str) -> dict[str, str]:
         raise NotImplementedError("send_reset_email is not used with Supabase")
         raise NotImplementedError("send_reset_email is not used with Supabase")
+
+    async def create_user_api_key(
+        self, user_id: UUID, name: Optional[str] = None
+    ) -> dict[str, str]:
+        raise NotImplementedError(
+            "API key management is not supported with Supabase authentication"
+        )
+
+    async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+        raise NotImplementedError(
+            "API key management is not supported with Supabase authentication"
+        )
+
+    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
+        raise NotImplementedError(
+            "API key management is not supported with Supabase authentication"
+        )

+ 3 - 3
core/providers/ingestion/unstructured/base.py

@@ -63,7 +63,7 @@ class UnstructuredIngestionConfig(IngestionConfig):
     split_pdf_page: Optional[bool] = None
     split_pdf_page: Optional[bool] = None
     starting_page_number: Optional[int] = None
     starting_page_number: Optional[int] = None
     strategy: Optional[str] = None
     strategy: Optional[str] = None
-    chunking_strategy: Optional[ChunkingStrategy] = None
+    chunking_strategy: Optional[str | ChunkingStrategy] = None
     unique_element_ids: Optional[bool] = None
     unique_element_ids: Optional[bool] = None
     xml_keep_tags: Optional[bool] = None
     xml_keep_tags: Optional[bool] = None
 
 
@@ -99,8 +99,8 @@ class UnstructuredIngestionProvider(IngestionProvider):
     EXTRA_PARSERS = {
     EXTRA_PARSERS = {
         DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced},  # type: ignore
         DocumentType.CSV: {"advanced": parsers.CSVParserAdvanced},  # type: ignore
         DocumentType.PDF: {
         DocumentType.PDF: {
-            "unstructured": parsers.PDFParserUnstructured,
-            "zerox": parsers.VLMPDFParser,
+            "unstructured": parsers.PDFParserUnstructured,  # type: ignore
+            "zerox": parsers.VLMPDFParser,  # type: ignore
         },
         },
         DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced},  # type: ignore
         DocumentType.XLSX: {"advanced": parsers.XLSXParserAdvanced},  # type: ignore
     }
     }

+ 0 - 69
core/telemetry/events.py

@@ -9,11 +9,6 @@ class BaseTelemetryEvent:
         self.event_id = str(uuid.uuid4())
         self.event_id = str(uuid.uuid4())
 
 
 
 
-class DailyActiveUserEvent(BaseTelemetryEvent):
-    def __init__(self, user_id: str):
-        super().__init__("DailyActiveUser", {"user_id": user_id})
-
-
 class FeatureUsageEvent(BaseTelemetryEvent):
 class FeatureUsageEvent(BaseTelemetryEvent):
     def __init__(
     def __init__(
         self,
         self,
@@ -48,67 +43,3 @@ class ErrorEvent(BaseTelemetryEvent):
                 "properties": properties or {},
                 "properties": properties or {},
             },
             },
         )
         )
-
-
-class RequestLatencyEvent(BaseTelemetryEvent):
-    def __init__(
-        self,
-        endpoint: str,
-        latency: float,
-        properties: Optional[dict[str, Any]] = None,
-    ):
-        super().__init__(
-            "RequestLatency",
-            {
-                "endpoint": endpoint,
-                "latency": latency,
-                "properties": properties or {},
-            },
-        )
-
-
-class GeographicDistributionEvent(BaseTelemetryEvent):
-    def __init__(
-        self,
-        user_id: str,
-        country: str,
-        properties: Optional[dict[str, Any]] = None,
-    ):
-        super().__init__(
-            "GeographicDistribution",
-            {
-                "user_id": user_id,
-                "country": country,
-                "properties": properties or {},
-            },
-        )
-
-
-class SessionDurationEvent(BaseTelemetryEvent):
-    def __init__(
-        self,
-        user_id: str,
-        duration: float,
-        properties: Optional[dict[str, Any]] = None,
-    ):
-        super().__init__(
-            "SessionDuration",
-            {
-                "user_id": user_id,
-                "duration": duration,
-                "properties": properties or {},
-            },
-        )
-
-
-class UserPathEvent(BaseTelemetryEvent):
-    def __init__(
-        self,
-        user_id: str,
-        path: str,
-        properties: Optional[dict[str, Any]] = None,
-    ):
-        super().__init__(
-            "UserPath",
-            {"user_id": user_id, "path": path, "properties": properties or {}},
-        )

+ 1 - 17
sdk/async_client.py

@@ -7,14 +7,6 @@ import httpx
 from shared.abstractions import R2RException
 from shared.abstractions import R2RException
 
 
 from .base.base_client import BaseClient
 from .base.base_client import BaseClient
-from .v2 import (
-    AuthMixins,
-    IngestionMixins,
-    KGMixins,
-    ManagementMixins,
-    RetrievalMixins,
-    ServerMixins,
-)
 from .v3 import (
 from .v3 import (
     ChunksSDK,
     ChunksSDK,
     CollectionsSDK,
     CollectionsSDK,
@@ -29,15 +21,7 @@ from .v3 import (
 )
 )
 
 
 
 
-class R2RAsyncClient(
-    BaseClient,
-    AuthMixins,
-    IngestionMixins,
-    KGMixins,
-    ManagementMixins,
-    RetrievalMixins,
-    ServerMixins,
-):
+class R2RAsyncClient(BaseClient):
     """
     """
     Asynchronous client for interacting with the R2R API.
     Asynchronous client for interacting with the R2R API.
     """
     """

+ 0 - 4
sdk/models.py

@@ -8,9 +8,7 @@ from shared.abstractions import (
     KGCommunityResult,
     KGCommunityResult,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
     KGEntityResult,
     KGEntityResult,
-    KGGlobalResult,
     KGRelationshipResult,
     KGRelationshipResult,
     KGRunType,
     KGRunType,
     KGSearchResultType,
     KGSearchResultType,
@@ -33,7 +31,6 @@ __all__ = [
     "KGCreationSettings",
     "KGCreationSettings",
     "KGEnrichmentSettings",
     "KGEnrichmentSettings",
     "KGEntityResult",
     "KGEntityResult",
-    "KGGlobalResult",
     "KGRelationshipResult",
     "KGRelationshipResult",
     "KGRunType",
     "KGRunType",
     "GraphSearchResult",
     "GraphSearchResult",
@@ -48,7 +45,6 @@ __all__ = [
     "SearchSettings",
     "SearchSettings",
     "select_search_filters",
     "select_search_filters",
     "SearchMode",
     "SearchMode",
-    "KGEntityDeduplicationSettings",
     "RAGResponse",
     "RAGResponse",
     "CombinedSearchResponse",
     "CombinedSearchResponse",
     "User",
     "User",

+ 0 - 48
sdk/sync_client.py

@@ -5,14 +5,6 @@ import inspect
 from typing import Any, Callable, Coroutine, TypeVar
 from typing import Any, Callable, Coroutine, TypeVar
 
 
 from .async_client import R2RAsyncClient
 from .async_client import R2RAsyncClient
-from .v2 import (
-    SyncAuthMixins,
-    SyncIngestionMixins,
-    SyncKGMixins,
-    SyncManagementMixins,
-    SyncRetrievalMixins,
-    SyncServerMixins,
-)
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
@@ -29,52 +21,12 @@ class R2RClient(R2RAsyncClient):
         # Only wrap v3 methods since they're already working
         # Only wrap v3 methods since they're already working
         self._wrap_v3_methods()
         self._wrap_v3_methods()
 
 
-        # Override v2 methods with sync versions
-        self._override_v2_methods()
-
     def _make_sync_request(self, *args, **kwargs):
     def _make_sync_request(self, *args, **kwargs):
         """Sync version of _make_request for v2 methods"""
         """Sync version of _make_request for v2 methods"""
         return self._loop.run_until_complete(
         return self._loop.run_until_complete(
             self._async_make_request(*args, **kwargs)
             self._async_make_request(*args, **kwargs)
         )
         )
 
 
-    def _override_v2_methods(self):
-        """
-        Replace async v2 methods with sync versions
-        This is really ugly, but it's the only way to make it work once we
-        remove v2, we can just resort to the metaclass approach that is in utils
-        """
-        sync_mixins = {
-            SyncAuthMixins: ["auth_methods"],
-            SyncIngestionMixins: ["ingestion_methods"],
-            SyncKGMixins: ["kg_methods"],
-            SyncManagementMixins: ["management_methods"],
-            SyncRetrievalMixins: ["retrieval_methods"],
-            SyncServerMixins: ["server_methods"],
-        }
-
-        for sync_class in sync_mixins:
-            for name, method in sync_class.__dict__.items():
-                if not name.startswith("_") and inspect.isfunction(method):
-                    # Create a wrapper that uses sync _make_request
-                    def wrap_method(m):
-                        def wrapped(self, *args, **kwargs):
-                            # Temporarily swap _make_request
-                            original_make_request = self._make_request
-                            self._make_request = self._make_sync_request
-                            try:
-                                return m(self, *args, **kwargs)
-                            finally:
-                                # Restore original _make_request
-                                self._make_request = original_make_request
-
-                        return wrapped
-
-                    bound_method = wrap_method(method).__get__(
-                        self, self.__class__
-                    )
-                    setattr(self, name, bound_method)
-
     def _wrap_v3_methods(self) -> None:
     def _wrap_v3_methods(self) -> None:
         """Wraps only v3 SDK object methods"""
         """Wraps only v3 SDK object methods"""
         sdk_objects = [
         sdk_objects = [

+ 3 - 3
sdk/v3/chunks.py

@@ -1,5 +1,5 @@
 import json
 import json
-from typing import Optional
+from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from shared.api.models.base import WrappedBooleanResponse
 from shared.api.models.base import WrappedBooleanResponse
@@ -102,7 +102,7 @@ class ChunksSDK:
         Delete a specific chunk.
         Delete a specific chunk.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of chunk to delete
+            id (str | UUID): ID of chunk to delete
         """
         """
         return await self.client._make_request(
         return await self.client._make_request(
             "DELETE",
             "DELETE",
@@ -168,7 +168,7 @@ class ChunksSDK:
         if search_settings and not isinstance(search_settings, dict):
         if search_settings and not isinstance(search_settings, dict):
             search_settings = search_settings.model_dump()
             search_settings = search_settings.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "query": query,
             "query": query,
             "search_settings": search_settings,
             "search_settings": search_settings,
         }
         }

+ 5 - 5
sdk/v3/collections.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from shared.api.models.base import (
 from shared.api.models.base import (
@@ -32,7 +32,7 @@ class CollectionsSDK:
         Returns:
         Returns:
             dict: Created collection information
             dict: Created collection information
         """
         """
-        data = {"name": name, "description": description}
+        data: dict[str, Any] = {"name": name, "description": description}
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
             "collections",
             "collections",
@@ -104,7 +104,7 @@ class CollectionsSDK:
         Returns:
         Returns:
             dict: Updated collection information
             dict: Updated collection information
         """
         """
-        data = {}
+        data: dict[str, Any] = {}
         if name is not None:
         if name is not None:
             data["name"] = name
             data["name"] = name
         if description is not None:
         if description is not None:
@@ -304,7 +304,7 @@ class CollectionsSDK:
             "run_with_orchestration": run_with_orchestration
             "run_with_orchestration": run_with_orchestration
         }
         }
 
 
-        data = {}
+        data: dict[str, Any] = {}
         if settings is not None:
         if settings is not None:
             data["settings"] = settings
             data["settings"] = settings
 
 
@@ -312,6 +312,6 @@ class CollectionsSDK:
             "POST",
             "POST",
             f"collections/{str(id)}/extract",
             f"collections/{str(id)}/extract",
             params=params,
             params=params,
-            json=data if data else None,
+            json=data or None,
             version="v3",
             version="v3",
         )
         )

+ 108 - 6
sdk/v3/conversations.py

@@ -1,6 +1,10 @@
+from builtins import list as _list
+from pathlib import Path
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
+import aiofiles
+
 from shared.api.models.base import WrappedBooleanResponse
 from shared.api.models.base import WrappedBooleanResponse
 from shared.api.models.management.responses import (
 from shared.api.models.management.responses import (
     WrappedConversationMessagesResponse,
     WrappedConversationMessagesResponse,
@@ -24,7 +28,9 @@ class ConversationsSDK:
         Returns:
         Returns:
             dict: Created conversation information
             dict: Created conversation information
         """
         """
-        data = {"name": name} if name else None
+        data: dict[str, Any] = {}
+        if name:
+            data["name"] = name
 
 
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
@@ -43,7 +49,7 @@ class ConversationsSDK:
         List conversations with pagination and sorting options.
         List conversations with pagination and sorting options.
 
 
         Args:
         Args:
-            ids (Optional[list[Union[str, UUID]]]): List of conversation IDs to retrieve
+            ids (Optional[list[str | UUID]]): List of conversation IDs to retrieve
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
 
 
@@ -72,7 +78,7 @@ class ConversationsSDK:
         Get detailed information about a specific conversation.
         Get detailed information about a specific conversation.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): The ID of the conversation to retrieve
+            id (str | UUID): The ID of the conversation to retrieve
 
 
         Returns:
         Returns:
             dict: Detailed conversation information
             dict: Detailed conversation information
@@ -92,7 +98,7 @@ class ConversationsSDK:
         Update an existing conversation.
         Update an existing conversation.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): The ID of the conversation to update
+            id (str | UUID): The ID of the conversation to update
             name (str): The new name of the conversation
             name (str): The new name of the conversation
 
 
         Returns:
         Returns:
@@ -117,7 +123,7 @@ class ConversationsSDK:
         Delete a conversation.
         Delete a conversation.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): The ID of the conversation to delete
+            id (str | UUID): The ID of the conversation to delete
 
 
         Returns:
         Returns:
             bool: True if deletion was successful
             bool: True if deletion was successful
@@ -140,7 +146,7 @@ class ConversationsSDK:
         Add a new message to a conversation.
         Add a new message to a conversation.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): The ID of the conversation to add the message to
+            id (str | UUID): The ID of the conversation to add the message to
             content (str): The content of the message
             content (str): The content of the message
             role (str): The role of the message (e.g., "user" or "assistant")
             role (str): The role of the message (e.g., "user" or "assistant")
             parent_id (Optional[str]): The ID of the parent message
             parent_id (Optional[str]): The ID of the parent message
@@ -193,3 +199,99 @@ class ConversationsSDK:
             json=data,
             json=data,
             version="v3",
             version="v3",
         )
         )
+
+    async def export(
+        self,
+        output_path: str | Path,
+        columns: Optional[_list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> None:
+        """
+        Export conversations to a CSV file, streaming the results directly to disk.
+
+        Args:
+            output_path (str | Path): Local path where the CSV file should be saved
+            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
+            filters (Optional[dict]): Optional filters to apply when selecting conversations
+            include_header (bool): Whether to include column headers in the CSV (default: True)
+        """
+        # Convert path to string if it's a Path object
+        output_path = (
+            str(output_path) if isinstance(output_path, Path) else output_path
+        )
+
+        # Prepare request data
+        data: dict[str, Any] = {"include_header": include_header}
+        if columns:
+            data["columns"] = columns
+        if filters:
+            data["filters"] = filters
+
+        # Stream response directly to file
+        async with aiofiles.open(output_path, "wb") as f:
+            async with self.client.session.post(
+                f"{self.client.base_url}/v3/conversations/export",
+                json=data,
+                headers={
+                    "Accept": "text/csv",
+                    **self.client._get_auth_headers(),
+                },
+            ) as response:
+                if response.status != 200:
+                    raise ValueError(
+                        f"Export failed with status {response.status}",
+                        response,
+                    )
+
+                async for chunk in response.content.iter_chunks():
+                    if chunk:
+                        await f.write(chunk[0])
+
+    async def export_messages(
+        self,
+        output_path: str | Path,
+        columns: Optional[_list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> None:
+        """
+        Export messages to a CSV file, streaming the results directly to disk.
+
+        Args:
+            output_path (str | Path): Local path where the CSV file should be saved
+            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
+            filters (Optional[dict]): Optional filters to apply when selecting messages
+            include_header (bool): Whether to include column headers in the CSV (default: True)
+        """
+        # Convert path to string if it's a Path object
+        output_path = (
+            str(output_path) if isinstance(output_path, Path) else output_path
+        )
+
+        # Prepare request data
+        data: dict[str, Any] = {"include_header": include_header}
+        if columns:
+            data["columns"] = columns
+        if filters:
+            data["filters"] = filters
+
+        # Stream response directly to file
+        async with aiofiles.open(output_path, "wb") as f:
+            async with self.client.session.post(
+                f"{self.client.base_url}/v3/conversations/export_messages",
+                json=data,
+                headers={
+                    "Accept": "text/csv",
+                    **self.client._get_auth_headers(),
+                },
+            ) as response:
+                if response.status != 200:
+                    raise ValueError(
+                        f"Export failed with status {response.status}",
+                        response,
+                    )
+
+                async for chunk in response.content.iter_chunks():
+                    if chunk:
+                        await f.write(chunk[0])

+ 212 - 41
sdk/v3/documents.py

@@ -1,8 +1,12 @@
 import json
 import json
+from datetime import datetime
 from io import BytesIO
 from io import BytesIO
+from pathlib import Path
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
+import aiofiles
+
 from shared.api.models.base import WrappedBooleanResponse
 from shared.api.models.base import WrappedBooleanResponse
 from shared.api.models.ingestion.responses import WrappedIngestionResponse
 from shared.api.models.ingestion.responses import WrappedIngestionResponse
 from shared.api.models.management.responses import (
 from shared.api.models.management.responses import (
@@ -41,8 +45,8 @@ class DocumentsSDK:
         Args:
         Args:
             file_path (Optional[str]): The file to upload, if any
             file_path (Optional[str]): The file to upload, if any
             content (Optional[str]): Optional text content to upload, if no file path is provided
             content (Optional[str]): Optional text content to upload, if no file path is provided
-            id (Optional[Union[str, UUID]]): Optional ID to assign to the document
-            collection_ids (Optional[list[Union[str, UUID]]]): Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.
+            id (Optional[str | UUID]): Optional ID to assign to the document
+            collection_ids (Optional[list[str | UUID]]): Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.
             metadata (Optional[dict]): Optional metadata to assign to the document
             metadata (Optional[dict]): Optional metadata to assign to the document
             ingestion_config (Optional[dict]): Optional ingestion configuration to use
             ingestion_config (Optional[dict]): Optional ingestion configuration to use
             run_with_orchestration (Optional[bool]): Whether to run with orchestration
             run_with_orchestration (Optional[bool]): Whether to run with orchestration
@@ -60,17 +64,23 @@ class DocumentsSDK:
                 "Only one of `file_path`, `raw_text` or `chunks` may be provided"
                 "Only one of `file_path`, `raw_text` or `chunks` may be provided"
             )
             )
 
 
-        data = {}
+        data: dict[str, Any] = {}
         files = None
         files = None
 
 
         if id:
         if id:
-            data["id"] = str(id)  # json.dumps(str(id))
+            data["id"] = str(id)
         if metadata:
         if metadata:
             data["metadata"] = json.dumps(metadata)
             data["metadata"] = json.dumps(metadata)
         if ingestion_config:
         if ingestion_config:
-            if not isinstance(ingestion_config, dict):
-                ingestion_config = ingestion_config.model_dump()
-            ingestion_config["app"] = {}
+            if isinstance(ingestion_config, IngestionMode):
+                ingestion_config = {"mode": ingestion_config.value}
+            app_config: dict[str, Any] = (
+                {}
+                if isinstance(ingestion_config, dict)
+                else ingestion_config["app"]
+            )
+            ingestion_config = dict(ingestion_config)
+            ingestion_config["app"] = app_config
             data["ingestion_config"] = json.dumps(ingestion_config)
             data["ingestion_config"] = json.dumps(ingestion_config)
         if collection_ids:
         if collection_ids:
             collection_ids = [str(collection_id) for collection_id in collection_ids]  # type: ignore
             collection_ids = [str(collection_id) for collection_id in collection_ids]  # type: ignore
@@ -125,7 +135,7 @@ class DocumentsSDK:
         Get a specific document by ID.
         Get a specific document by ID.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to retrieve
+            id (str | UUID): ID of document to retrieve
 
 
         Returns:
         Returns:
             dict: Document information
             dict: Document information
@@ -136,7 +146,6 @@ class DocumentsSDK:
             version="v3",
             version="v3",
         )
         )
 
 
-    # you could do something like:
     async def download(
     async def download(
         self,
         self,
         id: str | UUID,
         id: str | UUID,
@@ -145,12 +154,196 @@ class DocumentsSDK:
             "GET",
             "GET",
             f"documents/{str(id)}/download",
             f"documents/{str(id)}/download",
             version="v3",
             version="v3",
-            # No json parsing here, if possible
         )
         )
         if not isinstance(response, BytesIO):
         if not isinstance(response, BytesIO):
             raise ValueError("Expected BytesIO response")
             raise ValueError("Expected BytesIO response")
         return response
         return response
 
 
+    async def download_zip(
+        self,
+        document_ids: Optional[list[str | UUID]] = None,
+        start_date: Optional[datetime] = None,
+        end_date: Optional[datetime] = None,
+        output_path: Optional[str | Path] = None,
+    ) -> BytesIO | None:
+        """
+        Download multiple documents as a zip file.
+        """
+        params: dict[str, Any] = {}
+        if document_ids:
+            params["document_ids"] = [str(doc_id) for doc_id in document_ids]
+        if start_date:
+            params["start_date"] = start_date.isoformat()
+        if end_date:
+            params["end_date"] = end_date.isoformat()
+
+        response = await self.client._make_request(
+            "GET",
+            "documents/download_zip",
+            params=params,
+            version="v3",
+        )
+
+        if not isinstance(response, BytesIO):
+            raise ValueError("Expected BytesIO response")
+
+        if output_path:
+            output_path = (
+                Path(output_path)
+                if isinstance(output_path, str)
+                else output_path
+            )
+            async with aiofiles.open(output_path, "wb") as f:
+                await f.write(response.getvalue())
+            return None
+
+        return response
+
+    async def export(
+        self,
+        output_path: str | Path,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> None:
+        """
+        Export documents to a CSV file, streaming the results directly to disk.
+
+        Args:
+            output_path (str | Path): Local path where the CSV file should be saved
+            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
+            filters (Optional[dict]): Optional filters to apply when selecting documents
+            include_header (bool): Whether to include column headers in the CSV (default: True)
+        """
+        # Convert path to string if it's a Path object
+        output_path = (
+            str(output_path) if isinstance(output_path, Path) else output_path
+        )
+
+        data: dict[str, Any] = {"include_header": include_header}
+        if columns:
+            data["columns"] = columns
+        if filters:
+            data["filters"] = filters
+
+        # Stream response directly to file
+        async with aiofiles.open(output_path, "wb") as f:
+            async with self.client.session.post(
+                f"{self.client.base_url}/v3/documents/export",
+                json=data,
+                headers={
+                    "Accept": "text/csv",
+                    **self.client._get_auth_headers(),
+                },
+            ) as response:
+                if response.status != 200:
+                    raise ValueError(
+                        f"Export failed with status {response.status}",
+                        response,
+                    )
+
+                async for chunk in response.content.iter_chunks():
+                    if chunk:
+                        await f.write(chunk[0])
+
+    async def export_entities(
+        self,
+        id: str | UUID,
+        output_path: str | Path,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> None:
+        """
+        Export documents to a CSV file, streaming the results directly to disk.
+
+        Args:
+            output_path (str | Path): Local path where the CSV file should be saved
+            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
+            filters (Optional[dict]): Optional filters to apply when selecting documents
+            include_header (bool): Whether to include column headers in the CSV (default: True)
+        """
+        # Convert path to string if it's a Path object
+        output_path = (
+            str(output_path) if isinstance(output_path, Path) else output_path
+        )
+
+        # Prepare request data
+        data: dict[str, Any] = {"include_header": include_header}
+        if columns:
+            data["columns"] = columns
+        if filters:
+            data["filters"] = filters
+
+        # Stream response directly to file
+        async with aiofiles.open(output_path, "wb") as f:
+            async with self.client.session.post(
+                f"{self.client.base_url}/v3/documents/{str(id)}/entities/export",
+                json=data,
+                headers={
+                    "Accept": "text/csv",
+                    **self.client._get_auth_headers(),
+                },
+            ) as response:
+                if response.status != 200:
+                    raise ValueError(
+                        f"Export failed with status {response.status}",
+                        response,
+                    )
+
+                async for chunk in response.content.iter_chunks():
+                    if chunk:
+                        await f.write(chunk[0])
+
+    async def export_relationships(
+        self,
+        id: str | UUID,
+        output_path: str | Path,
+        columns: Optional[list[str]] = None,
+        filters: Optional[dict] = None,
+        include_header: bool = True,
+    ) -> None:
+        """
+        Export document relationships to a CSV file, streaming the results directly to disk.
+
+        Args:
+            output_path (str | Path): Local path where the CSV file should be saved
+            columns (Optional[list[str]]): Specific columns to export. If None, exports default columns
+            filters (Optional[dict]): Optional filters to apply when selecting documents
+            include_header (bool): Whether to include column headers in the CSV (default: True)
+        """
+        # Convert path to string if it's a Path object
+        output_path = (
+            str(output_path) if isinstance(output_path, Path) else output_path
+        )
+
+        # Prepare request data
+        data: dict[str, Any] = {"include_header": include_header}
+        if columns:
+            data["columns"] = columns
+        if filters:
+            data["filters"] = filters
+
+        # Stream response directly to file
+        async with aiofiles.open(output_path, "wb") as f:
+            async with self.client.session.post(
+                f"{self.client.base_url}/v3/documents/{str(id)}/relationships/export",
+                json=data,
+                headers={
+                    "Accept": "text/csv",
+                    **self.client._get_auth_headers(),
+                },
+            ) as response:
+                if response.status != 200:
+                    raise ValueError(
+                        f"Export failed with status {response.status}",
+                        response,
+                    )
+
+                async for chunk in response.content.iter_chunks():
+                    if chunk:
+                        await f.write(chunk[0])
+
     async def delete(
     async def delete(
         self,
         self,
         id: str | UUID,
         id: str | UUID,
@@ -159,7 +352,7 @@ class DocumentsSDK:
         Delete a specific document.
         Delete a specific document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to delete
+            id (str | UUID): ID of document to delete
         """
         """
         return await self.client._make_request(
         return await self.client._make_request(
             "DELETE",
             "DELETE",
@@ -178,7 +371,7 @@ class DocumentsSDK:
         Get chunks for a specific document.
         Get chunks for a specific document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to retrieve chunks for
+            id (str | UUID): ID of document to retrieve chunks for
             include_vectors (Optional[bool]): Whether to include vector embeddings in the response
             include_vectors (Optional[bool]): Whether to include vector embeddings in the response
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
@@ -209,7 +402,7 @@ class DocumentsSDK:
         List collections for a specific document.
         List collections for a specific document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to retrieve collections for
+            id (str | UUID): ID of document to retrieve collections for
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
 
 
@@ -259,7 +452,7 @@ class DocumentsSDK:
         Extract entities and relationships from a document.
         Extract entities and relationships from a document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to extract from
+            id (str, UUID): ID of document to extract from
             run_type (Optional[str]): Whether to return an estimate or run extraction
             run_type (Optional[str]): Whether to return an estimate or run extraction
             settings (Optional[dict]): Settings for extraction process
             settings (Optional[dict]): Settings for extraction process
             run_with_orchestration (Optional[bool]): Whether to run with orchestration
             run_with_orchestration (Optional[bool]): Whether to run with orchestration
@@ -267,7 +460,7 @@ class DocumentsSDK:
         Returns:
         Returns:
             dict: Extraction results or cost estimate
             dict: Extraction results or cost estimate
         """
         """
-        data = {}
+        data: dict[str, Any] = {}
         if run_type:
         if run_type:
             data["run_type"] = run_type
             data["run_type"] = run_type
         if settings:
         if settings:
@@ -293,7 +486,7 @@ class DocumentsSDK:
         List entities extracted from a document.
         List entities extracted from a document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to get entities from
+            id (str | UUID): ID of document to get entities from
             offset (Optional[int]): Number of items to skip
             offset (Optional[int]): Number of items to skip
             limit (Optional[int]): Max number of items to return
             limit (Optional[int]): Max number of items to return
             include_embeddings (Optional[bool]): Whether to include embeddings
             include_embeddings (Optional[bool]): Whether to include embeddings
@@ -325,7 +518,7 @@ class DocumentsSDK:
         List relationships extracted from a document.
         List relationships extracted from a document.
 
 
         Args:
         Args:
-            id (Union[str, UUID]): ID of document to get relationships from
+            id (str | UUID): ID of document to get relationships from
             offset (Optional[int]): Number of items to skip
             offset (Optional[int]): Number of items to skip
             limit (Optional[int]): Max number of items to return
             limit (Optional[int]): Max number of items to return
             entity_names (Optional[list[str]]): Filter by entity names
             entity_names (Optional[list[str]]): Filter by entity names
@@ -350,28 +543,6 @@ class DocumentsSDK:
             version="v3",
             version="v3",
         )
         )
 
 
-    # async def extract(
-    #     self,
-    #     id: str | UUID,
-    #     run_type: Optional[str] = None,
-    #     run_with_orchestration: Optional[bool] = True,
-    # ):
-    #     data = {}
-
-    #     if run_type:
-    #         data["run_type"] = run_type
-    #     if run_with_orchestration is not None:
-    #         data["run_with_orchestration"] = str(run_with_orchestration)
-
-    #     return await self.client._make_request(
-    #         "POST",
-    #         f"documents/{str(id)}/extract",
-    #         params=data,
-    #         version="v3",
-    #     )
-
-    # Be sure to put at bottom of the page...
-
     async def list(
     async def list(
         self,
         self,
         ids: Optional[list[str | UUID]] = None,
         ids: Optional[list[str | UUID]] = None,
@@ -382,7 +553,7 @@ class DocumentsSDK:
         List documents with pagination.
         List documents with pagination.
 
 
         Args:
         Args:
-            ids (Optional[list[Union[str, UUID]]]): Optional list of document IDs to filter by
+            ids (Optional[list[str | UUID]]): Optional list of document IDs to filter by
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             offset (int, optional): Specifies the number of objects to skip. Defaults to 0.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
             limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.
 
 
@@ -424,7 +595,7 @@ class DocumentsSDK:
 
 
         if search_settings and not isinstance(search_settings, dict):
         if search_settings and not isinstance(search_settings, dict):
             search_settings = search_settings.model_dump()
             search_settings = search_settings.model_dump()
-        data = {
+        data: dict[str, Any] = {
             "query": query,
             "query": query,
             "search_settings": search_settings,
             "search_settings": search_settings,
         }
         }

+ 2 - 2
sdk/v3/graphs.py

@@ -108,7 +108,7 @@ class GraphsSDK:
         Returns:
         Returns:
             dict: Updated graph information
             dict: Updated graph information
         """
         """
-        data = {}
+        data: dict[str, Any] = {}
         if name is not None:
         if name is not None:
             data["name"] = name
             data["name"] = name
         if description is not None:
         if description is not None:
@@ -290,7 +290,7 @@ class GraphsSDK:
         Returns:
         Returns:
             dict: Success message
             dict: Success message
         """
         """
-        data = {
+        data: dict[str, Any] = {
             "run_type": run_type,
             "run_type": run_type,
             "run_with_orchestration": run_with_orchestration,
             "run_with_orchestration": run_with_orchestration,
         }
         }

+ 3 - 29
sdk/v3/indices.py

@@ -1,5 +1,5 @@
 import json
 import json
-from typing import Optional
+from typing import Any, Optional
 
 
 from shared.api.models.base import WrappedGenericMessageResponse
 from shared.api.models.base import WrappedGenericMessageResponse
 from shared.api.models.ingestion.responses import (
 from shared.api.models.ingestion.responses import (
@@ -20,13 +20,13 @@ class IndicesSDK:
         Create a new vector similarity search index in the database.
         Create a new vector similarity search index in the database.
 
 
         Args:
         Args:
-            config (Union[dict, IndexConfig]): Configuration for the vector index.
+            config (dict | IndexConfig): Configuration for the vector index.
             run_with_orchestration (Optional[bool]): Whether to run index creation as an orchestrated task.
             run_with_orchestration (Optional[bool]): Whether to run index creation as an orchestrated task.
         """
         """
         if not isinstance(config, dict):
         if not isinstance(config, dict):
             config = config.model_dump()
             config = config.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "config": config,
             "config": config,
             "run_with_orchestration": run_with_orchestration,
             "run_with_orchestration": run_with_orchestration,
         }
         }
@@ -88,32 +88,6 @@ class IndicesSDK:
             version="v3",
             version="v3",
         )
         )
 
 
-    # async def update_index(
-    #     self,
-    #     id: Union[str, UUID],
-    #     config: dict,  # Union[dict, IndexConfig],
-    #     run_with_orchestration: Optional[bool] = True,
-    # ) -> dict:
-    #     """
-    #     Update an existing index's configuration.
-
-    #     Args:
-    #         id (Union[str, UUID]): The ID of the index to update.
-    #         config (Union[dict, IndexConfig]): The new configuration for the index.
-    #         run_with_orchestration (Optional[bool]): Whether to run the update as an orchestrated task.
-
-    #     Returns:
-    #         WrappedUpdateIndexResponse: The response containing the updated index details.
-    #     """
-    #     if not isinstance(config, dict):
-    #         config = config.model_dump()
-
-    #     data = {
-    #         "config": config,
-    #         "run_with_orchestration": run_with_orchestration,
-    #     }
-    #     return await self.client._make_request("POST", f"indices/{id}", json=data)  # type: ignore
-
     async def delete(
     async def delete(
         self,
         self,
         index_name: str,
         index_name: str,

+ 6 - 2
sdk/v3/prompts.py

@@ -1,5 +1,5 @@
 import json
 import json
-from typing import Optional
+from typing import Any, Optional
 
 
 from shared.api.models.base import (
 from shared.api.models.base import (
     WrappedBooleanResponse,
     WrappedBooleanResponse,
@@ -27,7 +27,11 @@ class PromptsSDK:
         Returns:
         Returns:
             dict: Created prompt information
             dict: Created prompt information
         """
         """
-        data = {"name": name, "template": template, "input_types": input_types}
+        data: dict[str, Any] = {
+            "name": name,
+            "template": template,
+            "input_types": input_types,
+        }
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
             "prompts",
             "prompts",

+ 6 - 6
sdk/v3/retrieval.py

@@ -1,4 +1,4 @@
-from typing import AsyncGenerator, Optional
+from typing import Any, AsyncGenerator, Optional
 
 
 from ..models import (
 from ..models import (
     CombinedSearchResponse,
     CombinedSearchResponse,
@@ -41,7 +41,7 @@ class RetrievalSDK:
         if search_settings and not isinstance(search_settings, dict):
         if search_settings and not isinstance(search_settings, dict):
             search_settings = search_settings.model_dump()
             search_settings = search_settings.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "query": query,
             "query": query,
             "search_settings": search_settings,
             "search_settings": search_settings,
         }
         }
@@ -68,7 +68,7 @@ class RetrievalSDK:
         if generation_config and not isinstance(generation_config, dict):
         if generation_config and not isinstance(generation_config, dict):
             generation_config = generation_config.model_dump()
             generation_config = generation_config.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "messages": [msg.model_dump() for msg in cast_messages],
             "messages": [msg.model_dump() for msg in cast_messages],
             "generation_config": generation_config,
             "generation_config": generation_config,
         }
         }
@@ -83,7 +83,7 @@ class RetrievalSDK:
         self,
         self,
         text: str,
         text: str,
     ):
     ):
-        data = {
+        data: dict[str, Any] = {
             "text": text,
             "text": text,
         }
         }
 
 
@@ -123,7 +123,7 @@ class RetrievalSDK:
         if search_settings and not isinstance(search_settings, dict):
         if search_settings and not isinstance(search_settings, dict):
             search_settings = search_settings.model_dump()
             search_settings = search_settings.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "query": query,
             "query": query,
             "rag_generation_config": rag_generation_config,
             "rag_generation_config": rag_generation_config,
             "search_settings": search_settings,
             "search_settings": search_settings,
@@ -179,7 +179,7 @@ class RetrievalSDK:
         if search_settings and not isinstance(search_settings, dict):
         if search_settings and not isinstance(search_settings, dict):
             search_settings = search_settings.model_dump()
             search_settings = search_settings.model_dump()
 
 
-        data = {
+        data: dict[str, Any] = {
             "rag_generation_config": rag_generation_config or {},
             "rag_generation_config": rag_generation_config or {},
             "search_settings": search_settings,
             "search_settings": search_settings,
             "task_prompt_override": task_prompt_override,
             "task_prompt_override": task_prompt_override,

+ 13 - 8
sdk/v3/users.py

@@ -1,7 +1,6 @@
 from __future__ import annotations  # for Python 3.10+
 from __future__ import annotations  # for Python 3.10+
-import json
 
 
-from typing import Optional
+from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from shared.api.models.auth.responses import WrappedTokenResponse
 from shared.api.models.auth.responses import WrappedTokenResponse
@@ -72,7 +71,7 @@ class UsersSDK:
         Returns:
         Returns:
             UserResponse: New user information
             UserResponse: New user information
         """
         """
-        data = {"email": email, "password": password}
+        data: dict[str, Any] = {"email": email, "password": password}
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
             "users/register",
             "users/register",
@@ -94,7 +93,7 @@ class UsersSDK:
         Returns:
         Returns:
             dict: Deletion result
             dict: Deletion result
         """
         """
-        data = {"password": password}
+        data: dict[str, Any] = {"password": password}
         response = await self.client._make_request(
         response = await self.client._make_request(
             "DELETE",
             "DELETE",
             f"users/{str(id)}",
             f"users/{str(id)}",
@@ -118,7 +117,10 @@ class UsersSDK:
         Returns:
         Returns:
             dict: Verification result
             dict: Verification result
         """
         """
-        data = {"email": email, "verification_code": verification_code}
+        data: dict[str, Any] = {
+            "email": email,
+            "verification_code": verification_code,
+        }
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
             "users/verify-email",
             "users/verify-email",
@@ -141,7 +143,7 @@ class UsersSDK:
             raise ValueError(
             raise ValueError(
                 "Cannot log in after setting an API key, please unset your R2R_API_KEY variable or call client.set_api_key(None)"
                 "Cannot log in after setting an API key, please unset your R2R_API_KEY variable or call client.set_api_key(None)"
             )
             )
-        data = {"username": email, "password": password}
+        data: dict[str, Any] = {"username": email, "password": password}
         response = await self.client._make_request(
         response = await self.client._make_request(
             "POST",
             "POST",
             "users/login",
             "users/login",
@@ -227,7 +229,7 @@ class UsersSDK:
         Returns:
         Returns:
             dict: Change password result
             dict: Change password result
         """
         """
-        data = {
+        data: dict[str, Any] = {
             "current_password": current_password,
             "current_password": current_password,
             "new_password": new_password,
             "new_password": new_password,
         }
         }
@@ -270,7 +272,10 @@ class UsersSDK:
         Returns:
         Returns:
             dict: Password reset result
             dict: Password reset result
         """
         """
-        data = {"reset_token": reset_token, "new_password": new_password}
+        data: dict[str, Any] = {
+            "reset_token": reset_token,
+            "new_password": new_password,
+        }
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",
             "users/reset-password",
             "users/reset-password",

+ 2 - 9
shared/abstractions/__init__.py

@@ -18,15 +18,11 @@ from .exception import (
     R2RDocumentProcessingError,
     R2RDocumentProcessingError,
     R2RException,
     R2RException,
 )
 )
-from .graph import Community, Entity, KGExtraction, Relationship
+from .graph import Community, Entity, KGExtraction, Relationship, StoreType
 from .kg import (
 from .kg import (
-    GraphBuildSettings,
     GraphCommunitySettings,
     GraphCommunitySettings,
-    GraphEntitySettings,
-    GraphRelationshipSettings,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEntityDeduplicationSettings,
     KGRunType,
     KGRunType,
 )
 )
 from .llm import (
 from .llm import (
@@ -47,7 +43,6 @@ from .search import (
     HybridSearchSettings,
     HybridSearchSettings,
     KGCommunityResult,
     KGCommunityResult,
     KGEntityResult,
     KGEntityResult,
-    KGGlobalResult,
     KGRelationshipResult,
     KGRelationshipResult,
     KGSearchResultType,
     KGSearchResultType,
     SearchMode,
     SearchMode,
@@ -99,6 +94,7 @@ __all__ = [
     "Community",
     "Community",
     "KGExtraction",
     "KGExtraction",
     "Relationship",
     "Relationship",
+    "StoreType",
     # LLM abstractions
     # LLM abstractions
     "GenerationConfig",
     "GenerationConfig",
     "LLMChatCompletion",
     "LLMChatCompletion",
@@ -114,7 +110,6 @@ __all__ = [
     "KGEntityResult",
     "KGEntityResult",
     "KGRelationshipResult",
     "KGRelationshipResult",
     "KGCommunityResult",
     "KGCommunityResult",
-    "KGGlobalResult",
     "GraphSearchSettings",
     "GraphSearchSettings",
     "ChunkSearchSettings",
     "ChunkSearchSettings",
     "ChunkSearchResult",
     "ChunkSearchResult",
@@ -127,8 +122,6 @@ __all__ = [
     "KGEnrichmentSettings",
     "KGEnrichmentSettings",
     "KGExtraction",
     "KGExtraction",
     "KGRunType",
     "KGRunType",
-    "GraphEntitySettings",
-    "GraphRelationshipSettings",
     "GraphCommunitySettings",
     "GraphCommunitySettings",
     # User abstractions
     # User abstractions
     "Token",
     "Token",

+ 9 - 4
shared/abstractions/base.py

@@ -2,7 +2,7 @@ import asyncio
 import json
 import json
 from datetime import datetime
 from datetime import datetime
 from enum import Enum
 from enum import Enum
-from typing import Any, Type, TypeVar, Union
+from typing import Any, Type, TypeVar
 from uuid import UUID
 from uuid import UUID
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -12,10 +12,15 @@ T = TypeVar("T", bound="R2RSerializable")
 
 
 class R2RSerializable(BaseModel):
 class R2RSerializable(BaseModel):
     @classmethod
     @classmethod
-    def from_dict(cls: Type[T], data: Union[dict[str, Any], str]) -> T:
+    def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
         if isinstance(data, str):
         if isinstance(data, str):
-            data = json.loads(data)
-        return cls(**data)
+            try:
+                data_dict = json.loads(data)
+            except json.JSONDecodeError as e:
+                raise ValueError(f"Invalid JSON string: {e}") from e
+        else:
+            data_dict = data
+        return cls(**data_dict)
 
 
     def to_dict(self) -> dict[str, Any]:
     def to_dict(self) -> dict[str, Any]:
         data = self.model_dump(exclude_unset=True)
         data = self.model_dump(exclude_unset=True)

+ 1 - 1
shared/abstractions/document.py

@@ -282,7 +282,7 @@ class ChunkEnrichmentSettings(R2RSerializable):
     )
     )
     strategies: list[ChunkEnrichmentStrategy] = Field(
     strategies: list[ChunkEnrichmentStrategy] = Field(
         default=[],
         default=[],
-        description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.",
+        description="The strategies to use for chunk enrichment. List of chunks obtained from each strategy is used as context.",
     )
     )
     forward_chunks: int = Field(
     forward_chunks: int = Field(
         default=3,
         default=3,

+ 6 - 0
shared/abstractions/graph.py

@@ -1,6 +1,7 @@
 import json
 import json
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import datetime
 from datetime import datetime
+from enum import Enum
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
@@ -134,3 +135,8 @@ class Graph(R2RSerializable):
 
 
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
+
+
+class StoreType(str, Enum):
+    GRAPHS = "graphs"
+    DOCUMENTS = "documents"

+ 1 - 1
shared/abstractions/ingestion.py

@@ -27,7 +27,7 @@ class ChunkEnrichmentSettings(R2RSerializable):
     )
     )
     strategies: list[ChunkEnrichmentStrategy] = Field(
     strategies: list[ChunkEnrichmentStrategy] = Field(
         default=[],
         default=[],
-        description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.",
+        description="The strategies to use for chunk enrichment.",
     )
     )
     forward_chunks: int = Field(
     forward_chunks: int = Field(
         default=3,
         default=3,

+ 0 - 84
shared/abstractions/kg.py

@@ -19,17 +19,6 @@ class KGRunType(str, Enum):
 GraphRunType = KGRunType
 GraphRunType = KGRunType
 
 
 
 
-class KGEntityDeduplicationType(str, Enum):
-    """Type of KG entity deduplication."""
-
-    BY_NAME = "by_name"
-    BY_DESCRIPTION = "by_description"
-    BY_LLM = "by_llm"
-
-    def __str__(self):
-        return self.value
-
-
 class KGCreationSettings(R2RSerializable):
 class KGCreationSettings(R2RSerializable):
     """Settings for knowledge graph creation."""
     """Settings for knowledge graph creation."""
 
 
@@ -81,30 +70,6 @@ class KGCreationSettings(R2RSerializable):
     )
     )
 
 
 
 
-class KGEntityDeduplicationSettings(R2RSerializable):
-    """Settings for knowledge graph entity deduplication."""
-
-    graph_entity_deduplication_type: KGEntityDeduplicationType = Field(
-        default=KGEntityDeduplicationType.BY_NAME,
-        description="The type of entity deduplication to use.",
-    )
-
-    max_description_input_length: int = Field(
-        default=65536,
-        description="The maximum length of the description for a node in the graph.",
-    )
-
-    graph_entity_deduplication_prompt: str = Field(
-        default="graphrag_entity_deduplication",
-        description="The prompt to use for knowledge graph entity deduplication.",
-    )
-
-    generation_config: GenerationConfig = Field(
-        default_factory=GenerationConfig,
-        description="Configuration for text generation during graph entity deduplication.",
-    )
-
-
 class KGEnrichmentSettings(R2RSerializable):
 class KGEnrichmentSettings(R2RSerializable):
     """Settings for knowledge graph enrichment."""
     """Settings for knowledge graph enrichment."""
 
 
@@ -135,36 +100,6 @@ class KGEnrichmentSettings(R2RSerializable):
     )
     )
 
 
 
 
-class GraphEntitySettings(R2RSerializable):
-    """Settings for knowledge graph entity creation."""
-
-    graph_entity_deduplication_type: KGEntityDeduplicationType = Field(
-        default=KGEntityDeduplicationType.BY_NAME,
-        description="The type of entity deduplication to use.",
-    )
-
-    max_description_input_length: int = Field(
-        default=65536,
-        description="The maximum length of the description for a node in the graph.",
-    )
-
-    graph_entity_deduplication_prompt: str = Field(
-        default="graphrag_entity_deduplication",
-        description="The prompt to use for knowledge graph entity deduplication.",
-    )
-
-    generation_config: GenerationConfig = Field(
-        default_factory=GenerationConfig,
-        description="Configuration for text generation during graph entity deduplication.",
-    )
-
-
-class GraphRelationshipSettings(R2RSerializable):
-    """Settings for knowledge graph relationship creation."""
-
-    pass
-
-
 class GraphCommunitySettings(R2RSerializable):
 class GraphCommunitySettings(R2RSerializable):
     """Settings for knowledge graph community enrichment."""
     """Settings for knowledge graph community enrichment."""
 
 
@@ -192,22 +127,3 @@ class GraphCommunitySettings(R2RSerializable):
         default_factory=dict,
         default_factory=dict,
         description="Parameters for the Leiden algorithm.",
         description="Parameters for the Leiden algorithm.",
     )
     )
-
-
-class GraphBuildSettings(R2RSerializable):
-    """Settings for knowledge graph build."""
-
-    entity_settings: GraphEntitySettings = Field(
-        default=GraphEntitySettings(),
-        description="Settings for knowledge graph entity creation.",
-    )
-
-    relationship_settings: GraphRelationshipSettings = Field(
-        default=GraphRelationshipSettings(),
-        description="Settings for knowledge graph relationship creation.",
-    )
-
-    community_settings: GraphCommunitySettings = Field(
-        default=GraphCommunitySettings(),
-        description="Settings for knowledge graph community enrichment.",
-    )

+ 1 - 19
shared/abstractions/search.py

@@ -113,26 +113,8 @@ class KGCommunityResult(R2RSerializable):
         }
         }
 
 
 
 
-class KGGlobalResult(R2RSerializable):
-    name: str
-    description: str
-    metadata: Optional[dict[str, Any]] = None
-
-    class Config:
-        json_schema_extra = {
-            "name": "Global Result Name",
-            "description": "Global Result Description",
-            "metadata": {},
-        }
-
-
 class GraphSearchResult(R2RSerializable):
 class GraphSearchResult(R2RSerializable):
-    content: (
-        KGEntityResult
-        | KGRelationshipResult
-        | KGCommunityResult
-        | KGGlobalResult
-    )
+    content: KGEntityResult | KGRelationshipResult | KGCommunityResult
     result_type: Optional[KGSearchResultType] = None
     result_type: Optional[KGSearchResultType] = None
     chunk_ids: Optional[list[UUID]] = None
     chunk_ids: Optional[list[UUID]] = None
     metadata: dict[str, Any] = {}
     metadata: dict[str, Any] = {}

+ 0 - 1
shared/utils/base_utils.py

@@ -17,7 +17,6 @@ from ..abstractions.search import (
     AggregateSearchResult,
     AggregateSearchResult,
     KGCommunityResult,
     KGCommunityResult,
     KGEntityResult,
     KGEntityResult,
-    KGGlobalResult,
     KGRelationshipResult,
     KGRelationshipResult,
 )
 )
 from ..abstractions.vector import VectorQuantizationType
 from ..abstractions.vector import VectorQuantizationType

+ 77 - 80
shared/utils/splitter/text.py

@@ -37,9 +37,7 @@ from typing import (
     Any,
     Any,
     Callable,
     Callable,
     Collection,
     Collection,
-    Dict,
     Iterable,
     Iterable,
-    List,
     Literal,
     Literal,
     Optional,
     Optional,
     Sequence,
     Sequence,
@@ -47,7 +45,6 @@ from typing import (
     Type,
     Type,
     TypedDict,
     TypedDict,
     TypeVar,
     TypeVar,
-    Union,
     cast,
     cast,
 )
 )
 
 
@@ -64,16 +61,16 @@ class BaseSerialized(TypedDict):
     """Base class for serialized objects."""
     """Base class for serialized objects."""
 
 
     lc: int
     lc: int
-    id: List[str]
+    id: list[str]
     name: NotRequired[str]
     name: NotRequired[str]
-    graph: NotRequired[Dict[str, Any]]
+    graph: NotRequired[dict[str, Any]]
 
 
 
 
 class SerializedConstructor(BaseSerialized):
 class SerializedConstructor(BaseSerialized):
     """Serialized constructor."""
     """Serialized constructor."""
 
 
     type: Literal["constructor"]
     type: Literal["constructor"]
-    kwargs: Dict[str, Any]
+    kwargs: dict[str, Any]
 
 
 
 
 class SerializedSecret(BaseSerialized):
 class SerializedSecret(BaseSerialized):
@@ -115,7 +112,7 @@ class Serializable(BaseModel, ABC):
         return False
         return False
 
 
     @classmethod
     @classmethod
-    def get_lc_namespace(cls) -> List[str]:
+    def get_lc_namespace(cls) -> list[str]:
         """Get the namespace of the langchain object.
         """Get the namespace of the langchain object.
 
 
         For example, if the class is `langchain.llms.openai.OpenAI`, then the
         For example, if the class is `langchain.llms.openai.OpenAI`, then the
@@ -124,16 +121,16 @@ class Serializable(BaseModel, ABC):
         return cls.__module__.split(".")
         return cls.__module__.split(".")
 
 
     @property
     @property
-    def lc_secrets(self) -> Dict[str, str]:
+    def lc_secrets(self) -> dict[str, str]:
         """A map of constructor argument names to secret ids.
         """A map of constructor argument names to secret ids.
 
 
         For example,
         For example,
             {"openai_api_key": "OPENAI_API_KEY"}
             {"openai_api_key": "OPENAI_API_KEY"}
         """
         """
-        return dict()
+        return {}
 
 
     @property
     @property
-    def lc_attributes(self) -> Dict:
+    def lc_attributes(self) -> dict:
         """List of attribute names that should be included in the serialized kwargs.
         """List of attribute names that should be included in the serialized kwargs.
 
 
         These attributes must be accepted by the constructor.
         These attributes must be accepted by the constructor.
@@ -141,7 +138,7 @@ class Serializable(BaseModel, ABC):
         return {}
         return {}
 
 
     @classmethod
     @classmethod
-    def lc_id(cls) -> List[str]:
+    def lc_id(cls) -> list[str]:
         """A unique identifier for this class for serialization purposes.
         """A unique identifier for this class for serialization purposes.
 
 
         The unique identifier is a list of strings that describes the path
         The unique identifier is a list of strings that describes the path
@@ -159,7 +156,7 @@ class Serializable(BaseModel, ABC):
             if (k not in self.__fields__ or try_neq_default(v, k, self))
             if (k not in self.__fields__ or try_neq_default(v, k, self))
         ]
         ]
 
 
-    _lc_kwargs = PrivateAttr(default_factory=dict)
+    _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict)
 
 
     def __init__(self, **kwargs: Any) -> None:
     def __init__(self, **kwargs: Any) -> None:
         super().__init__(**kwargs)
         super().__init__(**kwargs)
@@ -167,7 +164,7 @@ class Serializable(BaseModel, ABC):
 
 
     def to_json(
     def to_json(
         self,
         self,
-    ) -> Union[SerializedConstructor, SerializedNotImplemented]:
+    ) -> SerializedConstructor | SerializedNotImplemented:
         if not self.is_lc_serializable():
         if not self.is_lc_serializable():
             return self.to_json_not_implemented()
             return self.to_json_not_implemented()
 
 
@@ -238,8 +235,8 @@ class Serializable(BaseModel, ABC):
 
 
 
 
 def _replace_secrets(
 def _replace_secrets(
-    root: Dict[Any, Any], secrets_map: Dict[str, str]
-) -> Dict[Any, Any]:
+    root: dict[Any, Any], secrets_map: dict[str, str]
+) -> dict[Any, Any]:
     result = root.copy()
     result = root.copy()
     for path, secret_id in secrets_map.items():
     for path, secret_id in secrets_map.items():
         [*parts, last] = path.split(".")
         [*parts, last] = path.split(".")
@@ -267,7 +264,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
     Returns:
     Returns:
         SerializedNotImplemented
         SerializedNotImplemented
     """
     """
-    _id: List[str] = []
+    _id: list[str] = []
     try:
     try:
         if hasattr(obj, "__name__"):
         if hasattr(obj, "__name__"):
             _id = [*obj.__module__.split("."), obj.__name__]
             _id = [*obj.__module__.split("."), obj.__name__]
@@ -313,7 +310,7 @@ class SplitterDocument(Serializable):
         return True
         return True
 
 
     @classmethod
     @classmethod
-    def get_lc_namespace(cls) -> List[str]:
+    def get_lc_namespace(cls) -> list[str]:
         """Get the namespace of the langchain object."""
         """Get the namespace of the langchain object."""
         return ["langchain", "schema", "document"]
         return ["langchain", "schema", "document"]
 
 
@@ -406,7 +403,7 @@ def _make_spacy_pipe_for_splitting(
 
 
 def _split_text_with_regex(
 def _split_text_with_regex(
     text: str, separator: str, keep_separator: bool
     text: str, separator: str, keep_separator: bool
-) -> List[str]:
+) -> list[str]:
     # Now that we have the separator, split the text
     # Now that we have the separator, split the text
     if separator:
     if separator:
         if keep_separator:
         if keep_separator:
@@ -461,12 +458,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
         self._strip_whitespace = strip_whitespace
         self._strip_whitespace = strip_whitespace
 
 
     @abstractmethod
     @abstractmethod
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         """Split text into multiple components."""
         """Split text into multiple components."""
 
 
     def create_documents(
     def create_documents(
-        self, texts: List[str], metadatas: Optional[List[dict]] = None
-    ) -> List[SplitterDocument]:
+        self, texts: list[str], metadatas: Optional[list[dict]] = None
+    ) -> list[SplitterDocument]:
         """Create documents from a list of texts."""
         """Create documents from a list of texts."""
         _metadatas = metadatas or [{}] * len(texts)
         _metadatas = metadatas or [{}] * len(texts)
         documents = []
         documents = []
@@ -488,7 +485,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
 
 
     def split_documents(
     def split_documents(
         self, documents: Iterable[SplitterDocument]
         self, documents: Iterable[SplitterDocument]
-    ) -> List[SplitterDocument]:
+    ) -> list[SplitterDocument]:
         """Split documents."""
         """Split documents."""
         texts, metadatas = [], []
         texts, metadatas = [], []
         for doc in documents:
         for doc in documents:
@@ -496,7 +493,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
             metadatas.append(doc.metadata)
             metadatas.append(doc.metadata)
         return self.create_documents(texts, metadatas=metadatas)
         return self.create_documents(texts, metadatas=metadatas)
 
 
-    def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
+    def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
         text = separator.join(docs)
         text = separator.join(docs)
         if self._strip_whitespace:
         if self._strip_whitespace:
             text = text.strip()
             text = text.strip()
@@ -507,13 +504,13 @@ class TextSplitter(BaseDocumentTransformer, ABC):
 
 
     def _merge_splits(
     def _merge_splits(
         self, splits: Iterable[str], separator: str
         self, splits: Iterable[str], separator: str
-    ) -> List[str]:
+    ) -> list[str]:
         # We now want to combine these smaller pieces into medium size
         # We now want to combine these smaller pieces into medium size
         # chunks to send to the LLM.
         # chunks to send to the LLM.
         separator_len = self._length_function(separator)
         separator_len = self._length_function(separator)
 
 
         docs = []
         docs = []
-        current_doc: List[str] = []
+        current_doc: list[str] = []
         total = 0
         total = 0
         for d in splits:
         for d in splits:
             _len = self._length_function(d)
             _len = self._length_function(d)
@@ -579,8 +576,8 @@ class TextSplitter(BaseDocumentTransformer, ABC):
         cls: Type[TS],
         cls: Type[TS],
         encoding_name: str = "gpt2",
         encoding_name: str = "gpt2",
         model: Optional[str] = None,
         model: Optional[str] = None,
-        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
-        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+        allowed_special: Literal["all"] | AbstractSet[str] = set(),
+        disallowed_special: Literal["all"] | Collection[str] = "all",
         **kwargs: Any,
         **kwargs: Any,
     ) -> TS:
     ) -> TS:
         """Text splitter that uses tiktoken encoder to count length."""
         """Text splitter that uses tiktoken encoder to count length."""
@@ -641,7 +638,7 @@ class CharacterTextSplitter(TextSplitter):
         self._separator = separator
         self._separator = separator
         self._is_separator_regex = is_separator_regex
         self._is_separator_regex = is_separator_regex
 
 
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         """Split incoming text and return chunks."""
         """Split incoming text and return chunks."""
         # First we naively split the large input into a bunch of smaller ones.
         # First we naively split the large input into a bunch of smaller ones.
         separator = (
         separator = (
@@ -657,7 +654,7 @@ class CharacterTextSplitter(TextSplitter):
 class LineType(TypedDict):
 class LineType(TypedDict):
     """Line type as typed dict."""
     """Line type as typed dict."""
 
 
-    metadata: Dict[str, str]
+    metadata: dict[str, str]
     content: str
     content: str
 
 
 
 
@@ -674,7 +671,7 @@ class MarkdownHeaderTextSplitter:
 
 
     def __init__(
     def __init__(
         self,
         self,
-        headers_to_split_on: List[Tuple[str, str]],
+        headers_to_split_on: list[Tuple[str, str]],
         return_each_line: bool = False,
         return_each_line: bool = False,
         strip_headers: bool = True,
         strip_headers: bool = True,
     ):
     ):
@@ -696,13 +693,13 @@ class MarkdownHeaderTextSplitter:
         self.strip_headers = strip_headers
         self.strip_headers = strip_headers
 
 
     def aggregate_lines_to_chunks(
     def aggregate_lines_to_chunks(
-        self, lines: List[LineType]
-    ) -> List[SplitterDocument]:
+        self, lines: list[LineType]
+    ) -> list[SplitterDocument]:
         """Combine lines with common metadata into chunks
         """Combine lines with common metadata into chunks
         Args:
         Args:
             lines: Line of text / associated header metadata
             lines: Line of text / associated header metadata
         """
         """
-        aggregated_chunks: List[LineType] = []
+        aggregated_chunks: list[LineType] = []
 
 
         for line in lines:
         for line in lines:
             if (
             if (
@@ -742,7 +739,7 @@ class MarkdownHeaderTextSplitter:
             for chunk in aggregated_chunks
             for chunk in aggregated_chunks
         ]
         ]
 
 
-    def split_text(self, text: str) -> List[SplitterDocument]:
+    def split_text(self, text: str) -> list[SplitterDocument]:
         """Split markdown file
         """Split markdown file
         Args:
         Args:
             text: Markdown file"""
             text: Markdown file"""
@@ -750,14 +747,14 @@ class MarkdownHeaderTextSplitter:
         # Split the input text by newline character ("\n").
         # Split the input text by newline character ("\n").
         lines = text.split("\n")
         lines = text.split("\n")
         # Final output
         # Final output
-        lines_with_metadata: List[LineType] = []
+        lines_with_metadata: list[LineType] = []
         # Content and metadata of the chunk currently being processed
         # Content and metadata of the chunk currently being processed
-        current_content: List[str] = []
-        current_metadata: Dict[str, str] = {}
+        current_content: list[str] = []
+        current_metadata: dict[str, str] = {}
         # Keep track of the nested header structure
         # Keep track of the nested header structure
-        # header_stack: List[Dict[str, Union[int, str]]] = []
-        header_stack: List[HeaderType] = []
-        initial_metadata: Dict[str, str] = {}
+        # header_stack: list[dict[str, int | str]] = []
+        header_stack: list[HeaderType] = []
+        initial_metadata: dict[str, str] = {}
 
 
         in_code_block = False
         in_code_block = False
         opening_fence = ""
         opening_fence = ""
@@ -879,7 +876,7 @@ class ElementType(TypedDict):
     url: str
     url: str
     xpath: str
     xpath: str
     content: str
     content: str
-    metadata: Dict[str, str]
+    metadata: dict[str, str]
 
 
 
 
 class HTMLHeaderTextSplitter:
 class HTMLHeaderTextSplitter:
@@ -890,7 +887,7 @@ class HTMLHeaderTextSplitter:
 
 
     def __init__(
     def __init__(
         self,
         self,
-        headers_to_split_on: List[Tuple[str, str]],
+        headers_to_split_on: list[Tuple[str, str]],
         return_each_element: bool = False,
         return_each_element: bool = False,
     ):
     ):
         """Create a new HTMLHeaderTextSplitter.
         """Create a new HTMLHeaderTextSplitter.
@@ -906,14 +903,14 @@ class HTMLHeaderTextSplitter:
         self.headers_to_split_on = sorted(headers_to_split_on)
         self.headers_to_split_on = sorted(headers_to_split_on)
 
 
     def aggregate_elements_to_chunks(
     def aggregate_elements_to_chunks(
-        self, elements: List[ElementType]
-    ) -> List[SplitterDocument]:
+        self, elements: list[ElementType]
+    ) -> list[SplitterDocument]:
         """Combine elements with common metadata into chunks
         """Combine elements with common metadata into chunks
 
 
         Args:
         Args:
             elements: HTML element content with associated identifying info and metadata
             elements: HTML element content with associated identifying info and metadata
         """
         """
-        aggregated_chunks: List[ElementType] = []
+        aggregated_chunks: list[ElementType] = []
 
 
         for element in elements:
         for element in elements:
             if (
             if (
@@ -935,7 +932,7 @@ class HTMLHeaderTextSplitter:
             for chunk in aggregated_chunks
             for chunk in aggregated_chunks
         ]
         ]
 
 
-    def split_text_from_url(self, url: str) -> List[SplitterDocument]:
+    def split_text_from_url(self, url: str) -> list[SplitterDocument]:
         """Split HTML from web URL
         """Split HTML from web URL
 
 
         Args:
         Args:
@@ -944,7 +941,7 @@ class HTMLHeaderTextSplitter:
         r = requests.get(url)
         r = requests.get(url)
         return self.split_text_from_file(BytesIO(r.content))
         return self.split_text_from_file(BytesIO(r.content))
 
 
-    def split_text(self, text: str) -> List[SplitterDocument]:
+    def split_text(self, text: str) -> list[SplitterDocument]:
         """Split HTML text string
         """Split HTML text string
 
 
         Args:
         Args:
@@ -952,7 +949,7 @@ class HTMLHeaderTextSplitter:
         """
         """
         return self.split_text_from_file(StringIO(text))
         return self.split_text_from_file(StringIO(text))
 
 
-    def split_text_from_file(self, file: Any) -> List[SplitterDocument]:
+    def split_text_from_file(self, file: Any) -> list[SplitterDocument]:
         """Split HTML file
         """Split HTML file
 
 
         Args:
         Args:
@@ -1048,15 +1045,15 @@ class Tokenizer:
     """Overlap in tokens between chunks"""
     """Overlap in tokens between chunks"""
     tokens_per_chunk: int
     tokens_per_chunk: int
     """Maximum number of tokens per chunk"""
     """Maximum number of tokens per chunk"""
-    decode: Callable[[List[int]], str]
+    decode: Callable[[list[int]], str]
     """ Function to decode a list of token ids to a string"""
     """ Function to decode a list of token ids to a string"""
-    encode: Callable[[str], List[int]]
+    encode: Callable[[str], list[int]]
     """ Function to encode a string to a list of token ids"""
     """ Function to encode a string to a list of token ids"""
 
 
 
 
-def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
     """Split incoming text and return chunks using tokenizer."""
     """Split incoming text and return chunks using tokenizer."""
-    splits: List[str] = []
+    splits: list[str] = []
     input_ids = tokenizer.encode(text)
     input_ids = tokenizer.encode(text)
     start_idx = 0
     start_idx = 0
     cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
     cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
@@ -1078,8 +1075,8 @@ class TokenTextSplitter(TextSplitter):
         self,
         self,
         encoding_name: str = "gpt2",
         encoding_name: str = "gpt2",
         model: Optional[str] = None,
         model: Optional[str] = None,
-        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
-        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+        allowed_special: Literal["all"] | AbstractSet[str] = set(),
+        disallowed_special: Literal["all"] | Collection[str] = "all",
         **kwargs: Any,
         **kwargs: Any,
     ) -> None:
     ) -> None:
         """Create a new TextSplitter."""
         """Create a new TextSplitter."""
@@ -1101,8 +1098,8 @@ class TokenTextSplitter(TextSplitter):
         self._allowed_special = allowed_special
         self._allowed_special = allowed_special
         self._disallowed_special = disallowed_special
         self._disallowed_special = disallowed_special
 
 
-    def split_text(self, text: str) -> List[str]:
-        def _encode(_text: str) -> List[int]:
+    def split_text(self, text: str) -> list[str]:
+        def _encode(_text: str) -> list[int]:
             return self._tokenizer.encode(
             return self._tokenizer.encode(
                 _text,
                 _text,
                 allowed_special=self._allowed_special,
                 allowed_special=self._allowed_special,
@@ -1164,8 +1161,8 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
                 f" > maximum token limit."
                 f" > maximum token limit."
             )
             )
 
 
-    def split_text(self, text: str) -> List[str]:
-        def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
+    def split_text(self, text: str) -> list[str]:
+        def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
             return self._encode(text)[1:-1]
             return self._encode(text)[1:-1]
 
 
         tokenizer = Tokenizer(
         tokenizer = Tokenizer(
@@ -1182,7 +1179,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
 
 
     _max_length_equal_32_bit_integer: int = 2**32
     _max_length_equal_32_bit_integer: int = 2**32
 
 
-    def _encode(self, text: str) -> List[int]:
+    def _encode(self, text: str) -> list[int]:
         token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
         token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
             text,
             text,
             max_length=self._max_length_equal_32_bit_integer,
             max_length=self._max_length_equal_32_bit_integer,
@@ -1228,7 +1225,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        separators: Optional[List[str]] = None,
+        separators: Optional[list[str]] = None,
         keep_separator: bool = True,
         keep_separator: bool = True,
         is_separator_regex: bool = False,
         is_separator_regex: bool = False,
         chunk_size: int = 4000,
         chunk_size: int = 4000,
@@ -1247,7 +1244,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
         self.chunk_size = chunk_size
         self.chunk_size = chunk_size
         self.chunk_overlap = chunk_overlap
         self.chunk_overlap = chunk_overlap
 
 
-    def _split_text(self, text: str, separators: List[str]) -> List[str]:
+    def _split_text(self, text: str, separators: list[str]) -> list[str]:
         """Split incoming text and return chunks."""
         """Split incoming text and return chunks."""
         final_chunks = []
         final_chunks = []
         # Get appropriate separator to use
         # Get appropriate separator to use
@@ -1289,7 +1286,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
             final_chunks.extend(merged_text)
             final_chunks.extend(merged_text)
         return final_chunks
         return final_chunks
 
 
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         return self._split_text(text, self._separators)
         return self._split_text(text, self._separators)
 
 
     @classmethod
     @classmethod
@@ -1300,7 +1297,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
         return cls(separators=separators, is_separator_regex=True, **kwargs)
         return cls(separators=separators, is_separator_regex=True, **kwargs)
 
 
     @staticmethod
     @staticmethod
-    def get_separators_for_language(language: Language) -> List[str]:
+    def get_separators_for_language(language: Language) -> list[str]:
         if language == Language.CPP:
         if language == Language.CPP:
             return [
             return [
                 # Split along class definitions
                 # Split along class definitions
@@ -1781,7 +1778,7 @@ class NLTKTextSplitter(TextSplitter):
         self._separator = separator
         self._separator = separator
         self._language = language
         self._language = language
 
 
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         """Split incoming text and return chunks."""
         """Split incoming text and return chunks."""
         # First we naively split the large input into a bunch of smaller ones.
         # First we naively split the large input into a bunch of smaller ones.
         splits = self._tokenizer(text, language=self._language)
         splits = self._tokenizer(text, language=self._language)
@@ -1812,7 +1809,7 @@ class SpacyTextSplitter(TextSplitter):
         )
         )
         self._separator = separator
         self._separator = separator
 
 
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         """Split incoming text and return chunks."""
         """Split incoming text and return chunks."""
         splits = (s.text for s in self._tokenizer(text).sents)
         splits = (s.text for s in self._tokenizer(text).sents)
         return self._merge_splits(splits, self._separator)
         return self._merge_splits(splits, self._separator)
@@ -1843,7 +1840,7 @@ class KonlpyTextSplitter(TextSplitter):
             )
             )
         self.kkma = Kkma()
         self.kkma = Kkma()
 
 
-    def split_text(self, text: str) -> List[str]:
+    def split_text(self, text: str) -> list[str]:
         """Split incoming text and return chunks."""
         """Split incoming text and return chunks."""
         splits = self.kkma.sentences(text)
         splits = self.kkma.sentences(text)
         return self._merge_splits(splits, self._separator)
         return self._merge_splits(splits, self._separator)
@@ -1890,12 +1887,12 @@ class RecursiveJsonSplitter:
         )
         )
 
 
     @staticmethod
     @staticmethod
-    def _json_size(data: Dict) -> int:
+    def _json_size(data: dict) -> int:
         """Calculate the size of the serialized JSON object."""
         """Calculate the size of the serialized JSON object."""
         return len(json.dumps(data))
         return len(json.dumps(data))
 
 
     @staticmethod
     @staticmethod
-    def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
+    def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
         """Set a value in a nested dictionary based on the given path."""
         """Set a value in a nested dictionary based on the given path."""
         for key in path[:-1]:
         for key in path[:-1]:
             d = d.setdefault(key, {})
             d = d.setdefault(key, {})
@@ -1919,10 +1916,10 @@ class RecursiveJsonSplitter:
 
 
     def _json_split(
     def _json_split(
         self,
         self,
-        data: Dict[str, Any],
-        current_path: List[str] = [],
-        chunks: List[Dict] = [{}],
-    ) -> List[Dict]:
+        data: dict[str, Any],
+        current_path: list[str] = [],
+        chunks: list[dict] = [{}],
+    ) -> list[dict]:
         """
         """
         Split json into maximum size dictionaries while preserving structure.
         Split json into maximum size dictionaries while preserving structure.
         """
         """
@@ -1950,9 +1947,9 @@ class RecursiveJsonSplitter:
 
 
     def split_json(
     def split_json(
         self,
         self,
-        json_data: Dict[str, Any],
+        json_data: dict[str, Any],
         convert_lists: bool = False,
         convert_lists: bool = False,
-    ) -> List[Dict]:
+    ) -> list[dict]:
         """Splits JSON into a list of JSON chunks"""
         """Splits JSON into a list of JSON chunks"""
 
 
         if convert_lists:
         if convert_lists:
@@ -1968,8 +1965,8 @@ class RecursiveJsonSplitter:
         return chunks
         return chunks
 
 
     def split_text(
     def split_text(
-        self, json_data: Dict[str, Any], convert_lists: bool = False
-    ) -> List[str]:
+        self, json_data: dict[str, Any], convert_lists: bool = False
+    ) -> list[str]:
         """Splits JSON into a list of JSON formatted strings"""
         """Splits JSON into a list of JSON formatted strings"""
 
 
         chunks = self.split_json(
         chunks = self.split_json(
@@ -1981,11 +1978,11 @@ class RecursiveJsonSplitter:
 
 
     def create_documents(
     def create_documents(
         self,
         self,
-        texts: List[Dict],
+        texts: list[dict],
         convert_lists: bool = False,
         convert_lists: bool = False,
-        metadatas: Optional[List[dict]] = None,
-    ) -> List[SplitterDocument]:
-        """Create documents from a list of json objects (Dict)."""
+        metadatas: Optional[list[dict]] = None,
+    ) -> list[SplitterDocument]:
+        """Create documents from a list of json objects (dict)."""
         _metadatas = metadatas or [{}] * len(texts)
         _metadatas = metadatas or [{}] * len(texts)
         documents = []
         documents = []
         for i, text in enumerate(texts):
         for i, text in enumerate(texts):

+ 4 - 2
tests/unit/test_collections.py

@@ -1,8 +1,10 @@
-import pytest
 import uuid
 import uuid
 from uuid import UUID
 from uuid import UUID
-from core.base.api.models import CollectionResponse
+
+import pytest
+
 from core.base import R2RException
 from core.base import R2RException
+from core.base.api.models import CollectionResponse
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio

+ 26 - 27
tests/unit/test_graphs.py

@@ -1,9 +1,8 @@
-import pytest
 import uuid
 import uuid
-from uuid import UUID
-
 from enum import Enum
 from enum import Enum
-from core.base.abstractions import Entity, Relationship, Community
+
+import pytest
+
 from core.base.api.models import GraphResponse
 from core.base.api.models import GraphResponse
 
 
 
 
@@ -35,7 +34,7 @@ async def test_add_entities_and_relationships(graphs_handler):
     # Add an entity
     # Add an entity
     entity = await graphs_handler.entities.create(
     entity = await graphs_handler.entities.create(
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         name="TestEntity",
         name="TestEntity",
         category="Person",
         category="Person",
         description="A test entity",
         description="A test entity",
@@ -45,7 +44,7 @@ async def test_add_entities_and_relationships(graphs_handler):
     # Add another entity
     # Add another entity
     entity2 = await graphs_handler.entities.create(
     entity2 = await graphs_handler.entities.create(
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         name="AnotherEntity",
         name="AnotherEntity",
         category="Place",
         category="Place",
         description="A test place",
         description="A test place",
@@ -59,7 +58,7 @@ async def test_add_entities_and_relationships(graphs_handler):
         object="AnotherEntity",
         object="AnotherEntity",
         object_id=entity2.id,
         object_id=entity2.id,
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         description="Entity lives in AnotherEntity",
         description="Entity lives in AnotherEntity",
     )
     )
     assert rel.predicate == "lives_in"
     assert rel.predicate == "lives_in"
@@ -92,12 +91,12 @@ async def test_delete_entities_and_relationships(graphs_handler):
     # Add entities
     # Add entities
     e1 = await graphs_handler.entities.create(
     e1 = await graphs_handler.entities.create(
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         name="DeleteMe",
         name="DeleteMe",
     )
     )
     e2 = await graphs_handler.entities.create(
     e2 = await graphs_handler.entities.create(
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         name="DeleteMeToo",
         name="DeleteMeToo",
     )
     )
 
 
@@ -109,14 +108,14 @@ async def test_delete_entities_and_relationships(graphs_handler):
         object="DeleteMeToo",
         object="DeleteMeToo",
         object_id=e2.id,
         object_id=e2.id,
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
 
 
     # Delete one entity
     # Delete one entity
     await graphs_handler.entities.delete(
     await graphs_handler.entities.delete(
         parent_id=graph_id,
         parent_id=graph_id,
         entity_ids=[e1.id],
         entity_ids=[e1.id],
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
     ents, count = await graphs_handler.get_entities(
     ents, count = await graphs_handler.get_entities(
         parent_id=graph_id, offset=0, limit=10
         parent_id=graph_id, offset=0, limit=10
@@ -128,7 +127,7 @@ async def test_delete_entities_and_relationships(graphs_handler):
     await graphs_handler.relationships.delete(
     await graphs_handler.relationships.delete(
         parent_id=graph_id,
         parent_id=graph_id,
         relationship_ids=[rel.id],
         relationship_ids=[rel.id],
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
     rels, rel_count = await graphs_handler.get_relationships(
     rels, rel_count = await graphs_handler.get_relationships(
         parent_id=graph_id, offset=0, limit=10
         parent_id=graph_id, offset=0, limit=10
@@ -142,7 +141,7 @@ async def test_communities(graphs_handler):
     coll_id = uuid.uuid4()
     coll_id = uuid.uuid4()
     await graphs_handler.communities.create(
     await graphs_handler.communities.create(
         parent_id=coll_id,
         parent_id=coll_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         name="CommunityOne",
         name="CommunityOne",
         summary="Test community",
         summary="Test community",
         findings=["finding1", "finding2"],
         findings=["finding1", "finding2"],
@@ -153,7 +152,7 @@ async def test_communities(graphs_handler):
 
 
     comms, count = await graphs_handler.communities.get(
     comms, count = await graphs_handler.communities.get(
         parent_id=coll_id,
         parent_id=coll_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
         offset=0,
         offset=0,
         limit=10,
         limit=10,
     )
     )
@@ -284,7 +283,7 @@ async def test_bulk_entities(graphs_handler):
     for ent in entities_to_add:
     for ent in entities_to_add:
         await graphs_handler.entities.create(
         await graphs_handler.entities.create(
             parent_id=graph_id,
             parent_id=graph_id,
-            store_type=StoreType.GRAPHS.value,
+            store_type=StoreType.GRAPHS,
             name=ent["name"],
             name=ent["name"],
             category=ent["category"],
             category=ent["category"],
             description=ent["description"],
             description=ent["description"],
@@ -309,13 +308,13 @@ async def test_relationship_filtering(graphs_handler):
 
 
     # Add entities
     # Add entities
     e1 = await graphs_handler.entities.create(
     e1 = await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node1"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node1"
     )
     )
     e2 = await graphs_handler.entities.create(
     e2 = await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node2"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node2"
     )
     )
     e3 = await graphs_handler.entities.create(
     e3 = await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node3"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="Node3"
     )
     )
 
 
     # Add different relationships
     # Add different relationships
@@ -326,7 +325,7 @@ async def test_relationship_filtering(graphs_handler):
         object="Node2",
         object="Node2",
         object_id=e2.id,
         object_id=e2.id,
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
 
 
     await graphs_handler.relationships.create(
     await graphs_handler.relationships.create(
@@ -336,7 +335,7 @@ async def test_relationship_filtering(graphs_handler):
         object="Node3",
         object="Node3",
         object_id=e3.id,
         object_id=e3.id,
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
 
 
     # Get all relationships
     # Get all relationships
@@ -366,15 +365,15 @@ async def test_delete_all_entities(graphs_handler):
 
 
     # Add some entities
     # Add some entities
     await graphs_handler.entities.create(
     await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1"
     )
     )
     await graphs_handler.entities.create(
     await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2"
     )
     )
 
 
     # Delete all entities without specifying IDs
     # Delete all entities without specifying IDs
     await graphs_handler.entities.delete(
     await graphs_handler.entities.delete(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+        parent_id=graph_id, store_type=StoreType.GRAPHS
     )
     )
     ents, count = await graphs_handler.get_entities(
     ents, count = await graphs_handler.get_entities(
         parent_id=graph_id, offset=0, limit=10
         parent_id=graph_id, offset=0, limit=10
@@ -392,10 +391,10 @@ async def test_delete_all_relationships(graphs_handler):
 
 
     # Add two entities and a relationship
     # Add two entities and a relationship
     e1 = await graphs_handler.entities.create(
     e1 = await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="E1"
     )
     )
     e2 = await graphs_handler.entities.create(
     e2 = await graphs_handler.entities.create(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+        parent_id=graph_id, store_type=StoreType.GRAPHS, name="E2"
     )
     )
     await graphs_handler.relationships.create(
     await graphs_handler.relationships.create(
         subject="E1",
         subject="E1",
@@ -404,12 +403,12 @@ async def test_delete_all_relationships(graphs_handler):
         object="E2",
         object="E2",
         object_id=e2.id,
         object_id=e2.id,
         parent_id=graph_id,
         parent_id=graph_id,
-        store_type=StoreType.GRAPHS.value,
+        store_type=StoreType.GRAPHS,
     )
     )
 
 
     # Delete all relationships
     # Delete all relationships
     await graphs_handler.relationships.delete(
     await graphs_handler.relationships.delete(
-        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+        parent_id=graph_id, store_type=StoreType.GRAPHS
     )
     )
     rels, rel_count = await graphs_handler.get_relationships(
     rels, rel_count = await graphs_handler.get_relationships(
         parent_id=graph_id, offset=0, limit=10
         parent_id=graph_id, offset=0, limit=10