jack 6 months ago
parent
commit
3e8b6e1e6d
100 changed files with 5991 additions and 2035 deletions
  1. 0 32
      cli/commands/system.py
  2. 0 1
      cli/main.py
  3. 0 1
      cli/utils/docker_utils.py
  4. 124 0
      compose.yaml
  5. 24 9
      core/__init__.py
  6. 1 2
      core/base/__init__.py
  7. 12 9
      core/base/api/models/__init__.py
  8. 1 2
      core/base/logger/__init__.py
  9. 0 12
      core/base/logger/base.py
  10. 2 6
      core/base/logger/run_manager.py
  11. 1 2
      core/base/pipes/base_pipe.py
  12. 2 0
      core/base/providers/__init__.py
  13. 89 16
      core/base/providers/auth.py
  14. 3 0
      core/base/providers/base.py
  15. 83 2
      core/base/providers/crypto.py
  16. 49 0
      core/base/providers/database.py
  17. 8 3
      core/base/providers/ingestion.py
  18. 2 0
      core/base/utils/__init__.py
  19. 2 3
      core/configs/local_llm.toml
  20. 61 0
      core/configs/r2r_azure_with_test_limits.toml
  21. 39 206
      core/database/chunks.py
  22. 1 1
      core/database/collections.py
  23. 189 111
      core/database/conversations.py
  24. 10 148
      core/database/documents.py
  25. 433 0
      core/database/filters.py
  26. 87 181
      core/database/graphs.py
  27. 79 150
      core/database/limits.py
  28. 4 5
      core/database/postgres.py
  29. 126 5
      core/database/users.py
  30. BIN
      core/examples/supported_file_types/bmp.bmp
  31. 11 0
      core/examples/supported_file_types/csv.csv
  32. BIN
      core/examples/supported_file_types/doc.doc
  33. BIN
      core/examples/supported_file_types/docx.docx
  34. 61 0
      core/examples/supported_file_types/eml.eml
  35. BIN
      core/examples/supported_file_types/epub.epub
  36. BIN
      core/examples/supported_file_types/heic.heic
  37. 69 0
      core/examples/supported_file_types/html.html
  38. BIN
      core/examples/supported_file_types/jpeg.jpeg
  39. BIN
      core/examples/supported_file_types/jpg.jpg
  40. 58 0
      core/examples/supported_file_types/json.json
  41. 310 0
      core/examples/supported_file_types/md.md
  42. BIN
      core/examples/supported_file_types/msg.msg
  43. BIN
      core/examples/supported_file_types/odt.odt
  44. 153 0
      core/examples/supported_file_types/org.org
  45. 50 0
      core/examples/supported_file_types/p7s.p7s
  46. BIN
      core/examples/supported_file_types/pdf.pdf
  47. BIN
      core/examples/supported_file_types/png.png
  48. BIN
      core/examples/supported_file_types/ppt.ppt
  49. BIN
      core/examples/supported_file_types/pptx.pptx
  50. 86 0
      core/examples/supported_file_types/rst.rst
  51. 5 0
      core/examples/supported_file_types/rtf.rtf
  52. BIN
      core/examples/supported_file_types/tiff.tiff
  53. 11 0
      core/examples/supported_file_types/tsv.tsv
  54. 21 0
      core/examples/supported_file_types/txt.txt
  55. BIN
      core/examples/supported_file_types/xls.xls
  56. BIN
      core/examples/supported_file_types/xlsx.xlsx
  57. 17 0
      core/examples/supported_file_types/xml.xml
  58. 1 1
      core/main/__init__.py
  59. 49 15
      core/main/abstractions.py
  60. 28 22
      core/main/api/v3/base_router.py
  61. 25 28
      core/main/api/v3/chunks_router.py
  62. 72 82
      core/main/api/v3/collections_router.py
  63. 148 40
      core/main/api/v3/conversations_router.py
  64. 215 140
      core/main/api/v3/documents_router.py
  65. 93 73
      core/main/api/v3/graph_router.py
  66. 16 19
      core/main/api/v3/indices_router.py
  67. 13 16
      core/main/api/v3/logs_router.py
  68. 22 24
      core/main/api/v3/prompts_router.py
  69. 29 39
      core/main/api/v3/retrieval_router.py
  70. 11 97
      core/main/api/v3/system_router.py
  71. 271 50
      core/main/api/v3/users_router.py
  72. 21 38
      core/main/assembly/builder.py
  73. 111 111
      core/main/assembly/factory.py
  74. 16 16
      core/main/config.py
  75. 8 8
      core/main/orchestration/hatchet/kg_workflow.py
  76. 2 2
      core/main/orchestration/simple/kg_workflow.py
  77. 2 2
      core/main/services/__init__.py
  78. 46 8
      core/main/services/auth_service.py
  79. 1081 0
      core/main/services/graph_service.py
  80. 169 238
      core/main/services/management_service.py
  81. 7 13
      core/main/services/retrieval_service.py
  82. 14 0
      core/parsers/__init__.py
  83. 12 5
      core/parsers/media/__init__.py
  84. 0 1
      core/parsers/media/audio_parser.py
  85. 74 0
      core/parsers/media/bmp_parser.py
  86. 115 0
      core/parsers/media/doc_parser.py
  87. 75 25
      core/parsers/media/img_parser.py
  88. 65 0
      core/parsers/media/odt_parser.py
  89. 9 4
      core/parsers/media/pdf_parser.py
  90. 64 11
      core/parsers/media/ppt_parser.py
  91. 43 0
      core/parsers/media/pptx_parser.py
  92. 52 0
      core/parsers/media/rtf_parser.py
  93. 18 0
      core/parsers/structured/__init__.py
  94. 63 0
      core/parsers/structured/eml_parser.py
  95. 128 0
      core/parsers/structured/epub_parser.py
  96. 75 0
      core/parsers/structured/msg_parser.py
  97. 79 0
      core/parsers/structured/org_parser.py
  98. 184 0
      core/parsers/structured/p7s_parser.py
  99. 65 0
      core/parsers/structured/rst_parser.py
  100. 116 0
      core/parsers/structured/tiff_parser.py

+ 0 - 32
cli/commands/system.py

@@ -38,38 +38,6 @@ async def health(ctx):
     click.echo(json.dumps(response, indent=2))
     click.echo(json.dumps(response, indent=2))
 
 
 
 
-@system.command()
-@click.option("--run-type-filter", help="Filter for log types")
-@click.option(
-    "--offset", default=None, help="Pagination offset. Default is None."
-)
-@click.option(
-    "--limit", default=None, help="Pagination limit. Defaults to 100."
-)
-@pass_context
-async def logs(ctx, run_type_filter, offset, limit):
-    """Retrieve logs with optional type filter."""
-    client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.logs(
-            run_type_filter=run_type_filter,
-            offset=offset,
-            limit=limit,
-        )
-
-    for log in response["results"]:
-        click.echo(f"Run ID: {log['run_id']}")
-        click.echo(f"Run Type: {log['run_type']}")
-        click.echo(f"Timestamp: {log['timestamp']}")
-        click.echo(f"User ID: {log['user_id']}")
-        click.echo("Entries:")
-        for entry in log["entries"]:
-            click.echo(f"  - {entry['key']}: {entry['value'][:100]}")
-        click.echo("---")
-
-    click.echo(f"Total runs: {len(response['results'])}")
-
-
 @system.command()
 @system.command()
 @pass_context
 @pass_context
 async def settings(ctx):
 async def settings(ctx):

+ 0 - 1
cli/main.py

@@ -48,7 +48,6 @@ def main():
         pass
         pass
     except Exception as e:
     except Exception as e:
         # Handle other exceptions if needed
         # Handle other exceptions if needed
-        print("CLI error: An error occurred")
         raise e
         raise e
     finally:
     finally:
         # Ensure all events are flushed before exiting
         # Ensure all events are flushed before exiting

+ 0 - 1
cli/utils/docker_utils.py

@@ -373,7 +373,6 @@ def build_docker_command(
         else:
         else:
             base_command = f"docker compose -f {compose_files['full_scale']}"
             base_command = f"docker compose -f {compose_files['full_scale']}"
 
 
-    print("base_command = ", base_command)
     base_command += (
     base_command += (
         f" --project-name {project_name or ('r2r-full' if full else 'r2r')}"
         f" --project-name {project_name or ('r2r-full' if full else 'r2r')}"
     )
     )

+ 124 - 0
compose.yaml

@@ -0,0 +1,124 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    ports:
+      - "${R2R_PORT:-7272}:${R2R_PORT:-7272}"
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+    extra_hosts:
+      - host.docker.internal:host-gateway
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"

+ 24 - 9
core/__init__.py

@@ -104,8 +104,6 @@ __all__ = [
     "TokenResponse",
     "TokenResponse",
     "User",
     "User",
     ## LOGGING
     ## LOGGING
-    # Basic types
-    "RunType",
     # Run Manager
     # Run Manager
     "RunManager",
     "RunManager",
     "manage_run",
     "manage_run",
@@ -133,6 +131,7 @@ __all__ = [
     "EmailConfig",
     "EmailConfig",
     "EmailProvider",
     "EmailProvider",
     # Database providers
     # Database providers
+    "LimitSettings",
     "DatabaseConfig",
     "DatabaseConfig",
     "DatabaseProvider",
     "DatabaseProvider",
     # Embedding provider
     # Embedding provider
@@ -174,20 +173,34 @@ __all__ = [
     "IngestionService",
     "IngestionService",
     "ManagementService",
     "ManagementService",
     "RetrievalService",
     "RetrievalService",
-    "KgService",
+    "GraphService",
     ## PARSERS
     ## PARSERS
     # Media parsers
     # Media parsers
     "AudioParser",
     "AudioParser",
+    "BMPParser",
+    "DOCParser",
     "DOCXParser",
     "DOCXParser",
     "ImageParser",
     "ImageParser",
+    "ODTParser",
     "VLMPDFParser",
     "VLMPDFParser",
     "BasicPDFParser",
     "BasicPDFParser",
     "PDFParserUnstructured",
     "PDFParserUnstructured",
     "PPTParser",
     "PPTParser",
+    "PPTXParser",
+    "RTFParser",
     # Structured parsers
     # Structured parsers
     "CSVParser",
     "CSVParser",
     "CSVParserAdvanced",
     "CSVParserAdvanced",
+    "EMLParser",
+    "EPUBParser",
     "JSONParser",
     "JSONParser",
+    "MSGParser",
+    "ORGParser",
+    "P7SParser",
+    "RSTParser",
+    "TIFFParser",
+    "TSVParser",
+    "XLSParser",
     "XLSXParser",
     "XLSXParser",
     "XLSXParserAdvanced",
     "XLSXParserAdvanced",
     # Text parsers
     # Text parsers
@@ -200,22 +213,24 @@ __all__ = [
     ## PIPES
     ## PIPES
     "SearchPipe",
     "SearchPipe",
     "EmbeddingPipe",
     "EmbeddingPipe",
-    "KGExtractionPipe",
+    "GraphExtractionPipe",
     "ParsingPipe",
     "ParsingPipe",
     "QueryTransformPipe",
     "QueryTransformPipe",
-    "SearchRAGPipe",
-    "StreamingSearchRAGPipe",
+    "RAGPipe",
+    "StreamingRAGPipe",
     "VectorSearchPipe",
     "VectorSearchPipe",
     "VectorStoragePipe",
     "VectorStoragePipe",
-    "KGStoragePipe",
+    "GraphStoragePipe",
     "MultiSearchPipe",
     "MultiSearchPipe",
     ## PROVIDERS
     ## PROVIDERS
     # Auth
     # Auth
     "SupabaseAuthProvider",
     "SupabaseAuthProvider",
     "R2RAuthProvider",
     "R2RAuthProvider",
     # Crypto
     # Crypto
-    "BCryptProvider",
-    "BCryptConfig",
+    "BCryptCryptoProvider",
+    "BcryptCryptoConfig",
+    "NaClCryptoConfig",
+    "NaClCryptoProvider",
     # Database
     # Database
     "PostgresDatabaseProvider",
     "PostgresDatabaseProvider",
     # Embeddings
     # Embeddings

+ 1 - 2
core/base/__init__.py

@@ -77,8 +77,6 @@ __all__ = [
     "TokenResponse",
     "TokenResponse",
     "User",
     "User",
     ## LOGGING
     ## LOGGING
-    # Basic types
-    "RunType",
     # Run Manager
     # Run Manager
     "RunManager",
     "RunManager",
     "manage_run",
     "manage_run",
@@ -106,6 +104,7 @@ __all__ = [
     "EmailConfig",
     "EmailConfig",
     "EmailProvider",
     "EmailProvider",
     # Database providers
     # Database providers
+    "LimitSettings",
     "DatabaseConfig",
     "DatabaseConfig",
     "DatabaseProvider",
     "DatabaseProvider",
     "Handler",
     "Handler",

+ 12 - 9
core/base/api/models/__init__.py

@@ -10,15 +10,7 @@ from shared.api.models.base import (
     WrappedBooleanResponse,
     WrappedBooleanResponse,
     WrappedGenericMessageResponse,
     WrappedGenericMessageResponse,
 )
 )
-from shared.api.models.ingestion.responses import (
-    IngestionResponse,
-    UpdateResponse,
-    WrappedIngestionResponse,
-    WrappedListVectorIndicesResponse,
-    WrappedMetadataUpdateResponse,
-    WrappedUpdateResponse,
-)
-from shared.api.models.kg.responses import (  # TODO: Need to review anything above this
+from shared.api.models.graph.responses import (  # TODO: Need to review anything above this
     Community,
     Community,
     Entity,
     Entity,
     GraphResponse,
     GraphResponse,
@@ -32,6 +24,14 @@ from shared.api.models.kg.responses import (  # TODO: Need to review anything ab
     WrappedRelationshipResponse,
     WrappedRelationshipResponse,
     WrappedRelationshipsResponse,
     WrappedRelationshipsResponse,
 )
 )
+from shared.api.models.ingestion.responses import (
+    IngestionResponse,
+    UpdateResponse,
+    WrappedIngestionResponse,
+    WrappedListVectorIndicesResponse,
+    WrappedMetadataUpdateResponse,
+    WrappedUpdateResponse,
+)
 from shared.api.models.management.responses import (  # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
 from shared.api.models.management.responses import (  # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
     AnalyticsResponse,
     AnalyticsResponse,
     ChunkResponse,
     ChunkResponse,
@@ -43,6 +43,8 @@ from shared.api.models.management.responses import (  # Document Responses; Prom
     SettingsResponse,
     SettingsResponse,
     User,
     User,
     WrappedAnalyticsResponse,
     WrappedAnalyticsResponse,
+    WrappedAPIKeyResponse,
+    WrappedAPIKeysResponse,
     WrappedChunkResponse,
     WrappedChunkResponse,
     WrappedChunksResponse,
     WrappedChunksResponse,
     WrappedCollectionResponse,
     WrappedCollectionResponse,
@@ -138,6 +140,7 @@ __all__ = [
     "User",
     "User",
     "WrappedUserResponse",
     "WrappedUserResponse",
     "WrappedUsersResponse",
     "WrappedUsersResponse",
+    "WrappedAPIKeyResponse",
     # Base Responses
     # Base Responses
     "PaginatedR2RResult",
     "PaginatedR2RResult",
     "R2RResults",
     "R2RResults",

+ 1 - 2
core/base/logger/__init__.py

@@ -1,9 +1,8 @@
-from .base import RunInfoLog, RunType
+from .base import RunInfoLog
 from .run_manager import RunManager, manage_run
 from .run_manager import RunManager, manage_run
 
 
 __all__ = [
 __all__ = [
     # Basic types
     # Basic types
-    "RunType",
     "RunInfoLog",
     "RunInfoLog",
     # Run Manager
     # Run Manager
     "RunManager",
     "RunManager",

+ 0 - 12
core/base/logger/base.py

@@ -16,17 +16,5 @@ logger = logging.getLogger()
 
 
 class RunInfoLog(BaseModel):
 class RunInfoLog(BaseModel):
     run_id: UUID
     run_id: UUID
-    run_type: str
     timestamp: datetime
     timestamp: datetime
     user_id: UUID
     user_id: UUID
-
-
-class RunType(str, Enum):
-    """Enumeration of the different types of runs."""
-
-    RETRIEVAL = "RETRIEVAL"
-    MANAGEMENT = "MANAGEMENT"
-    INGESTION = "INGESTION"
-    AUTH = "AUTH"
-    UNSPECIFIED = "UNSPECIFIED"
-    KG = "KG"

+ 2 - 6
core/base/logger/run_manager.py

@@ -5,7 +5,6 @@ from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from core.base.api.models import User
 from core.base.api.models import User
-from core.base.logger.base import RunType
 from core.base.utils import generate_id
 from core.base.utils import generate_id
 
 
 run_id_var = contextvars.ContextVar("run_id", default=generate_id())
 run_id_var = contextvars.ContextVar("run_id", default=generate_id())
@@ -15,12 +14,11 @@ class RunManager:
     def __init__(self):
     def __init__(self):
         self.run_info: dict[UUID, dict] = {}
         self.run_info: dict[UUID, dict] = {}
 
 
-    async def set_run_info(self, run_type: str, run_id: Optional[UUID] = None):
+    async def set_run_info(self, run_id: Optional[UUID] = None):
         run_id = run_id or run_id_var.get()
         run_id = run_id or run_id_var.get()
         if run_id is None:
         if run_id is None:
             run_id = generate_id()
             run_id = generate_id()
             token = run_id_var.set(run_id)
             token = run_id_var.set(run_id)
-            self.run_info[run_id] = {"run_type": run_type}
         else:
         else:
             token = run_id_var.set(run_id)
             token = run_id_var.set(run_id)
         return run_id, token
         return run_id, token
@@ -31,7 +29,6 @@ class RunManager:
 
 
     async def log_run_info(
     async def log_run_info(
         self,
         self,
-        run_type: RunType,
         user: User,
         user: User,
     ):
     ):
         if asyncio.iscoroutine(user):
         if asyncio.iscoroutine(user):
@@ -47,10 +44,9 @@ class RunManager:
 @asynccontextmanager
 @asynccontextmanager
 async def manage_run(
 async def manage_run(
     run_manager: RunManager,
     run_manager: RunManager,
-    run_type: RunType = RunType.UNSPECIFIED,
     run_id: Optional[UUID] = None,
     run_id: Optional[UUID] = None,
 ):
 ):
-    run_id, token = await run_manager.set_run_info(run_type, run_id)
+    run_id, token = await run_manager.set_run_info(run_id)
     try:
     try:
         yield run_id
         yield run_id
     finally:
     finally:

+ 1 - 2
core/base/pipes/base_pipe.py

@@ -7,7 +7,6 @@ from uuid import UUID
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
-from core.base.logger.base import RunType
 from core.base.logger.run_manager import RunManager, manage_run
 from core.base.logger.run_manager import RunManager, manage_run
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -108,7 +107,7 @@ class AsyncPipe(Generic[T]):
         state = state or AsyncState()
         state = state or AsyncState()
 
 
         async def wrapped_run() -> AsyncGenerator[Any, None]:
         async def wrapped_run() -> AsyncGenerator[Any, None]:
-            async with manage_run(run_manager, RunType.UNSPECIFIED) as run_id:  # type: ignore
+            async with manage_run(run_manager) as run_id:  # type: ignore
                 async for result in self._run_logic(  # type: ignore
                 async for result in self._run_logic(  # type: ignore
                     input, state, run_id, *args, **kwargs  # type: ignore
                     input, state, run_id, *args, **kwargs  # type: ignore
                 ):
                 ):

+ 2 - 0
core/base/providers/__init__.py

@@ -6,6 +6,7 @@ from .database import (
     DatabaseConnectionManager,
     DatabaseConnectionManager,
     DatabaseProvider,
     DatabaseProvider,
     Handler,
     Handler,
+    LimitSettings,
     PostgresConfigurationSettings,
     PostgresConfigurationSettings,
 )
 )
 from .email import EmailConfig, EmailProvider
 from .email import EmailConfig, EmailProvider
@@ -41,6 +42,7 @@ __all__ = [
     # Database providers
     # Database providers
     "DatabaseConnectionManager",
     "DatabaseConnectionManager",
     "DatabaseConfig",
     "DatabaseConfig",
+    "LimitSettings",
     "PostgresConfigurationSettings",
     "PostgresConfigurationSettings",
     "DatabaseProvider",
     "DatabaseProvider",
     "Handler",
     "Handler",

+ 89 - 16
core/base/providers/auth.py

@@ -3,7 +3,11 @@ from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, Optional
 from typing import TYPE_CHECKING, Optional
 
 
 from fastapi import Security
 from fastapi import Security
-from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastapi.security import (
+    APIKeyHeader,
+    HTTPAuthorizationCredentials,
+    HTTPBearer,
+)
 
 
 from ..abstractions import R2RException, Token, TokenData
 from ..abstractions import R2RException, Token, TokenData
 from ..api.models import User
 from ..api.models import User
@@ -18,6 +22,8 @@ logger = logging.getLogger()
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from core.database import PostgresDatabaseProvider
     from core.database import PostgresDatabaseProvider
 
 
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
 
 
 class AuthConfig(ProviderConfig):
 class AuthConfig(ProviderConfig):
     secret_key: Optional[str] = None
     secret_key: Optional[str] = None
@@ -110,25 +116,92 @@ class AuthProvider(Provider, ABC):
     ) -> dict[str, Token]:
     ) -> dict[str, Token]:
         pass
         pass
 
 
-    async def auth_wrapper(
-        self, auth: Optional[HTTPAuthorizationCredentials] = Security(security)
-    ) -> User:
-        if not self.config.require_authentication and auth is None:
-            return await self._get_default_admin_user()
-
-        if auth is None:
+    def auth_wrapper(
+        self,
+        public: bool = False,
+    ):
+        async def _auth_wrapper(
+            auth: Optional[HTTPAuthorizationCredentials] = Security(
+                self.security
+            ),
+            api_key: Optional[str] = Security(api_key_header),
+        ) -> User:
+            # If authentication is not required and no credentials are provided, return the default admin user
+            if (
+                ((not self.config.require_authentication) or public)
+                and auth is None
+                and api_key is None
+            ):
+                return await self._get_default_admin_user()
+            if not auth and not api_key:
+                raise R2RException(
+                    message="No credentials provided",
+                    status_code=401,
+                )
+            if auth and api_key:
+                raise R2RException(
+                    message="Cannot have both Bearer token and API key",
+                    status_code=400,
+                )
+            # 1. Try JWT if `auth` is present (Bearer token)
+            if auth is not None:
+                credentials = auth.credentials
+                try:
+                    token_data = await self.decode_token(credentials)
+                    user = await self.database_provider.users_handler.get_user_by_email(
+                        token_data.email
+                    )
+                    if user is not None:
+                        return user
+                except R2RException:
+                    # JWT decoding failed for logical reasons (invalid token)
+                    pass
+                except Exception as e:
+                    # JWT decoding failed unexpectedly, log and continue
+                    logger.debug(f"JWT verification failed: {e}")
+
+                # 2. If JWT failed, try API key from Bearer token
+                # Expected format: key_id.raw_api_key
+                if "." in credentials:
+                    key_id, raw_api_key = credentials.split(".", 1)
+                    api_key_record = await self.database_provider.users_handler.get_api_key_record(
+                        key_id
+                    )
+                    if api_key_record is not None:
+                        hashed_key = api_key_record["hashed_key"]
+                        if self.crypto_provider.verify_api_key(
+                            raw_api_key, hashed_key
+                        ):
+                            user = await self.database_provider.users_handler.get_user_by_id(
+                                api_key_record["user_id"]
+                            )
+                            if user is not None and user.is_active:
+                                return user
+
+            # 3. If no Bearer token worked, try the X-API-Key header
+            if api_key is not None and "." in api_key:
+                key_id, raw_api_key = api_key.split(".", 1)
+                api_key_record = await self.database_provider.users_handler.get_api_key_record(
+                    key_id
+                )
+                if api_key_record is not None:
+                    hashed_key = api_key_record["hashed_key"]
+                    if self.crypto_provider.verify_api_key(
+                        raw_api_key, hashed_key
+                    ):
+                        user = await self.database_provider.users_handler.get_user_by_id(
+                            api_key_record["user_id"]
+                        )
+                        if user is not None and user.is_active:
+                            return user
+
+            # If we reach here, both JWT and API key auth failed
             raise R2RException(
             raise R2RException(
-                message="Authentication required.",
+                message="Invalid token or API key",
                 status_code=401,
                 status_code=401,
             )
             )
 
 
-        try:
-            return await self.user(auth.credentials)
-        except Exception as e:
-            raise R2RException(
-                message=f"Error '{e}' occurred during authentication.",
-                status_code=404,
-            )
+        return _auth_wrapper
 
 
     @abstractmethod
     @abstractmethod
     async def change_password(
     async def change_password(

+ 3 - 0
core/base/providers/base.py

@@ -8,6 +8,9 @@ from ..abstractions import R2RSerializable
 
 
 class AppConfig(R2RSerializable):
 class AppConfig(R2RSerializable):
     project_name: Optional[str] = None
     project_name: Optional[str] = None
+    default_max_documents_per_user: Optional[int] = 100
+    default_max_chunks_per_user: Optional[int] = 100_000
+    default_max_collections_per_user: Optional[int] = 10
 
 
     @classmethod
     @classmethod
     def create(cls, *args, **kwargs):
     def create(cls, *args, **kwargs):

+ 83 - 2
core/base/providers/crypto.py

@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Optional
+from datetime import datetime
+from typing import Optional, Tuple
 
 
 from .base import Provider, ProviderConfig
 from .base import Provider, ProviderConfig
 
 
@@ -9,7 +10,7 @@ class CryptoConfig(ProviderConfig):
 
 
     @property
     @property
     def supported_providers(self) -> list[str]:
     def supported_providers(self) -> list[str]:
-        return ["bcrypt"]  # Add other crypto providers as needed
+        return ["bcrypt", "nacl"]  # Add other crypto providers as needed
 
 
     def validate_config(self) -> None:
     def validate_config(self) -> None:
         if self.provider not in self.supported_providers:
         if self.provider not in self.supported_providers:
@@ -26,14 +27,94 @@ class CryptoProvider(Provider, ABC):
 
 
     @abstractmethod
     @abstractmethod
     def get_password_hash(self, password: str) -> str:
     def get_password_hash(self, password: str) -> str:
+        """Hash a plaintext password using a secure password hashing algorithm (e.g., Argon2i)."""
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
     def verify_password(
     def verify_password(
         self, plain_password: str, hashed_password: str
         self, plain_password: str, hashed_password: str
     ) -> bool:
     ) -> bool:
+        """Verify that a plaintext password matches the given hashed password."""
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
     def generate_verification_code(self, length: int = 32) -> str:
     def generate_verification_code(self, length: int = 32) -> str:
+        """Generate a random code for email verification or reset tokens."""
+        pass
+
+    @abstractmethod
+    def generate_signing_keypair(self) -> Tuple[str, str, str]:
+        """
+        Generate a new Ed25519 signing keypair for request signing.
+
+        Returns:
+            A tuple of (key_id, private_key, public_key).
+            - key_id: A unique identifier for this keypair.
+            - private_key: Base64 encoded Ed25519 private key.
+            - public_key: Base64 encoded Ed25519 public key.
+        """
+        pass
+
+    @abstractmethod
+    def sign_request(self, private_key: str, data: str) -> str:
+        """Sign request data with an Ed25519 private key, returning the signature."""
+        pass
+
+    @abstractmethod
+    def verify_request_signature(
+        self, public_key: str, signature: str, data: str
+    ) -> bool:
+        """Verify a request signature using the corresponding Ed25519 public key."""
+        pass
+
+    @abstractmethod
+    def generate_api_key(self) -> Tuple[str, str]:
+        """
+        Generate a new API key for a user.
+
+        Returns:
+            A tuple (key_id, raw_api_key):
+            - key_id: A unique identifier for the API key.
+            - raw_api_key: The plaintext API key to provide to the user.
+        """
+        pass
+
+    @abstractmethod
+    def hash_api_key(self, raw_api_key: str) -> str:
+        """
+        Hash a raw API key for secure storage in the database.
+        Use strong parameters suitable for long-term secrets.
+        """
+        pass
+
+    @abstractmethod
+    def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
+        """Verify that a provided API key matches the stored hashed version."""
+        pass
+
+    @abstractmethod
+    def generate_secure_token(self, data: dict, expiry: datetime) -> str:
+        """
+        Generate a secure, signed token (e.g., JWT) embedding claims.
+
+        Args:
+            data: The claims to include in the token.
+            expiry: A datetime at which the token expires.
+
+        Returns:
+            A JWT string signed with a secret key.
+        """
+        pass
+
+    @abstractmethod
+    def verify_secure_token(self, token: str) -> Optional[dict]:
+        """
+        Verify a secure token (e.g., JWT).
+
+        Args:
+            token: The token string to verify.
+
+        Returns:
+            The token payload if valid, otherwise None.
+        """
         pass
         pass

+ 49 - 0
core/base/providers/database.py

@@ -134,6 +134,21 @@ class PostgresConfigurationSettings(BaseModel):
     work_mem: Optional[int] = 4096
     work_mem: Optional[int] = 4096
 
 
 
 
+class LimitSettings(BaseModel):
+    global_per_min: Optional[int] = None
+    route_per_min: Optional[int] = None
+    monthly_limit: Optional[int] = None
+
+    def merge_with_defaults(
+        self, defaults: "LimitSettings"
+    ) -> "LimitSettings":
+        return LimitSettings(
+            global_per_min=self.global_per_min or defaults.global_per_min,
+            route_per_min=self.route_per_min or defaults.route_per_min,
+            monthly_limit=self.monthly_limit or defaults.monthly_limit,
+        )
+
+
 class DatabaseConfig(ProviderConfig):
 class DatabaseConfig(ProviderConfig):
     """A base database configuration class"""
     """A base database configuration class"""
 
 
@@ -163,6 +178,13 @@ class DatabaseConfig(ProviderConfig):
     )
     )
     graph_search_settings: GraphSearchSettings = GraphSearchSettings()
     graph_search_settings: GraphSearchSettings = GraphSearchSettings()
 
 
+    # Rate limits
+    limits: LimitSettings = LimitSettings(
+        global_per_min=60, route_per_min=20, monthly_limit=10000
+    )
+    route_limits: dict[str, LimitSettings] = {}
+    user_limits: dict[UUID, LimitSettings] = {}
+
     def __post_init__(self):
     def __post_init__(self):
         self.validate_config()
         self.validate_config()
         # Capture additional fields
         # Capture additional fields
@@ -177,6 +199,33 @@ class DatabaseConfig(ProviderConfig):
     def supported_providers(self) -> list[str]:
     def supported_providers(self) -> list[str]:
         return ["postgres"]
         return ["postgres"]
 
 
+    @classmethod
+    def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+        instance = super().from_dict(
+            data
+        )  # or some logic to create the base instance
+
+        limits_data = data.get("limits", {})
+        default_limits = LimitSettings(
+            global_per_min=limits_data.get("global_per_min", 60),
+            route_per_min=limits_data.get("route_per_min", 20),
+            monthly_limit=limits_data.get("monthly_limit", 10000),
+        )
+
+        instance.limits = default_limits
+
+        route_limits_data = limits_data.get("routes", {})
+        for route_str, route_cfg in route_limits_data.items():
+            instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+        # user_limits parsing if needed:
+        # user_limits_data = limits_data.get("users", {})
+        # for user_str, user_cfg in user_limits_data.items():
+        #     user_id = UUID(user_str)
+        #     instance.user_limits[user_id] = LimitSettings(**user_cfg)
+
+        return instance
+
 
 
 class DatabaseProvider(Provider):
 class DatabaseProvider(Provider):
     connection_manager: DatabaseConnectionManager
     connection_manager: DatabaseConnectionManager

+ 8 - 3
core/base/providers/ingestion.py

@@ -1,9 +1,9 @@
 import logging
 import logging
 from abc import ABC
 from abc import ABC
 from enum import Enum
 from enum import Enum
-from typing import TYPE_CHECKING, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar
 
 
-from pydantic import BaseModel, Field
+from pydantic import Field
 
 
 from core.base.abstractions import ChunkEnrichmentSettings
 from core.base.abstractions import ChunkEnrichmentSettings
 
 
@@ -21,6 +21,7 @@ class IngestionConfig(ProviderConfig):
         "app": AppConfig(),
         "app": AppConfig(),
         "provider": "r2r",
         "provider": "r2r",
         "excluded_parsers": ["mp4"],
         "excluded_parsers": ["mp4"],
+        "chunking_strategy": "recursive",
         "chunk_enrichment_settings": ChunkEnrichmentSettings(),
         "chunk_enrichment_settings": ChunkEnrichmentSettings(),
         "extra_parsers": {},
         "extra_parsers": {},
         "audio_transcription_model": "openai/whisper-1",
         "audio_transcription_model": "openai/whisper-1",
@@ -43,12 +44,15 @@ class IngestionConfig(ProviderConfig):
     excluded_parsers: list[str] = Field(
     excluded_parsers: list[str] = Field(
         default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
         default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
     )
     )
+    chunking_strategy: str = Field(
+        default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
+    )
     chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
     chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
         default_factory=lambda: IngestionConfig._defaults[
         default_factory=lambda: IngestionConfig._defaults[
             "chunk_enrichment_settings"
             "chunk_enrichment_settings"
         ]
         ]
     )
     )
-    extra_parsers: dict[str, str] = Field(
+    extra_parsers: dict[str, Any] = Field(
         default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
         default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
     )
     )
     audio_transcription_model: str = Field(
     audio_transcription_model: str = Field(
@@ -157,6 +161,7 @@ class IngestionConfig(ProviderConfig):
         json_schema_extra = {
         json_schema_extra = {
             "provider": "r2r",
             "provider": "r2r",
             "excluded_parsers": ["mp4"],
             "excluded_parsers": ["mp4"],
+            "chunking_strategy": "recursive",
             "chunk_enrichment_settings": ChunkEnrichmentSettings().dict(),
             "chunk_enrichment_settings": ChunkEnrichmentSettings().dict(),
             "extra_parsers": {},
             "extra_parsers": {},
             "audio_transcription_model": "openai/whisper-1",
             "audio_transcription_model": "openai/whisper-1",

+ 2 - 0
core/base/utils/__init__.py

@@ -4,6 +4,7 @@ from shared.utils import (
     _decorate_vector_type,
     _decorate_vector_type,
     _get_str_estimation_output,
     _get_str_estimation_output,
     decrement_version,
     decrement_version,
+    deep_update,
     format_search_results_for_llm,
     format_search_results_for_llm,
     format_search_results_for_stream,
     format_search_results_for_stream,
     generate_default_prompt_id,
     generate_default_prompt_id,
@@ -36,6 +37,7 @@ __all__ = [
     "TextSplitter",
     "TextSplitter",
     "llm_cost_per_million_tokens",
     "llm_cost_per_million_tokens",
     "validate_uuid",
     "validate_uuid",
+    "deep_update",
     "_decorate_vector_type",
     "_decorate_vector_type",
     "_get_str_estimation_output",
     "_get_str_estimation_output",
 ]
 ]

+ 2 - 3
core/configs/local_llm.toml

@@ -60,9 +60,8 @@ provider = "simple"
 [ingestion]
 [ingestion]
 vision_img_model = "ollama/llama3.2-vision"
 vision_img_model = "ollama/llama3.2-vision"
 vision_pdf_model = "ollama/llama3.2-vision"
 vision_pdf_model = "ollama/llama3.2-vision"
+chunks_for_document_summary = 16
+document_summary_model = "ollama/llama3.1"
 
 
   [ingestion.extra_parsers]
   [ingestion.extra_parsers]
     pdf = "zerox"
     pdf = "zerox"
-
-chunks_for_document_summary = 16
-document_summary_model = "ollama/llama3.1"

+ 61 - 0
core/configs/r2r_azure_with_test_limits.toml

@@ -0,0 +1,61 @@
+# A config which overrides all instances of `openai` with `azure` in the `r2r.toml` config
+[agent]
+  [agent.generation_config]
+  model = "azure/gpt-4o"
+
+[completion]
+  [completion.generation_config]
+  model = "azure/gpt-4o"
+
+
+
+[database]
+# KG settings
+batch_size = 256
+
+  [database.graph_creation_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_entity_deduplication_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_search_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.limits]
+  global_per_min = 10  # Small enough to test quickly
+  monthly_limit = 20  # Small enough to test in one run
+
+  [database.route_limits]
+  "/v3/retrieval/search" = { route_per_min = 5, monthly_limit = 10 }
+
+  [database.user_limits."47e53676-b478-5b3f-a409-234ca2164de5"]
+  global_per_min = 2
+  route_per_min = 1
+
+
+[embedding]
+provider = "litellm"
+base_model = "openai/text-embedding-3-small" # continue with `openai` for embeddings, due to server rate limit on azure
+base_dimension = 512
+
+[file]
+provider = "postgres"
+
+[ingestion]
+provider = "r2r"
+chunking_strategy = "recursive"
+chunk_size = 1_024
+chunk_overlap = 512
+excluded_parsers = ["mp4"]
+
+audio_transcription_model="azure/whisper-1"
+document_summary_model = "azure/gpt-4o-mini"
+vision_img_model = "azure/gpt-4o"
+vision_pdf_model = "azure/gpt-4o"
+
+  [ingestion.chunk_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }

+ 39 - 206
core/database/chunks.py

@@ -23,6 +23,7 @@ from core.base import (
 )
 )
 
 
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
+from .filters import apply_filters
 from .vecs.exc import ArgError, FilterError
 from .vecs.exc import ArgError, FilterError
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -83,13 +84,6 @@ class HybridSearchIntermediateResult(TypedDict):
 class PostgresChunksHandler(Handler):
 class PostgresChunksHandler(Handler):
     TABLE_NAME = VectorTableName.CHUNKS
     TABLE_NAME = VectorTableName.CHUNKS
 
 
-    COLUMN_VARS = [
-        "id",
-        "document_id",
-        "owner_id",
-        "collection_ids",
-    ]
-
     def __init__(
     def __init__(
         self,
         self,
         project_name: str,
         project_name: str,
@@ -298,6 +292,7 @@ class PostgresChunksHandler(Handler):
         ]
         ]
 
 
         params: list[str | int | bytes] = []
         params: list[str | int | bytes] = []
+
         # For binary vectors (INT1), implement two-stage search
         # For binary vectors (INT1), implement two-stage search
         if self.quantization_type == VectorQuantizationType.INT1:
         if self.quantization_type == VectorQuantizationType.INT1:
             # Convert query vector to binary format
             # Convert query vector to binary format
@@ -306,6 +301,7 @@ class PostgresChunksHandler(Handler):
             extended_limit = (
             extended_limit = (
                 search_settings.limit * 20
                 search_settings.limit * 20
             )  # Get 20x candidates for re-ranking
             )  # Get 20x candidates for re-ranking
+
             if (
             if (
                 imeasure_obj == IndexMeasure.hamming_distance
                 imeasure_obj == IndexMeasure.hamming_distance
                 or imeasure_obj == IndexMeasure.jaccard_distance
                 or imeasure_obj == IndexMeasure.jaccard_distance
@@ -331,10 +327,9 @@ class PostgresChunksHandler(Handler):
             params.append(stage1_param)
             params.append(stage1_param)
 
 
             if search_settings.filters:
             if search_settings.filters:
-                where_clause = self._build_filters(
-                    search_settings.filters, params
+                where_clause, params = apply_filters(
+                    search_settings.filters, params, mode="where_clause"
                 )
                 )
-                where_clause = f"WHERE {where_clause}"
 
 
             # First stage: Get candidates using binary search
             # First stage: Get candidates using binary search
             query = f"""
             query = f"""
@@ -371,7 +366,7 @@ class PostgresChunksHandler(Handler):
             )
             )
 
 
         else:
         else:
-            # Standard float vector handling - unchanged from original
+            # Standard float vector handling
             distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
             distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
             query_param = str(query_vector)
             query_param = str(query_vector)
 
 
@@ -385,10 +380,12 @@ class PostgresChunksHandler(Handler):
             params.append(query_param)
             params.append(query_param)
 
 
             if search_settings.filters:
             if search_settings.filters:
-                where_clause = self._build_filters(
-                    search_settings.filters, params
+                where_clause, new_params = apply_filters(
+                    search_settings.filters,
+                    params,
+                    mode="where_clause",  # Get just conditions without WHERE
                 )
                 )
-                where_clause = f"WHERE {where_clause}"
+                params = new_params
 
 
             query = f"""
             query = f"""
             SELECT {select_clause}
             SELECT {select_clause}
@@ -427,36 +424,36 @@ class PostgresChunksHandler(Handler):
         self, query_text: str, search_settings: SearchSettings
         self, query_text: str, search_settings: SearchSettings
     ) -> list[ChunkSearchResult]:
     ) -> list[ChunkSearchResult]:
 
 
-        where_clauses = []
+        conditions = []
         params: list[str | int | bytes] = [query_text]
         params: list[str | int | bytes] = [query_text]
 
 
+        conditions.append("fts @@ websearch_to_tsquery('english', $1)")
+
         if search_settings.filters:
         if search_settings.filters:
-            filters_clause = self._build_filters(
-                search_settings.filters, params
+            filter_condition, params = apply_filters(
+                search_settings.filters, params, mode="condition_only"
             )
             )
-            where_clauses.append(filters_clause)
+            if filter_condition:
+                conditions.append(filter_condition)
 
 
-        if where_clauses:
-            where_clause = (
-                "WHERE "
-                + " AND ".join(where_clauses)
-                + " AND fts @@ websearch_to_tsquery('english', $1)"
-            )
-        else:
-            where_clause = "WHERE fts @@ websearch_to_tsquery('english', $1)"
+        where_clause = "WHERE " + " AND ".join(conditions)
 
 
         query = f"""
         query = f"""
             SELECT
             SELECT
-                id, document_id, owner_id, collection_ids, text, metadata,
+                id,
+                document_id,
+                owner_id,
+                collection_ids,
+                text,
+                metadata,
                 ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
                 ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
             FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
             FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
             {where_clause}
             {where_clause}
-        """
-
-        query += f"""
             ORDER BY rank DESC
             ORDER BY rank DESC
-            OFFSET ${len(params)+1} LIMIT ${len(params)+2}
+            OFFSET ${len(params)+1}
+            LIMIT ${len(params)+2}
         """
         """
+
         params.extend(
         params.extend(
             [
             [
                 search_settings.offset,
                 search_settings.offset,
@@ -586,7 +583,9 @@ class PostgresChunksHandler(Handler):
         self, filters: dict[str, Any]
         self, filters: dict[str, Any]
     ) -> dict[str, dict[str, str]]:
     ) -> dict[str, dict[str, str]]:
         params: list[str | int | bytes] = []
         params: list[str | int | bytes] = []
-        where_clause = self._build_filters(filters, params)
+        where_clause, params = apply_filters(
+            filters, params, mode="condition_only"
+        )
 
 
         query = f"""
         query = f"""
         DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
         DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
@@ -856,157 +855,6 @@ class PostgresChunksHandler(Handler):
             raise Exception(f"Failed to create index: {e}")
             raise Exception(f"Failed to create index: {e}")
         return None
         return None
 
 
-    def _build_filters(
-        self, filters: dict, parameters: list[str | int | bytes]
-    ) -> str:
-
-        def parse_condition(key: str, value: Any) -> str:  # type: ignore
-            # nonlocal parameters
-            if key in self.COLUMN_VARS:
-                # Handle column-based filters
-                if isinstance(value, dict):
-                    op, clause = next(iter(value.items()))
-                    if op == "$eq":
-                        parameters.append(clause)
-                        return f"{key} = ${len(parameters)}"
-                    elif op == "$ne":
-                        parameters.append(clause)
-                        return f"{key} != ${len(parameters)}"
-                    elif op == "$in":
-                        parameters.append(clause)
-                        return f"{key} = ANY(${len(parameters)})"
-                    elif op == "$nin":
-                        parameters.append(clause)
-                        return f"{key} != ALL(${len(parameters)})"
-                    elif op == "$overlap":
-                        parameters.append(clause)
-                        return f"{key} && ${len(parameters)}"
-                    elif op == "$contains":
-                        parameters.append(clause)
-                        return f"{key} @> ${len(parameters)}"
-                    elif op == "$any":
-                        if key == "collection_ids":
-                            parameters.append(f"%{clause}%")
-                            return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
-                        parameters.append(clause)
-                        return f"${len(parameters)} = ANY({key})"
-                    else:
-                        raise FilterError(
-                            f"Unsupported operator for column {key}: {op}"
-                        )
-                else:
-                    # Handle direct equality
-                    parameters.append(value)
-                    return f"{key} = ${len(parameters)}"
-            else:
-                # Handle JSON-based filters
-                json_col = "metadata"
-                if key.startswith("metadata."):
-                    key = key.split("metadata.")[1]
-                if isinstance(value, dict):
-                    op, clause = next(iter(value.items()))
-                    if op not in (
-                        "$eq",
-                        "$ne",
-                        "$lt",
-                        "$lte",
-                        "$gt",
-                        "$gte",
-                        "$in",
-                        "$contains",
-                    ):
-                        raise FilterError("unknown operator")
-
-                    if op == "$eq":
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
-                        )
-                    elif op == "$ne":
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
-                        )
-                    elif op == "$lt":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
-                    elif op == "$lte":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
-                    elif op == "$gt":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
-                    elif op == "$gte":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
-                    elif op == "$in":
-                        # Ensure clause is a list
-                        if not isinstance(clause, list):
-                            raise FilterError(
-                                "argument to $in filter must be a list"
-                            )
-                        # Append the Python list as a parameter; many drivers can convert Python lists to arrays
-                        parameters.append(clause)
-                        # Cast the parameter to a text array type
-                        return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
-
-                    # elif op == "$in":
-                    #     if not isinstance(clause, list):
-                    #         raise FilterError(
-                    #             "argument to $in filter must be a list"
-                    #         )
-                    #     parameters.append(json.dumps(clause))
-                    #     return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
-                    elif op == "$contains":
-                        if isinstance(clause, (int, float, str)):
-                            clause = [clause]
-                        # Now clause is guaranteed to be a list or array-like structure.
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
-                        )
-
-                        # if not isinstance(clause, (int, str, float, list)):
-                        #     raise FilterError(
-                        #         "argument to $contains filter must be a scalar or array"
-                        #     )
-                        # parameters.append(json.dumps(clause))
-                        # return (
-                        #     f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
-                        # )
-
-        def parse_filter(filter_dict: dict) -> str:
-            filter_conditions = []
-            for key, value in filter_dict.items():
-                if key == "$and":
-                    and_conditions = [
-                        parse_filter(f) for f in value if f
-                    ]  # Skip empty dictionaries
-                    if and_conditions:
-                        filter_conditions.append(
-                            f"({' AND '.join(and_conditions)})"
-                        )
-                elif key == "$or":
-                    or_conditions = [
-                        parse_filter(f) for f in value if f
-                    ]  # Skip empty dictionaries
-                    if or_conditions:
-                        filter_conditions.append(
-                            f"({' OR '.join(or_conditions)})"
-                        )
-                else:
-                    filter_conditions.append(parse_condition(key, value))
-
-            # Check if there is only a single condition
-            if len(filter_conditions) == 1:
-                return filter_conditions[0]
-            else:
-                return " AND ".join(filter_conditions)
-
-        where_clause = parse_filter(filters)
-
-        return where_clause
-
     async def list_indices(
     async def list_indices(
         self,
         self,
         offset: int,
         offset: int,
@@ -1254,46 +1102,31 @@ class PostgresChunksHandler(Handler):
                 - total_entries: Total number of chunks matching the filters
                 - total_entries: Total number of chunks matching the filters
                 - page_info: Pagination information
                 - page_info: Pagination information
         """
         """
-        # Validate sort parameters
-        valid_sort_columns = {
-            "created_at": "metadata->>'created_at'",
-            "updated_at": "metadata->>'updated_at'",
-            "chunk_order": "metadata->>'chunk_order'",
-            "text": "text",
-        }
-
-        # Build the select clause
         vector_select = ", vec" if include_vectors else ""
         vector_select = ", vec" if include_vectors else ""
         select_clause = f"""
         select_clause = f"""
             id, document_id, owner_id, collection_ids,
             id, document_id, owner_id, collection_ids,
             text, metadata{vector_select}, COUNT(*) OVER() AS total
             text, metadata{vector_select}, COUNT(*) OVER() AS total
         """
         """
 
 
-        # Build the where clause if filters are provided
-        where_clause = ""
         params: list[str | int | bytes] = []
         params: list[str | int | bytes] = []
+        where_clause = ""
         if filters:
         if filters:
-            where_clause = self._build_filters(filters, params)
-            where_clause = f"WHERE {where_clause}"
+            where_clause, params = apply_filters(
+                filters, params, mode="where_clause"
+            )
 
 
-        # Construct the final query
         query = f"""
         query = f"""
         SELECT {select_clause}
         SELECT {select_clause}
         FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
         FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
         {where_clause}
         {where_clause}
-        LIMIT $%s
-        OFFSET $%s
+        LIMIT ${len(params) + 1}
+        OFFSET ${len(params) + 2}
         """
         """
 
 
-        # Add pagination parameters
         params.extend([limit, offset])
         params.extend([limit, offset])
-        param_indices = list(range(1, len(params) + 1))
-        formatted_query = query % tuple(param_indices)
 
 
         # Execute the query
         # Execute the query
-        results = await self.connection_manager.fetch_query(
-            formatted_query, params
-        )
+        results = await self.connection_manager.fetch_query(query, params)
 
 
         # Process results
         # Process results
         chunks = []
         chunks = []
@@ -1422,7 +1255,7 @@ class PostgresChunksHandler(Handler):
 
 
         # Add any additional filters
         # Add any additional filters
         if settings.filters:
         if settings.filters:
-            filter_clause = self._build_filters(settings.filters, params)
+            filter_clause, params = apply_filters(settings.filters, params)
             where_clauses.append(filter_clause)
             where_clauses.append(filter_clause)
 
 
         if where_clauses:
         if where_clauses:

+ 1 - 1
core/database/collections.py

@@ -7,8 +7,8 @@ from asyncpg.exceptions import UniqueViolationError
 from fastapi import HTTPException
 from fastapi import HTTPException
 
 
 from core.base import (
 from core.base import (
-    Handler,
     DatabaseConfig,
     DatabaseConfig,
+    Handler,
     KGExtractionStatus,
     KGExtractionStatus,
     R2RException,
     R2RException,
     generate_default_user_collection_id,
     generate_default_user_collection_id,

+ 189 - 111
core/database/conversations.py

@@ -1,7 +1,9 @@
 import json
 import json
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
 from uuid import UUID, uuid4
 from uuid import UUID, uuid4
 
 
+from fastapi import HTTPException
+
 from core.base import Handler, Message, R2RException
 from core.base import Handler, Message, R2RException
 from shared.api.models.management.responses import (
 from shared.api.models.management.responses import (
     ConversationResponse,
     ConversationResponse,
@@ -19,10 +21,6 @@ class PostgresConversationsHandler(Handler):
         self.connection_manager = connection_manager
         self.connection_manager = connection_manager
 
 
     async def create_tables(self):
     async def create_tables(self):
-        # Ensure the uuid_generate_v4() extension is available
-        # Depending on your environment, you may need a separate call:
-        # await self.connection_manager.execute_query("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
-
         create_conversations_query = f"""
         create_conversations_query = f"""
         CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
         CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
             id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
             id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
@@ -48,103 +46,90 @@ class PostgresConversationsHandler(Handler):
         await self.connection_manager.execute_query(create_messages_query)
         await self.connection_manager.execute_query(create_messages_query)
 
 
     async def create_conversation(
     async def create_conversation(
-        self, user_id: Optional[UUID] = None, name: Optional[str] = None
+        self,
+        user_id: Optional[UUID] = None,
+        name: Optional[str] = None,
     ) -> ConversationResponse:
     ) -> ConversationResponse:
         query = f"""
         query = f"""
             INSERT INTO {self._get_table_name("conversations")} (user_id, name)
             INSERT INTO {self._get_table_name("conversations")} (user_id, name)
             VALUES ($1, $2)
             VALUES ($1, $2)
             RETURNING id, extract(epoch from created_at) as created_at_epoch
             RETURNING id, extract(epoch from created_at) as created_at_epoch
         """
         """
-        result = await self.connection_manager.fetchrow_query(
-            query, [user_id, name]
-        )
-
-        if not result:
-            raise R2RException(
-                status_code=500, message="Failed to create conversation."
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, [user_id, name]
             )
             )
 
 
-        return ConversationResponse(
-            id=str(result["id"]),
-            created_at=result["created_at_epoch"],
-        )
-
-    async def verify_conversation_access(
-        self, conversation_id: UUID, user_id: UUID
-    ) -> bool:
-        query = f"""
-            SELECT 1 FROM {self._get_table_name("conversations")}
-            WHERE id = $1 AND (user_id IS NULL OR user_id = $2)
-        """
-        row = await self.connection_manager.fetchrow_query(
-            query, [conversation_id, user_id]
-        )
-        return row is not None
+            return ConversationResponse(
+                id=result["id"],
+                created_at=result["created_at_epoch"],
+                user_id=user_id or None,
+                name=name or None,
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to create conversation: {str(e)}",
+            ) from e
 
 
     async def get_conversations_overview(
     async def get_conversations_overview(
         self,
         self,
         offset: int,
         offset: int,
         limit: int,
         limit: int,
-        user_ids: Optional[UUID | List[UUID]] = None,
-        conversation_ids: Optional[List[UUID]] = None,
-    ) -> Dict[str, Any]:
-        # Construct conditions
+        filter_user_ids: Optional[list[UUID]] = None,
+        conversation_ids: Optional[list[UUID]] = None,
+    ) -> dict[str, Any]:
         conditions = []
         conditions = []
-        params = []
+        params: list = []
         param_index = 1
         param_index = 1
 
 
-        if user_ids is not None:
-            if isinstance(user_ids, UUID):
-                conditions.append(f"user_id = ${param_index}")
-                params.append(user_ids)
-                param_index += 1
-            else:
-                # user_ids is a list of UUIDs
-                placeholders = ", ".join(
-                    f"${i+param_index}" for i in range(len(user_ids))
+        if filter_user_ids:
+            conditions.append(
+                f"""
+                c.user_id IN (
+                    SELECT id
+                    FROM {self.project_name}.users
+                    WHERE id = ANY(${param_index})
                 )
                 )
-                conditions.append(
-                    f"user_id = ANY(ARRAY[{placeholders}]::uuid[])"
-                )
-                params.extend(user_ids)
-                param_index += len(user_ids)
-
-        if conversation_ids:
-            placeholders = ", ".join(
-                f"${i+param_index}" for i in range(len(conversation_ids))
+            """
             )
             )
-            conditions.append(f"id = ANY(ARRAY[{placeholders}]::uuid[])")
-            params.extend(conversation_ids)
-            param_index += len(conversation_ids)
-
-        where_clause = ""
-        if conditions:
-            where_clause = "WHERE " + " AND ".join(conditions)
+            params.append(filter_user_ids)
+            param_index += 1
 
 
-        limit_clause = ""
-        if limit != -1:
-            limit_clause = f"LIMIT ${param_index}"
-            params.append(limit)
+        if conversation_ids:
+            conditions.append(f"c.id = ANY(${param_index})")
+            params.append(conversation_ids)
             param_index += 1
             param_index += 1
 
 
-        offset_clause = f"OFFSET ${param_index}"
-        params.append(offset)
+        where_clause = (
+            "WHERE " + " AND ".join(conditions) if conditions else ""
+        )
 
 
         query = f"""
         query = f"""
             WITH conversation_overview AS (
             WITH conversation_overview AS (
-                SELECT id, extract(epoch from created_at) as created_at_epoch, user_id, name
-                FROM {self._get_table_name("conversations")}
+                SELECT c.id,
+                    extract(epoch from c.created_at) as created_at_epoch,
+                    c.user_id,
+                    c.name
+                FROM {self._get_table_name("conversations")} c
                 {where_clause}
                 {where_clause}
             ),
             ),
             counted_overview AS (
             counted_overview AS (
                 SELECT *,
                 SELECT *,
-                       COUNT(*) OVER() AS total_entries
+                    COUNT(*) OVER() AS total_entries
                 FROM conversation_overview
                 FROM conversation_overview
             )
             )
             SELECT * FROM counted_overview
             SELECT * FROM counted_overview
             ORDER BY created_at_epoch DESC
             ORDER BY created_at_epoch DESC
-            {limit_clause} {offset_clause}
+            OFFSET ${param_index}
         """
         """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            query += f" LIMIT ${param_index}"
+            params.append(limit)
+
         results = await self.connection_manager.fetch_query(query, params)
         results = await self.connection_manager.fetch_query(query, params)
 
 
         if not results:
         if not results:
@@ -224,56 +209,70 @@ class PostgresConversationsHandler(Handler):
                 status_code=500, message="Failed to insert message."
                 status_code=500, message="Failed to insert message."
             )
             )
 
 
-        return MessageResponse(id=str(message_id), message=content)
+        return MessageResponse(id=message_id, message=content)
 
 
     async def edit_message(
     async def edit_message(
         self,
         self,
         message_id: UUID,
         message_id: UUID,
-        new_content: str,
-        additional_metadata: dict = {},
-    ) -> Dict[str, Any]:
+        new_content: str | None = None,
+        additional_metadata: dict | None = None,
+    ) -> dict[str, Any]:
         # Get the original message
         # Get the original message
         query = f"""
         query = f"""
-            SELECT conversation_id, parent_id, content, metadata
+            SELECT conversation_id, parent_id, content, metadata, created_at
             FROM {self._get_table_name("messages")}
             FROM {self._get_table_name("messages")}
             WHERE id = $1
             WHERE id = $1
         """
         """
         row = await self.connection_manager.fetchrow_query(query, [message_id])
         row = await self.connection_manager.fetchrow_query(query, [message_id])
         if not row:
         if not row:
             raise R2RException(
             raise R2RException(
-                status_code=404, message=f"Message {message_id} not found."
+                status_code=404,
+                message=f"Message {message_id} not found.",
             )
             )
 
 
         old_content = json.loads(row["content"])
         old_content = json.loads(row["content"])
         old_metadata = json.loads(row["metadata"])
         old_metadata = json.loads(row["metadata"])
 
 
-        # Update the content
-        old_message = Message(**old_content)
-        edited_message = Message(
-            role=old_message.role,
-            content=new_content,
-            name=old_message.name,
-            function_call=old_message.function_call,
-            tool_calls=old_message.tool_calls,
-        )
-
-        # Merge metadata and mark edited
-        new_metadata = {**old_metadata, **additional_metadata, "edited": True}
+        if new_content is not None:
+            old_message = Message(**old_content)
+            edited_message = Message(
+                role=old_message.role,
+                content=new_content,
+                name=old_message.name,
+                function_call=old_message.function_call,
+                tool_calls=old_message.tool_calls,
+            )
+            content_to_save = edited_message.model_dump()
+        else:
+            content_to_save = old_content
+
+        additional_metadata = additional_metadata or {}
+
+        new_metadata = {
+            **old_metadata,
+            **additional_metadata,
+            "edited": (
+                True
+                if new_content is not None
+                else old_metadata.get("edited", False)
+            ),
+        }
 
 
-        # Instead of branching, we'll simply replace the message content and metadata:
-        # NOTE: If you prefer versioning or forking behavior, you'd add a new message.
-        # For simplicity, we just edit the existing message.
+        # Update message without changing the timestamp
         update_query = f"""
         update_query = f"""
             UPDATE {self._get_table_name("messages")}
             UPDATE {self._get_table_name("messages")}
-            SET content = $1::jsonb, metadata = $2::jsonb, created_at = NOW()
-            WHERE id = $3
+            SET content = $1::jsonb,
+                metadata = $2::jsonb,
+                created_at = $3
+            WHERE id = $4
             RETURNING id
             RETURNING id
         """
         """
         updated = await self.connection_manager.fetchrow_query(
         updated = await self.connection_manager.fetchrow_query(
             update_query,
             update_query,
             [
             [
-                json.dumps(edited_message.model_dump()),
+                json.dumps(content_to_save),
                 json.dumps(new_metadata),
                 json.dumps(new_metadata),
+                row["created_at"],
                 message_id,
                 message_id,
             ],
             ],
         )
         )
@@ -284,7 +283,11 @@ class PostgresConversationsHandler(Handler):
 
 
         return {
         return {
             "id": str(message_id),
             "id": str(message_id),
-            "message": edited_message,
+            "message": (
+                Message(**content_to_save)
+                if isinstance(content_to_save, dict)
+                else content_to_save
+            ),
             "metadata": new_metadata,
             "metadata": new_metadata,
         }
         }
 
 
@@ -302,7 +305,7 @@ class PostgresConversationsHandler(Handler):
                 status_code=404, message=f"Message {message_id} not found."
                 status_code=404, message=f"Message {message_id} not found."
             )
             )
 
 
-        current_metadata = row["metadata"] or {}
+        current_metadata = json.loads(row["metadata"]) or {}
         updated_metadata = {**current_metadata, **metadata}
         updated_metadata = {**current_metadata, **metadata}
 
 
         update_query = f"""
         update_query = f"""
@@ -311,17 +314,37 @@ class PostgresConversationsHandler(Handler):
             WHERE id = $2
             WHERE id = $2
         """
         """
         await self.connection_manager.execute_query(
         await self.connection_manager.execute_query(
-            update_query, [updated_metadata, message_id]
+            update_query, [json.dumps(updated_metadata), message_id]
         )
         )
 
 
     async def get_conversation(
     async def get_conversation(
-        self, conversation_id: UUID
-    ) -> List[MessageResponse]:
-        # Check conversation
-        conv_query = f"SELECT extract(epoch from created_at) AS created_at_epoch FROM {self._get_table_name('conversations')} WHERE id = $1"
-        conv_row = await self.connection_manager.fetchrow_query(
-            conv_query, [conversation_id]
-        )
+        self,
+        conversation_id: UUID,
+        filter_user_ids: Optional[list[UUID]] = None,
+    ) -> list[MessageResponse]:
+        conditions = ["c.id = $1"]
+        params: list = [conversation_id]
+
+        if filter_user_ids:
+            param_index = 2
+            conditions.append(
+                f"""
+                c.user_id IN (
+                    SELECT id
+                    FROM {self.project_name}.users
+                    WHERE id = ANY(${param_index})
+                )
+            """
+            )
+            params.append(filter_user_ids)
+
+        query = f"""
+            SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch
+            FROM {self._get_table_name('conversations')} c
+            WHERE {' AND '.join(conditions)}
+        """
+
+        conv_row = await self.connection_manager.fetchrow_query(query, params)
         if not conv_row:
         if not conv_row:
             raise R2RException(
             raise R2RException(
                 status_code=404,
                 status_code=404,
@@ -329,8 +352,6 @@ class PostgresConversationsHandler(Handler):
             )
             )
 
 
         # Retrieve messages in chronological order
         # Retrieve messages in chronological order
-        # We'll recursively gather messages based on parent_id = NULL as root.
-        # Since no branching, we simply order by created_at.
         msg_query = f"""
         msg_query = f"""
             SELECT id, content, metadata
             SELECT id, content, metadata
             FROM {self._get_table_name("messages")}
             FROM {self._get_table_name("messages")}
@@ -341,21 +362,78 @@ class PostgresConversationsHandler(Handler):
             msg_query, [conversation_id]
             msg_query, [conversation_id]
         )
         )
 
 
-        print("results = ", results)
         return [
         return [
             MessageResponse(
             MessageResponse(
-                id=str(row["id"]),
+                id=row["id"],
                 message=Message(**json.loads(row["content"])),
                 message=Message(**json.loads(row["content"])),
                 metadata=json.loads(row["metadata"]),
                 metadata=json.loads(row["metadata"]),
             )
             )
             for row in results
             for row in results
         ]
         ]
 
 
-    async def delete_conversation(self, conversation_id: UUID):
-        # Check if conversation exists
-        conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
+    async def update_conversation(
+        self, conversation_id: UUID, name: str
+    ) -> ConversationResponse:
+        try:
+            # Check if conversation exists
+            conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
+            conv_row = await self.connection_manager.fetchrow_query(
+                conv_query, [conversation_id]
+            )
+            if not conv_row:
+                raise R2RException(
+                    status_code=404,
+                    message=f"Conversation {conversation_id} not found.",
+                )
+
+            update_query = f"""
+            UPDATE {self._get_table_name('conversations')}
+            SET name = $1 WHERE id = $2
+            RETURNING user_id, extract(epoch from created_at) as created_at_epoch
+            """
+            updated_row = await self.connection_manager.fetchrow_query(
+                update_query, [name, conversation_id]
+            )
+            return ConversationResponse(
+                id=conversation_id,
+                created_at=updated_row["created_at_epoch"],
+                user_id=updated_row["user_id"] or None,
+                name=name,
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to update conversation: {str(e)}",
+            ) from e
+
+    async def delete_conversation(
+        self,
+        conversation_id: UUID,
+        filter_user_ids: Optional[list[UUID]] = None,
+    ) -> None:
+        conditions = ["c.id = $1"]
+        params: list = [conversation_id]
+
+        if filter_user_ids:
+            param_index = 2
+            conditions.append(
+                f"""
+                c.user_id IN (
+                    SELECT id
+                    FROM {self.project_name}.users
+                    WHERE id = ANY(${param_index})
+                )
+            """
+            )
+            params.append(filter_user_ids)
+
+        conv_query = f"""
+            SELECT 1
+            FROM {self._get_table_name('conversations')} c
+            WHERE {' AND '.join(conditions)}
+        """
         conv_row = await self.connection_manager.fetchrow_query(
         conv_row = await self.connection_manager.fetchrow_query(
-            conv_query, [conversation_id]
+            conv_query, params
         )
         )
         if not conv_row:
         if not conv_row:
             raise R2RException(
             raise R2RException(

+ 10 - 148
core/database/documents.py

@@ -9,9 +9,9 @@ import asyncpg
 from fastapi import HTTPException
 from fastapi import HTTPException
 
 
 from core.base import (
 from core.base import (
-    Handler,
     DocumentResponse,
     DocumentResponse,
     DocumentType,
     DocumentType,
+    Handler,
     IngestionStatus,
     IngestionStatus,
     KGEnrichmentStatus,
     KGEnrichmentStatus,
     KGExtractionStatus,
     KGExtractionStatus,
@@ -20,18 +20,13 @@ from core.base import (
 )
 )
 
 
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
+from .filters import apply_filters  # Add this near other imports
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
 class PostgresDocumentsHandler(Handler):
 class PostgresDocumentsHandler(Handler):
     TABLE_NAME = "documents"
     TABLE_NAME = "documents"
-    COLUMN_VARS = [
-        "extraction_id",
-        "id",
-        "owner_id",
-        "collection_ids",
-    ]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -517,12 +512,12 @@ class PostgresDocumentsHandler(Handler):
         where_clauses = ["summary_embedding IS NOT NULL"]
         where_clauses = ["summary_embedding IS NOT NULL"]
         params: list[str | int | bytes] = [str(query_embedding)]
         params: list[str | int | bytes] = [str(query_embedding)]
 
 
-        # Handle filters
         if search_settings.filters:
         if search_settings.filters:
-            filter_clause = self._build_filters(
-                search_settings.filters, params
+            filter_condition, params = apply_filters(
+                search_settings.filters, params, mode="condition_only"
             )
             )
-            where_clauses.append(filter_clause)
+            if filter_condition:
+                where_clauses.append(filter_condition)
 
 
         where_clause = " AND ".join(where_clauses)
         where_clause = " AND ".join(where_clauses)
 
 
@@ -599,12 +594,12 @@ class PostgresDocumentsHandler(Handler):
         where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
         where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
         params: list[str | int | bytes] = [query_text]
         params: list[str | int | bytes] = [query_text]
 
 
-        # Handle filters
         if search_settings.filters:
         if search_settings.filters:
-            filter_clause = self._build_filters(
-                search_settings.filters, params
+            filter_condition, params = apply_filters(
+                search_settings.filters, params, mode="condition_only"
             )
             )
-            where_clauses.append(filter_clause)
+            if filter_condition:
+                where_clauses.append(filter_condition)
 
 
         where_clause = " AND ".join(where_clauses)
         where_clause = " AND ".join(where_clauses)
 
 
@@ -798,136 +793,3 @@ class PostgresDocumentsHandler(Handler):
             )
             )
         else:
         else:
             return await self.full_text_document_search(query_text, settings)
             return await self.full_text_document_search(query_text, settings)
-
-    # TODO - Remove copy pasta, consolidate
-    def _build_filters(
-        self, filters: dict, parameters: list[str | int | bytes]
-    ) -> str:
-
-        def parse_condition(key: str, value: Any) -> str:  # type: ignore
-            # nonlocal parameters
-            if key in self.COLUMN_VARS:
-                # Handle column-based filters
-                if isinstance(value, dict):
-                    op, clause = next(iter(value.items()))
-                    if op == "$eq":
-                        parameters.append(clause)
-                        return f"{key} = ${len(parameters)}"
-                    elif op == "$ne":
-                        parameters.append(clause)
-                        return f"{key} != ${len(parameters)}"
-                    elif op == "$in":
-                        parameters.append(clause)
-                        return f"{key} = ANY(${len(parameters)})"
-                    elif op == "$nin":
-                        parameters.append(clause)
-                        return f"{key} != ALL(${len(parameters)})"
-                    elif op == "$overlap":
-                        parameters.append(clause)
-                        return f"{key} && ${len(parameters)}"
-                    elif op == "$contains":
-                        parameters.append(clause)
-                        return f"{key} @> ${len(parameters)}"
-                    elif op == "$any":
-                        if key == "collection_ids":
-                            parameters.append(f"%{clause}%")
-                            return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
-                        parameters.append(clause)
-                        return f"${len(parameters)} = ANY({key})"
-                    else:
-                        raise ValueError(
-                            f"Unsupported operator for column {key}: {op}"
-                        )
-                else:
-                    # Handle direct equality
-                    parameters.append(value)
-                    return f"{key} = ${len(parameters)}"
-            else:
-                # Handle JSON-based filters
-                json_col = "metadata"
-                if key.startswith("metadata."):
-                    key = key.split("metadata.")[1]
-                if isinstance(value, dict):
-                    op, clause = next(iter(value.items()))
-                    if op not in (
-                        "$eq",
-                        "$ne",
-                        "$lt",
-                        "$lte",
-                        "$gt",
-                        "$gte",
-                        "$in",
-                        "$contains",
-                    ):
-                        raise ValueError("unknown operator")
-
-                    if op == "$eq":
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
-                        )
-                    elif op == "$ne":
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
-                        )
-                    elif op == "$lt":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
-                    elif op == "$lte":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
-                    elif op == "$gt":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
-                    elif op == "$gte":
-                        parameters.append(json.dumps(clause))
-                        return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
-                    elif op == "$in":
-                        if not isinstance(clause, list):
-                            raise ValueError(
-                                "argument to $in filter must be a list"
-                            )
-                        parameters.append(json.dumps(clause))
-                        return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
-                    elif op == "$contains":
-                        if not isinstance(clause, (int, str, float, list)):
-                            raise ValueError(
-                                "argument to $contains filter must be a scalar or array"
-                            )
-                        parameters.append(json.dumps(clause))
-                        return (
-                            f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
-                        )
-
-        def parse_filter(filter_dict: dict) -> str:
-            filter_conditions = []
-            for key, value in filter_dict.items():
-                if key == "$and":
-                    and_conditions = [
-                        parse_filter(f) for f in value if f
-                    ]  # Skip empty dictionaries
-                    if and_conditions:
-                        filter_conditions.append(
-                            f"({' AND '.join(and_conditions)})"
-                        )
-                elif key == "$or":
-                    or_conditions = [
-                        parse_filter(f) for f in value if f
-                    ]  # Skip empty dictionaries
-                    if or_conditions:
-                        filter_conditions.append(
-                            f"({' OR '.join(or_conditions)})"
-                        )
-                else:
-                    filter_conditions.append(parse_condition(key, value))
-
-            # Check if there is only a single condition
-            if len(filter_conditions) == 1:
-                return filter_conditions[0]
-            else:
-                return " AND ".join(filter_conditions)
-
-        where_clause = parse_filter(filters)
-
-        return where_clause

+ 433 - 0
core/database/filters.py

@@ -0,0 +1,433 @@
+import json
+from typing import Any, Optional, Tuple, Union
+from uuid import UUID
+
+COLUMN_VARS = [
+    "id",
+    "document_id",
+    "owner_id",
+    "collection_ids",
+]
+
+
+class FilterError(Exception):
+    pass
+
+
+class FilterOperator:
+    EQ = "$eq"
+    NE = "$ne"
+    LT = "$lt"
+    LTE = "$lte"
+    GT = "$gt"
+    GTE = "$gte"
+    IN = "$in"
+    NIN = "$nin"
+    LIKE = "$like"
+    ILIKE = "$ilike"
+    CONTAINS = "$contains"
+    AND = "$and"
+    OR = "$or"
+    OVERLAP = "$overlap"
+
+    SCALAR_OPS = {EQ, NE, LT, LTE, GT, GTE, LIKE, ILIKE}
+    ARRAY_OPS = {IN, NIN, OVERLAP}
+    JSON_OPS = {CONTAINS}
+    LOGICAL_OPS = {AND, OR}
+
+
+class FilterCondition:
+    def __init__(self, field: str, operator: str, value: Any):
+        self.field = field
+        self.operator = operator
+        self.value = value
+
+
+class FilterExpression:
+    def __init__(self, logical_op: Optional[str] = None):
+        self.logical_op = logical_op
+        self.conditions: list[Union[FilterCondition, "FilterExpression"]] = []
+
+
+class FilterParser:
+    def __init__(
+        self,
+        top_level_columns: Optional[list[str]] = None,
+        json_column: str = "metadata",
+    ):
+        if top_level_columns is None:
+            self.top_level_columns = set(COLUMN_VARS)
+        else:
+            self.top_level_columns = set(top_level_columns)
+        self.json_column = json_column
+
+    def parse(self, filters: dict) -> FilterExpression:
+        if not filters:
+            raise FilterError("Empty filters are not allowed")
+        return self._parse_logical(filters)
+
+    def _parse_logical(self, dct: dict) -> FilterExpression:
+        keys = list(dct.keys())
+        expr = FilterExpression()
+        if len(keys) == 1 and keys[0] in (
+            FilterOperator.AND,
+            FilterOperator.OR,
+        ):
+            expr.logical_op = keys[0]
+            if not isinstance(dct[keys[0]], list):
+                raise FilterError(f"{keys[0]} value must be a list")
+            for item in dct[keys[0]]:
+                if isinstance(item, dict):
+                    if self._is_logical_block(item):
+                        expr.conditions.append(self._parse_logical(item))
+                    else:
+                        expr.conditions.append(
+                            self._parse_condition_dict(item)
+                        )
+                else:
+                    raise FilterError("Invalid filter format")
+        else:
+            expr.logical_op = FilterOperator.AND
+            expr.conditions.append(self._parse_condition_dict(dct))
+
+        return expr
+
+    def _is_logical_block(self, dct: dict) -> bool:
+        if len(dct.keys()) == 1:
+            k = next(iter(dct.keys()))
+            if k in FilterOperator.LOGICAL_OPS:
+                return True
+        return False
+
+    def _parse_condition_dict(self, dct: dict) -> FilterExpression:
+        expr = FilterExpression(logical_op=FilterOperator.AND)
+        for field, cond in dct.items():
+            if not isinstance(cond, dict):
+                # direct equality
+                expr.conditions.append(
+                    FilterCondition(field, FilterOperator.EQ, cond)
+                )
+            else:
+                if len(cond) != 1:
+                    raise FilterError(
+                        f"Condition for field {field} must have exactly one operator"
+                    )
+                op, val = next(iter(cond.items()))
+                self._validate_operator(op)
+                expr.conditions.append(FilterCondition(field, op, val))
+        return expr
+
+    def _validate_operator(self, op: str):
+        allowed = (
+            FilterOperator.SCALAR_OPS
+            | FilterOperator.ARRAY_OPS
+            | FilterOperator.JSON_OPS
+            | FilterOperator.LOGICAL_OPS
+        )
+        if op not in allowed:
+            raise FilterError(f"Unsupported operator: {op}")
+
+
+class SQLFilterBuilder:
+    def __init__(
+        self,
+        params: list[Any],
+        top_level_columns: Optional[list[str]] = None,
+        json_column: str = "metadata",
+        mode: str = "where_clause",
+    ):
+        if top_level_columns is None:
+            self.top_level_columns = set(COLUMN_VARS)
+        else:
+            self.top_level_columns = set(top_level_columns)
+        self.json_column = json_column
+        self.params: list[Any] = (
+            params  # params are mutated during construction
+        )
+        self.mode = mode
+
+    def build(self, expr: FilterExpression) -> Tuple[str, list[Any]]:
+        where_clause = self._build_expression(expr)
+        if self.mode == "where_clause":
+            return f"WHERE {where_clause}", self.params
+
+        return where_clause, self.params
+
+    def _build_expression(self, expr: FilterExpression) -> str:
+        parts = []
+        for c in expr.conditions:
+            if isinstance(c, FilterCondition):
+                parts.append(self._build_condition(c))
+            else:
+                nested_sql = self._build_expression(c)
+                parts.append(f"({nested_sql})")
+
+        if expr.logical_op == FilterOperator.AND:
+            return " AND ".join(parts)
+        elif expr.logical_op == FilterOperator.OR:
+            return " OR ".join(parts)
+        else:
+            return " AND ".join(parts)
+
+    @staticmethod
+    def _psql_quote_literal(value: str) -> str:
+        """
+        Safely quote a string literal for PostgreSQL to prevent SQL injection.
+        This is a simple implementation - in production, you should use proper parameterization
+        or your database driver's quoting functions.
+        """
+        return "'" + value.replace("'", "''") + "'"
+
+    def _build_condition(self, cond: FilterCondition) -> str:
+        field_is_metadata = cond.field not in self.top_level_columns
+        key = cond.field
+        op = cond.operator
+        val = cond.value
+
+        # Handle special logic for collection_id
+        if key == "collection_id":
+            return self._build_collection_id_condition(op, val)
+
+        if field_is_metadata:
+            return self._build_metadata_condition(key, op, val)
+        else:
+            return self._build_column_condition(key, op, val)
+
+    def _build_collection_id_condition(self, op: str, val: Any) -> str:
+        param_idx = len(self.params) + 1
+
+        # Handle operations
+        if op == "$eq":
+            # Expect a single UUID, ensure val is a string
+            if not isinstance(val, str):
+                raise FilterError(
+                    "$eq for collection_id expects a single UUID string"
+                )
+            self.params.append(val)
+            # Check if val is in the collection_ids array
+            return f"${param_idx}::uuid = ANY(collection_ids)"
+
+        elif op == "$ne":
+            # Not equal means val is not in collection_ids
+            if not isinstance(val, str):
+                raise FilterError(
+                    "$ne for collection_id expects a single UUID string"
+                )
+            self.params.append(val)
+            return f"NOT (${param_idx}::uuid = ANY(collection_ids))"
+
+        elif op == "$in":
+            # Expect a list of UUIDs, any of which may match
+            if not isinstance(val, list):
+                raise FilterError(
+                    "$in for collection_id expects a list of UUID strings"
+                )
+            self.params.append(val)
+            # Use overlap to check if any of the given IDs are in collection_ids
+            return f"collection_ids && ${param_idx}::uuid[]"
+
+        elif op == "$nin":
+            # None of the given UUIDs should be in collection_ids
+            if not isinstance(val, list):
+                raise FilterError(
+                    "$nin for collection_id expects a list of UUID strings"
+                )
+            self.params.append(val)
+            # Negate overlap condition
+            return f"NOT (collection_ids && ${param_idx}::uuid[])"
+
+        elif op == "$contains":
+            # If someone tries "$contains" with a single collection_id, we can check if collection_ids fully contain it
+            # Usually $contains might mean we want to see if collection_ids contain a certain element.
+            # That's basically $eq logic. For a single value:
+            if isinstance(val, str):
+                self.params.append([val])  # Array of one element
+                return f"collection_ids @> ${param_idx}::uuid[]"
+            elif isinstance(val, list):
+                self.params.append(val)
+                return f"collection_ids @> ${param_idx}::uuid[]"
+            else:
+                raise FilterError(
+                    "$contains for collection_id expects a UUID or list of UUIDs"
+                )
+
+        else:
+            raise FilterError(f"Unsupported operator {op} for collection_id")
+
+    def _build_column_condition(self, col: str, op: str, val: Any) -> str:
+        param_idx = len(self.params) + 1
+        if op == "$eq":
+            self.params.append(val)
+            return f"{col} = ${param_idx}"
+        elif op == "$ne":
+            self.params.append(val)
+            return f"{col} != ${param_idx}"
+        elif op == "$in":
+            if not isinstance(val, list):
+                raise FilterError("argument to $in filter must be a list")
+            self.params.append(val)
+            return f"{col} = ANY(${param_idx})"
+        elif op == "$nin":
+            if not isinstance(val, list):
+                raise FilterError("argument to $nin filter must be a list")
+            self.params.append(val)
+            return f"{col} != ALL(${param_idx})"
+        elif op == "$overlap":
+            self.params.append(val)
+            return f"{col} && ${param_idx}"
+        elif op == "$contains":
+            self.params.append(val)
+            return f"{col} @> ${param_idx}"
+        elif op == "$any":
+            # If col == "collection_ids" handle special case
+            if col == "collection_ids":
+                self.params.append(f"%{val}%")
+                return f"array_to_string({col}, ',') LIKE ${param_idx}"
+            else:
+                self.params.append(val)
+                return f"${param_idx} = ANY({col})"
+        elif op in ("$lt", "$lte", "$gt", "$gte"):
+            self.params.append(val)
+            return f"{col} {self._map_op(op)} ${param_idx}"
+        else:
+            raise FilterError(f"Unsupported operator for column {col}: {op}")
+
+    def _build_metadata_condition(self, key: str, op: str, val: Any) -> str:
+        param_idx = len(self.params) + 1
+        json_col = self.json_column
+
+        # Strip "metadata." prefix if present
+        if key.startswith("metadata."):
+            key = key[len("metadata.") :]
+
+        # Split on '.' to handle nested keys
+        parts = key.split(".")
+
+        # Use text extraction for scalar values, but not for arrays
+        use_text_extraction = op in (
+            "$lt",
+            "$lte",
+            "$gt",
+            "$gte",
+            "$eq",
+            "$ne",
+        )
+        if op == "$in" or op == "$contains" or isinstance(val, (list, dict)):
+            use_text_extraction = False
+
+        # Build the JSON path expression
+        if len(parts) == 1:
+            # Single part key
+            if use_text_extraction:
+                path_expr = f"{json_col}->>'{parts[0]}'"
+            else:
+                path_expr = f"{json_col}->'{parts[0]}'"
+        else:
+            # Multiple segments
+            inner_parts = parts[:-1]
+            last_part = parts[-1]
+            # Build chain for the inner parts
+            path_expr = json_col
+            for p in inner_parts:
+                path_expr += f"->'{p}'"
+            # Last part
+            if use_text_extraction:
+                path_expr += f"->>'{last_part}'"
+            else:
+                path_expr += f"->'{last_part}'"
+
+        # Convert numeric values to strings for text comparison
+        def prepare_value(v):
+            if isinstance(v, (int, float)):
+                return str(v)
+            return v
+
+        # Now apply the operator logic
+        if op == "$eq":
+            if use_text_extraction:
+                self.params.append(prepare_value(val))
+                return f"{path_expr} = ${param_idx}"
+            else:
+                self.params.append(json.dumps(val))
+                return f"{path_expr} = ${param_idx}::jsonb"
+        elif op == "$ne":
+            if use_text_extraction:
+                self.params.append(prepare_value(val))
+                return f"{path_expr} != ${param_idx}"
+            else:
+                self.params.append(json.dumps(val))
+                return f"{path_expr} != ${param_idx}::jsonb"
+        elif op == "$lt":
+            self.params.append(prepare_value(val))
+            return f"({path_expr})::numeric < ${param_idx}::numeric"
+        elif op == "$lte":
+            self.params.append(prepare_value(val))
+            return f"({path_expr})::numeric <= ${param_idx}::numeric"
+        elif op == "$gt":
+            self.params.append(prepare_value(val))
+            return f"({path_expr})::numeric > ${param_idx}::numeric"
+        elif op == "$gte":
+            self.params.append(prepare_value(val))
+            return f"({path_expr})::numeric >= ${param_idx}::numeric"
+        elif op == "$in":
+            if not isinstance(val, list):
+                raise FilterError("argument to $in filter must be a list")
+
+            # For regular scalar values, use ANY with text extraction
+            if use_text_extraction:
+                str_vals = [
+                    str(v) if isinstance(v, (int, float)) else v for v in val
+                ]
+                self.params.append(str_vals)
+                return f"{path_expr} = ANY(${param_idx}::text[])"
+
+            # For JSON arrays, use containment checks
+            conditions = []
+            for i, v in enumerate(val):
+                self.params.append(json.dumps(v))
+                conditions.append(f"{path_expr} @> ${param_idx + i}::jsonb")
+            return f"({' OR '.join(conditions)})"
+
+        elif op == "$contains":
+            if isinstance(val, (str, int, float, bool)):
+                val = [val]
+            self.params.append(json.dumps(val))
+            return f"{path_expr} @> ${param_idx}::jsonb"
+        else:
+            raise FilterError(f"Unsupported operator for metadata field {op}")
+
+    def _map_op(self, op: str) -> str:
+        mapping = {
+            FilterOperator.EQ: "=",
+            FilterOperator.NE: "!=",
+            FilterOperator.LT: "<",
+            FilterOperator.LTE: "<=",
+            FilterOperator.GT: ">",
+            FilterOperator.GTE: ">=",
+        }
+        return mapping.get(op, op)
+
+
+def apply_filters(
+    filters: dict, params: list[Any], mode: str = "where_clause"
+) -> str:
+    """
+    Apply filters with consistent WHERE clause handling
+    """
+
+    if not filters:
+        return ""
+
+    parser = FilterParser()
+    expr = parser.parse(filters)
+    builder = SQLFilterBuilder(params=params, mode=mode)
+    filter_clause, new_params = builder.build(expr)
+
+    if mode == "where_clause":
+        return filter_clause, new_params  # Already includes WHERE
+    elif mode == "condition_only":
+        return filter_clause, new_params
+    elif mode == "append_only":
+        return f"AND {filter_clause}", new_params
+    else:
+        raise ValueError(f"Unknown filter mode: {mode}")

+ 87 - 181
core/database/graphs.py

@@ -1,11 +1,12 @@
 import asyncio
 import asyncio
+import contextlib
 import datetime
 import datetime
 import json
 import json
 import logging
 import logging
 import os
 import os
 import time
 import time
 from enum import Enum
 from enum import Enum
-from typing import Any, AsyncGenerator, Optional, Tuple, Union
+from typing import Any, AsyncGenerator, Optional, Tuple
 from uuid import UUID
 from uuid import UUID
 
 
 import asyncpg
 import asyncpg
@@ -19,7 +20,6 @@ from core.base.abstractions import (
     Graph,
     Graph,
     KGCreationSettings,
     KGCreationSettings,
     KGEnrichmentSettings,
     KGEnrichmentSettings,
-    KGEnrichmentStatus,
     KGEntityDeduplicationSettings,
     KGEntityDeduplicationSettings,
     KGExtractionStatus,
     KGExtractionStatus,
     R2RException,
     R2RException,
@@ -128,10 +128,8 @@ class PostgresEntitiesHandler(Handler):
         table_name = self._get_entity_table_for_store(store_type)
         table_name = self._get_entity_table_for_store(store_type)
 
 
         if isinstance(metadata, str):
         if isinstance(metadata, str):
-            try:
+            with contextlib.suppress(json.JSONDecodeError):
                 metadata = json.loads(metadata)
                 metadata = json.loads(metadata)
-            except json.JSONDecodeError:
-                pass
 
 
         if isinstance(description_embedding, list):
         if isinstance(description_embedding, list):
             description_embedding = str(description_embedding)
             description_embedding = str(description_embedding)
@@ -238,12 +236,10 @@ class PostgresEntitiesHandler(Handler):
 
 
             # Process metadata if it exists and is a string
             # Process metadata if it exists and is a string
             if isinstance(entity_dict["metadata"], str):
             if isinstance(entity_dict["metadata"], str):
-                try:
+                with contextlib.suppress(json.JSONDecodeError):
                     entity_dict["metadata"] = json.loads(
                     entity_dict["metadata"] = json.loads(
                         entity_dict["metadata"]
                         entity_dict["metadata"]
                     )
                     )
-                except json.JSONDecodeError:
-                    pass
 
 
             entities.append(Entity(**entity_dict))
             entities.append(Entity(**entity_dict))
 
 
@@ -266,10 +262,8 @@ class PostgresEntitiesHandler(Handler):
         param_index = 1
         param_index = 1
 
 
         if isinstance(metadata, str):
         if isinstance(metadata, str):
-            try:
+            with contextlib.suppress(json.JSONDecodeError):
                 metadata = json.loads(metadata)
                 metadata = json.loads(metadata)
-            except json.JSONDecodeError:
-                pass
 
 
         if name is not None:
         if name is not None:
             update_fields.append(f"name = ${param_index}")
             update_fields.append(f"name = ${param_index}")
@@ -327,7 +321,7 @@ class PostgresEntitiesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while updating the entity: {e}",
                 detail=f"An error occurred while updating the entity: {e}",
-            )
+            ) from e
 
 
     async def delete(
     async def delete(
         self,
         self,
@@ -478,10 +472,8 @@ class PostgresRelationshipsHandler(Handler):
         table_name = self._get_relationship_table_for_store(store_type)
         table_name = self._get_relationship_table_for_store(store_type)
 
 
         if isinstance(metadata, str):
         if isinstance(metadata, str):
-            try:
+            with contextlib.suppress(json.JSONDecodeError):
                 metadata = json.loads(metadata)
                 metadata = json.loads(metadata)
-            except json.JSONDecodeError:
-                pass
 
 
         if isinstance(description_embedding, list):
         if isinstance(description_embedding, list):
             description_embedding = str(description_embedding)
             description_embedding = str(description_embedding)
@@ -621,12 +613,10 @@ class PostgresRelationshipsHandler(Handler):
             if include_metadata and isinstance(
             if include_metadata and isinstance(
                 relationship_dict["metadata"], str
                 relationship_dict["metadata"], str
             ):
             ):
-                try:
+                with contextlib.suppress(json.JSONDecodeError):
                     relationship_dict["metadata"] = json.loads(
                     relationship_dict["metadata"] = json.loads(
                         relationship_dict["metadata"]
                         relationship_dict["metadata"]
                     )
                     )
-                except json.JSONDecodeError:
-                    pass
             elif not include_metadata:
             elif not include_metadata:
                 relationship_dict.pop("metadata", None)
                 relationship_dict.pop("metadata", None)
             relationships.append(Relationship(**relationship_dict))
             relationships.append(Relationship(**relationship_dict))
@@ -654,10 +644,8 @@ class PostgresRelationshipsHandler(Handler):
         param_index = 1
         param_index = 1
 
 
         if isinstance(metadata, str):
         if isinstance(metadata, str):
-            try:
+            with contextlib.suppress(json.JSONDecodeError):
                 metadata = json.loads(metadata)
                 metadata = json.loads(metadata)
-            except json.JSONDecodeError:
-                pass
 
 
         if subject is not None:
         if subject is not None:
             update_fields.append(f"subject = ${param_index}")
             update_fields.append(f"subject = ${param_index}")
@@ -735,7 +723,7 @@ class PostgresRelationshipsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while updating the relationship: {e}",
                 detail=f"An error occurred while updating the relationship: {e}",
-            )
+            ) from e
 
 
     async def delete(
     async def delete(
         self,
         self,
@@ -832,7 +820,6 @@ class PostgresCommunitiesHandler(Handler):
         rating_explanation: Optional[str],
         rating_explanation: Optional[str],
         description_embedding: Optional[list[float] | str] = None,
         description_embedding: Optional[list[float] | str] = None,
     ) -> Community:
     ) -> Community:
-        # Do we ever want to get communities from document store?
         table_name = "graphs_communities"
         table_name = "graphs_communities"
 
 
         if isinstance(description_embedding, list):
         if isinstance(description_embedding, list):
@@ -876,7 +863,7 @@ class PostgresCommunitiesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while creating the community: {e}",
                 detail=f"An error occurred while creating the community: {e}",
-            )
+            ) from e
 
 
     async def update(
     async def update(
         self,
         self,
@@ -956,45 +943,51 @@ class PostgresCommunitiesHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while updating the community: {e}",
                 detail=f"An error occurred while updating the community: {e}",
-            )
+            ) from e
 
 
     async def delete(
     async def delete(
         self,
         self,
         parent_id: UUID,
         parent_id: UUID,
-        community_id: UUID,
+        community_id: UUID = None,
     ) -> None:
     ) -> None:
         table_name = "graphs_communities"
         table_name = "graphs_communities"
 
 
+        params = [community_id, parent_id]
+
+        # Delete the community
         query = f"""
         query = f"""
             DELETE FROM {self._get_table_name(table_name)}
             DELETE FROM {self._get_table_name(table_name)}
             WHERE id = $1 AND collection_id = $2
             WHERE id = $1 AND collection_id = $2
         """
         """
 
 
-        params = [community_id, parent_id]
-
         try:
         try:
-            results = await self.connection_manager.execute_query(
-                query, params
-            )
+            await self.connection_manager.execute_query(query, params)
         except Exception as e:
         except Exception as e:
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while deleting the community: {e}",
                 detail=f"An error occurred while deleting the community: {e}",
             )
             )
 
 
-        params = [
-            community_id,
-            parent_id,
-        ]
+    async def delete_all_communities(
+        self,
+        parent_id: UUID,
+    ) -> None:
+        table_name = "graphs_communities"
+
+        params = [parent_id]
+
+        # Delete all communities for the parent_id
+        query = f"""
+            DELETE FROM {self._get_table_name(table_name)}
+            WHERE collection_id = $1
+        """
 
 
         try:
         try:
-            results = await self.connection_manager.execute_query(
-                query, params
-            )
+            await self.connection_manager.execute_query(query, params)
         except Exception as e:
         except Exception as e:
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
-                detail=f"An error occurred while deleting the community: {e}",
+                detail=f"An error occurred while deleting communities: {e}",
             )
             )
 
 
     async def get(
     async def get(
@@ -1176,44 +1169,15 @@ class PostgresGraphsHandler(Handler):
         """
         """
         Completely reset a graph and all associated data.
         Completely reset a graph and all associated data.
         """
         """
-        try:
-            entity_delete_query = f"""
-                DELETE FROM {self._get_table_name("graphs_entities")}
-                WHERE parent_id = $1
-            """
-            await self.connection_manager.execute_query(
-                entity_delete_query, [parent_id]
-            )
-
-            # Delete all graph relationships
-            relationship_delete_query = f"""
-                DELETE FROM {self._get_table_name("graphs_relationships")}
-                WHERE parent_id = $1
-            """
-            await self.connection_manager.execute_query(
-                relationship_delete_query, [parent_id]
-            )
 
 
-            # Delete all graph relationships
-            community_delete_query = f"""
-                DELETE FROM {self._get_table_name("graphs_communities")}
-                WHERE collection_id = $1
-            """
-            await self.connection_manager.execute_query(
-                community_delete_query, [parent_id]
-            )
-
-            # Delete all graph communities and community info
-            query = f"""
-                DELETE FROM {self._get_table_name("graphs_communities")}
-                WHERE collection_id = $1
-            """
-
-            await self.connection_manager.execute_query(query, [parent_id])
-
-        except Exception as e:
-            logger.error(f"Error deleting graph {parent_id}: {str(e)}")
-            raise R2RException(f"Failed to delete graph: {str(e)}", 500)
+        await self.entities.delete(
+            parent_id=parent_id, store_type=StoreType.GRAPHS
+        )
+        await self.relationships.delete(
+            parent_id=parent_id, store_type=StoreType.GRAPHS
+        )
+        await self.communities.delete_all_communities(parent_id=parent_id)
+        return
 
 
     async def list_graphs(
     async def list_graphs(
         self,
         self,
@@ -1289,7 +1253,7 @@ class PostgresGraphsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while fetching graphs: {e}",
                 detail=f"An error occurred while fetching graphs: {e}",
-            )
+            ) from e
 
 
     async def get(
     async def get(
         self, offset: int, limit: int, graph_id: Optional[UUID] = None
         self, offset: int, limit: int, graph_id: Optional[UUID] = None
@@ -1445,7 +1409,7 @@ class PostgresGraphsHandler(Handler):
             raise HTTPException(
             raise HTTPException(
                 status_code=500,
                 status_code=500,
                 detail=f"An error occurred while updating the graph: {e}",
                 detail=f"An error occurred while updating the graph: {e}",
-            )
+            ) from e
 
 
     async def get_creation_estimate(
     async def get_creation_estimate(
         self,
         self,
@@ -1676,7 +1640,9 @@ class PostgresGraphsHandler(Handler):
             )
             )
         except Exception as e:
         except Exception as e:
             logger.error(f"Error in get_deduplication_estimate: {str(e)}")
             logger.error(f"Error in get_deduplication_estimate: {str(e)}")
-            raise HTTPException(500, "Error fetching deduplication estimate.")
+            raise HTTPException(
+                500, "Error fetching deduplication estimate."
+            ) from e
 
 
     async def get_entities(
     async def get_entities(
         self,
         self,
@@ -1754,12 +1720,10 @@ class PostgresGraphsHandler(Handler):
         for row in rows:
         for row in rows:
             entity_dict = dict(row)
             entity_dict = dict(row)
             if isinstance(entity_dict["metadata"], str):
             if isinstance(entity_dict["metadata"], str):
-                try:
+                with contextlib.suppress(json.JSONDecodeError):
                     entity_dict["metadata"] = json.loads(
                     entity_dict["metadata"] = json.loads(
                         entity_dict["metadata"]
                         entity_dict["metadata"]
                     )
                     )
-                except json.JSONDecodeError:
-                    pass
 
 
             entities.append(Entity(**entity_dict))
             entities.append(Entity(**entity_dict))
 
 
@@ -1840,12 +1804,10 @@ class PostgresGraphsHandler(Handler):
         for row in rows:
         for row in rows:
             relationship_dict = dict(row)
             relationship_dict = dict(row)
             if isinstance(relationship_dict["metadata"], str):
             if isinstance(relationship_dict["metadata"], str):
-                try:
+                with contextlib.suppress(json.JSONDecodeError):
                     relationship_dict["metadata"] = json.loads(
                     relationship_dict["metadata"] = json.loads(
                         relationship_dict["metadata"]
                         relationship_dict["metadata"]
                     )
                     )
-                except json.JSONDecodeError:
-                    pass
 
 
             relationships.append(Relationship(**relationship_dict))
             relationships.append(Relationship(**relationship_dict))
 
 
@@ -1889,29 +1851,6 @@ class PostgresGraphsHandler(Handler):
             conflict_columns=conflict_columns,
             conflict_columns=conflict_columns,
         )
         )
 
 
-    async def delete_node_via_document_id(
-        self, document_id: UUID, collection_id: UUID
-    ) -> None:
-        # don't delete if status is PROCESSING.
-        QUERY = f"""
-            SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
-        """
-        status = (
-            await self.connection_manager.fetch_query(QUERY, [collection_id])
-        )[0]["graph_cluster_status"]
-        if status == KGExtractionStatus.PROCESSING.value:
-            return
-
-        # Execute separate DELETE queries
-        delete_queries = [
-            f"""DELETE FROM {self._get_table_name("documents_relationships")} WHERE parent_id = $1""",
-            f"""DELETE FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1""",
-        ]
-
-        for query in delete_queries:
-            await self.connection_manager.execute_query(query, [document_id])
-        return None
-
     async def get_all_relationships(
     async def get_all_relationships(
         self,
         self,
         collection_id: UUID | None,
         collection_id: UUID | None,
@@ -2056,64 +1995,17 @@ class PostgresGraphsHandler(Handler):
             QUERY, [tuple(non_null_attrs.values())]
             QUERY, [tuple(non_null_attrs.values())]
         )
         )
 
 
-    async def delete_graph_for_collection(
-        self, collection_id: UUID, cascade: bool = False
-    ) -> None:
-
-        # don't delete if status is PROCESSING.
-        QUERY = f"""
-            SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
-        """
-        status = (
-            await self.connection_manager.fetch_query(QUERY, [collection_id])
-        )[0]["graph_cluster_status"]
-        if status == KGExtractionStatus.PROCESSING.value:
-            return
-
-        # remove all relationships for these documents.
-        DELETE_QUERIES = [
-            f"DELETE FROM {self._get_table_name('graphs_communities')} WHERE collection_id = $1;",
-        ]
-
-        # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
-        document_ids_response = (
-            await self.collections_handler.documents_in_collection(
-                offset=0,
-                limit=100,
-                collection_id=collection_id,
-            )
-        )
-
-        # This type ignore is due to insufficient typing of the documents_in_collection method
-        document_ids = [doc.id for doc in document_ids_response["results"]]  # type: ignore
+    # async def delete(self, collection_id: UUID, cascade: bool = False) -> None:
+    async def delete(self, collection_id: UUID) -> None:
 
 
-        # TODO: make these queries more efficient. Pass the document_ids as params.
-        if cascade:
-            DELETE_QUERIES += [
-                f"DELETE FROM {self._get_table_name('graphs_relationships')} WHERE document_id = ANY($1::uuid[]);",
-                f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE document_id = ANY($1::uuid[]);",
-                f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE collection_id = $1;",
-            ]
+        graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
 
 
-            # setting the kg_creation_status to PENDING for this collection.
-            QUERY = f"""
-                UPDATE {self._get_table_name("documents")} SET extraction_status = $1 WHERE $2::uuid = ANY(collection_ids)
-            """
-            await self.connection_manager.execute_query(
-                QUERY, [KGExtractionStatus.PENDING, collection_id]
+        if len(graphs["results"]) == 0:
+            raise R2RException(
+                message=f"Graph not found for collection {collection_id}",
+                status_code=404,
             )
             )
-
-        if document_ids:
-            for query in DELETE_QUERIES:
-                if "community" in query or "graphs_entities" in query:
-                    await self.connection_manager.execute_query(
-                        query, [collection_id]
-                    )
-                else:
-                    await self.connection_manager.execute_query(
-                        query, [document_ids]
-                    )
-
+        await self.reset(collection_id)
         # set status to PENDING for this collection.
         # set status to PENDING for this collection.
         QUERY = f"""
         QUERY = f"""
             UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
             UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
@@ -2121,6 +2013,37 @@ class PostgresGraphsHandler(Handler):
         await self.connection_manager.execute_query(
         await self.connection_manager.execute_query(
             QUERY, [KGExtractionStatus.PENDING, collection_id]
             QUERY, [KGExtractionStatus.PENDING, collection_id]
         )
         )
+        # Delete the graph
+        QUERY = f"""
+            DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
+        """
+
+        # if cascade:
+        #     documents = []
+        #     document_response = (
+        #         await self.collections_handler.documents_in_collection(
+        #             offset=0,
+        #             limit=100,
+        #             collection_id=collection_id,
+        #         )
+        #     )["results"]
+        #     documents.extend(document_response)
+        #     document_ids = [doc.id for doc in documents]
+        #     for document_id in document_ids:
+        #         self.entities.delete(
+        #             parent_id=document_id, store_type=StoreType.DOCUMENTS
+        #         )
+        #         self.relationships.delete(
+        #             parent_id=document_id, store_type=StoreType.DOCUMENTS
+        #         )
+
+        #     # setting the extraction status to PENDING for the documents in this collection.
+        #     QUERY = f"""
+        #         UPDATE {self._get_table_name("documents")} SET extraction_status = $1 WHERE $2::uuid = ANY(collection_ids)
+        #     """
+        #     await self.connection_manager.execute_query(
+        #         QUERY, [KGExtractionStatus.PENDING, collection_id]
+        #     )
 
 
     async def perform_graph_clustering(
     async def perform_graph_clustering(
         self,
         self,
@@ -2426,7 +2349,7 @@ class PostgresGraphsHandler(Handler):
         property_names_str = ", ".join(property_names)
         property_names_str = ", ".join(property_names)
 
 
         # Build the WHERE clause from filters
         # Build the WHERE clause from filters
-        params: list[Union[str, int, bytes]] = [
+        params: list[str | int | bytes] = [
             json.dumps(query_embedding),
             json.dumps(query_embedding),
             limit,
             limit,
         ]
         ]
@@ -2583,23 +2506,6 @@ class PostgresGraphsHandler(Handler):
 
 
         return parse_filter(filter_dict)
         return parse_filter(filter_dict)
 
 
-    # async def _create_graph_and_cluster(
-    #     self, relationships: list[Relationship], leiden_params: dict[str, Any]
-    # ) -> Any:
-
-    #     G = self.nx.Graph()
-    #     for relationship in relationships:
-    #         G.add_edge(
-    #             relationship.subject,
-    #             relationship.object,
-    #             weight=relationship.weight,
-    #             id=relationship.id,
-    #         )
-
-    #     logger.info(f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges")
-
-    #     return await self._compute_leiden_communities(G, leiden_params)
-
     async def _compute_leiden_communities(
     async def _compute_leiden_communities(
         self,
         self,
         graph: Any,
         graph: Any,

+ 79 - 150
core/database/limits.py

@@ -3,11 +3,12 @@ from datetime import datetime, timedelta, timezone
 from typing import Optional
 from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
-from core.base import Handler, R2RException
+from core.base import Handler
 
 
+from ..base.providers.database import DatabaseConfig, LimitSettings
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
 
 
-logger = logging.getLogger()
+logger = logging.getLogger(__name__)
 
 
 
 
 class PostgresLimitsHandler(Handler):
 class PostgresLimitsHandler(Handler):
@@ -17,10 +18,13 @@ class PostgresLimitsHandler(Handler):
         self,
         self,
         project_name: str,
         project_name: str,
         connection_manager: PostgresConnectionManager,
         connection_manager: PostgresConnectionManager,
-        route_limits: dict,
+        config: DatabaseConfig,
     ):
     ):
         super().__init__(project_name, connection_manager)
         super().__init__(project_name, connection_manager)
-        self.route_limits = route_limits
+        self.config = config
+        logger.debug(
+            f"Initialized PostgresLimitsHandler with project: {project_name}"
+        )
 
 
     async def create_tables(self):
     async def create_tables(self):
         query = f"""
         query = f"""
@@ -30,6 +34,7 @@ class PostgresLimitsHandler(Handler):
             route TEXT NOT NULL
             route TEXT NOT NULL
         );
         );
         """
         """
+        logger.debug("Creating request_log table if not exists")
         await self.connection_manager.execute_query(query)
         await self.connection_manager.execute_query(query)
 
 
     async def _count_requests(
     async def _count_requests(
@@ -44,6 +49,7 @@ class PostgresLimitsHandler(Handler):
               AND time >= $3
               AND time >= $3
             """
             """
             params = [user_id, route, since]
             params = [user_id, route, since]
+            logger.debug(f"Counting requests for route {route}")
         else:
         else:
             query = f"""
             query = f"""
             SELECT COUNT(*)::int
             SELECT COUNT(*)::int
@@ -52,55 +58,97 @@ class PostgresLimitsHandler(Handler):
               AND time >= $2
               AND time >= $2
             """
             """
             params = [user_id, since]
             params = [user_id, since]
+            logger.debug("Counting all requests")
 
 
         result = await self.connection_manager.fetchrow_query(query, params)
         result = await self.connection_manager.fetchrow_query(query, params)
-        return result["count"] if result else 0
+        count = result["count"] if result else 0
+
+        return count
 
 
     async def _count_monthly_requests(self, user_id: UUID) -> int:
     async def _count_monthly_requests(self, user_id: UUID) -> int:
         now = datetime.now(timezone.utc)
         now = datetime.now(timezone.utc)
         start_of_month = now.replace(
         start_of_month = now.replace(
             day=1, hour=0, minute=0, second=0, microsecond=0
             day=1, hour=0, minute=0, second=0, microsecond=0
         )
         )
-        return await self._count_requests(
+
+        count = await self._count_requests(
             user_id, route=None, since=start_of_month
             user_id, route=None, since=start_of_month
         )
         )
+        return count
+
+    def _determine_limits_for(
+        self, user_id: UUID, route: str
+    ) -> LimitSettings:
+        # Start with base limits
+        limits = self.config.limits
+
+        # Route-specific limits - directly override if present
+        route_limits = self.config.route_limits.get(route)
+        if route_limits:
+            # Only override non-None values from route_limits
+            if route_limits.global_per_min is not None:
+                limits.global_per_min = route_limits.global_per_min
+            if route_limits.route_per_min is not None:
+                limits.route_per_min = route_limits.route_per_min
+            if route_limits.monthly_limit is not None:
+                limits.monthly_limit = route_limits.monthly_limit
+
+        # User-specific limits - directly override if present
+        user_limits = self.config.user_limits.get(user_id)
+        if user_limits:
+            # Only override non-None values from user_limits
+            if user_limits.global_per_min is not None:
+                limits.global_per_min = user_limits.global_per_min
+            if user_limits.route_per_min is not None:
+                limits.route_per_min = user_limits.route_per_min
+            if user_limits.monthly_limit is not None:
+                limits.monthly_limit = user_limits.monthly_limit
+
+        return limits
 
 
     async def check_limits(self, user_id: UUID, route: str):
     async def check_limits(self, user_id: UUID, route: str):
-        limits = self.route_limits.get(
-            route,
-            {
-                "global_per_min": 60,
-                "route_per_min": 30,
-                "monthly_limit": 10000,
-            },
-        )
+        # Determine final applicable limits
+        limits = self._determine_limits_for(user_id, route)
+        if not limits:
+            limits = self.config.default_limits
 
 
-        global_per_min = limits["global_per_min"]
-        route_per_min = limits["route_per_min"]
-        monthly_limit = limits["monthly_limit"]
+        global_per_min = limits.global_per_min
+        route_per_min = limits.route_per_min
+        monthly_limit = limits.monthly_limit
 
 
         now = datetime.now(timezone.utc)
         now = datetime.now(timezone.utc)
         one_min_ago = now - timedelta(minutes=1)
         one_min_ago = now - timedelta(minutes=1)
 
 
         # Global per-minute check
         # Global per-minute check
-        user_req_count = await self._count_requests(user_id, None, one_min_ago)
-        print("min req count = ", user_req_count)
-        if user_req_count >= global_per_min:
-            raise ValueError("Global per-minute rate limit exceeded")
+        if global_per_min is not None:
+            user_req_count = await self._count_requests(
+                user_id, None, one_min_ago
+            )
+            if user_req_count > global_per_min:
+                logger.warning(
+                    f"Global per-minute limit exceeded for user_id={user_id}, route={route}"
+                )
+                raise ValueError("Global per-minute rate limit exceeded")
 
 
         # Per-route per-minute check
         # Per-route per-minute check
-        route_req_count = await self._count_requests(
-            user_id, route, one_min_ago
-        )
-        if route_req_count >= route_per_min:
-            raise ValueError("Per-route per-minute rate limit exceeded")
+        if route_per_min is not None:
+            route_req_count = await self._count_requests(
+                user_id, route, one_min_ago
+            )
+            if route_req_count > route_per_min:
+                logger.warning(
+                    f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
+                )
+                raise ValueError("Per-route per-minute rate limit exceeded")
 
 
         # Monthly limit check
         # Monthly limit check
-        monthly_count = await self._count_monthly_requests(user_id)
-        print("monthly_count = ", monthly_count)
-
-        if monthly_count >= monthly_limit:
-            raise ValueError("Monthly rate limit exceeded")
+        if monthly_limit is not None:
+            monthly_count = await self._count_monthly_requests(user_id)
+            if monthly_count > monthly_limit:
+                logger.warning(
+                    f"Monthly limit exceeded for user_id={user_id}, route={route}"
+                )
+                raise ValueError("Monthly rate limit exceeded")
 
 
     async def log_request(self, user_id: UUID, route: str):
     async def log_request(self, user_id: UUID, route: str):
         query = f"""
         query = f"""
@@ -108,122 +156,3 @@ class PostgresLimitsHandler(Handler):
         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
         """
         """
         await self.connection_manager.execute_query(query, [user_id, route])
         await self.connection_manager.execute_query(query, [user_id, route])
-
-
-# import logging
-# from datetime import datetime, timedelta
-# from typing import Optional
-# from uuid import UUID
-
-# from core.base import Handler, R2RException
-
-# from .base import PostgresConnectionManager
-
-# logger = logging.getLogger()
-
-
-# class PostgresLimitsHandler(Handler):
-#     TABLE_NAME = "request_log"
-
-#     def __init__(
-#         self,
-#         project_name: str,
-#         connection_manager: PostgresConnectionManager,
-#         route_limits: dict,
-#     ):
-#         super().__init__(project_name, connection_manager)
-#         self.route_limits = route_limits
-
-#     async def create_tables(self):
-#         """
-#         Create the request_log table if it doesn't exist.
-#         """
-#         query = f"""
-#         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
-#             time TIMESTAMPTZ NOT NULL,
-#             user_id UUID NOT NULL,
-#             route TEXT NOT NULL
-#         );
-#         """
-#         await self.connection_manager.execute_query(query)
-
-#     async def _count_requests(
-#         self, user_id: UUID, route: Optional[str], since: datetime
-#     ) -> int:
-#         if route:
-#             query = f"""
-#             SELECT COUNT(*)::int
-#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
-#             WHERE user_id = $1
-#               AND route = $2
-#               AND time >= $3
-#             """
-#             params = [user_id, route, since]
-#         else:
-#             query = f"""
-#             SELECT COUNT(*)::int
-#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
-#             WHERE user_id = $1
-#               AND time >= $2
-#             """
-#             params = [user_id, since]
-
-#         result = await self.connection_manager.fetchrow_query(query, params)
-#         return result["count"] if result else 0
-
-#     async def _count_monthly_requests(self, user_id: UUID) -> int:
-#         now = datetime.utcnow()
-#         start_of_month = now.replace(
-#             day=1, hour=0, minute=0, second=0, microsecond=0
-#         )
-#         return await self._count_requests(
-#             user_id, route=None, since=start_of_month
-#         )
-
-#     async def check_limits(self, user_id: UUID, route: str):
-#         """
-#         Check if the user can proceed with the request, using route-specific limits.
-#         Raises ValueError if the user exceeded any limit.
-#         """
-#         limits = self.route_limits.get(
-#             route,
-#             {
-#                 "global_per_min": 60,  # default global per min
-#                 "route_per_min": 20,  # default route per min
-#                 "monthly_limit": 10000,  # default monthly limit
-#             },
-#         )
-
-#         global_per_min = limits["global_per_min"]
-#         route_per_min = limits["route_per_min"]
-#         monthly_limit = limits["monthly_limit"]
-
-#         now = datetime.utcnow()
-#         one_min_ago = now - timedelta(minutes=1)
-
-#         # Global per-minute check
-#         user_req_count = await self._count_requests(user_id, None, one_min_ago)
-#         print('min req count = ', user_req_count)
-#         if user_req_count >= global_per_min:
-#             raise ValueError("Global per-minute rate limit exceeded")
-
-#         # Per-route per-minute check
-#         route_req_count = await self._count_requests(
-#             user_id, route, one_min_ago
-#         )
-#         if route_req_count >= route_per_min:
-#             raise ValueError("Per-route per-minute rate limit exceeded")
-
-#         # Monthly limit check
-#         monthly_count = await self._count_monthly_requests(user_id)
-#         print('monthly_count = ', monthly_count)
-
-#         if monthly_count >= monthly_limit:
-#             raise ValueError("Monthly rate limit exceeded")
-
-#     async def log_request(self, user_id: UUID, route: str):
-#         query = f"""
-#         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
-#         VALUES (NOW(), $1, $2)
-#         """
-#         await self.connection_manager.execute_query(query, [user_id, route])

+ 4 - 5
core/database/postgres.py

@@ -28,7 +28,7 @@ from .tokens import PostgresTokensHandler
 from .users import PostgresUserHandler
 from .users import PostgresUserHandler
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from ..providers.crypto import BCryptProvider
+    from ..providers.crypto import NaClCryptoProvider
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
@@ -57,7 +57,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
     dimension: int
     dimension: int
     conn: Optional[Any]
     conn: Optional[Any]
 
 
-    crypto_provider: "BCryptProvider"
+    crypto_provider: "NaClCryptoProvider"
     postgres_configuration_settings: PostgresConfigurationSettings
     postgres_configuration_settings: PostgresConfigurationSettings
     default_collection_name: str
     default_collection_name: str
     default_collection_description: str
     default_collection_description: str
@@ -81,7 +81,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
         self,
         self,
         config: DatabaseConfig,
         config: DatabaseConfig,
         dimension: int,
         dimension: int,
-        crypto_provider: "BCryptProvider",
+        crypto_provider: "NaClCryptoProvider",
         quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
         quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
         *args,
         *args,
         **kwargs,
         **kwargs,
@@ -203,8 +203,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
         self.limits_handler = PostgresLimitsHandler(
         self.limits_handler = PostgresLimitsHandler(
             project_name=self.project_name,
             project_name=self.project_name,
             connection_manager=self.connection_manager,
             connection_manager=self.connection_manager,
-            # TODO - this should be set in the config
-            route_limits={},
+            config=self.config,
         )
         )
 
 
     async def initialize(self):
     async def initialize(self):

+ 126 - 5
core/database/users.py

@@ -15,6 +15,7 @@ from .collections import PostgresCollectionsHandler
 
 
 class PostgresUserHandler(Handler):
 class PostgresUserHandler(Handler):
     TABLE_NAME = "users"
     TABLE_NAME = "users"
+    API_KEYS_TABLE_NAME = "users_api_keys"
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -26,7 +27,7 @@ class PostgresUserHandler(Handler):
         self.crypto_provider = crypto_provider
         self.crypto_provider = crypto_provider
 
 
     async def create_tables(self):
     async def create_tables(self):
-        query = f"""
+        user_table_query = f"""
         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
             id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
             id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
             email TEXT UNIQUE NOT NULL,
             email TEXT UNIQUE NOT NULL,
@@ -46,7 +47,27 @@ class PostgresUserHandler(Handler):
             updated_at TIMESTAMPTZ DEFAULT NOW()
             updated_at TIMESTAMPTZ DEFAULT NOW()
         );
         );
         """
         """
-        await self.connection_manager.execute_query(query)
+        # API keys table with updated_at instead of last_used_at
+        api_keys_table_query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
+            public_key TEXT UNIQUE NOT NULL,
+            hashed_key TEXT NOT NULL,
+            name TEXT,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW()
+        );
+
+        CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
+        ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
+
+        CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
+        ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
+        """
+
+        await self.connection_manager.execute_query(user_table_query)
+        await self.connection_manager.execute_query(api_keys_table_query)
 
 
     async def get_user_by_id(self, id: UUID) -> User:
     async def get_user_by_id(self, id: UUID) -> User:
         query, _ = (
         query, _ = (
@@ -502,7 +523,7 @@ class PostgresUserHandler(Handler):
 
 
     async def get_user_id_by_verification_code(
     async def get_user_id_by_verification_code(
         self, verification_code: str
         self, verification_code: str
-    ) -> Optional[UUID]:
+    ) -> UUID:
         query = f"""
         query = f"""
             SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
             SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
             WHERE verification_code = $1 AND verification_code_expiry > NOW()
             WHERE verification_code = $1 AND verification_code_expiry > NOW()
@@ -549,9 +570,12 @@ class PostgresUserHandler(Handler):
                     u.is_superuser,
                     u.is_superuser,
                     u.is_active,
                     u.is_active,
                     u.is_verified,
                     u.is_verified,
+                    u.name,
+                    u.bio,
+                    u.profile_picture,
+                    u.collection_ids,
                     u.created_at,
                     u.created_at,
                     u.updated_at,
                     u.updated_at,
-                    u.collection_ids,
                     COUNT(d.id) AS num_files,
                     COUNT(d.id) AS num_files,
                     COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
                     COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
                     ud.doc_ids as document_ids
                     ud.doc_ids as document_ids
@@ -588,6 +612,8 @@ class PostgresUserHandler(Handler):
                 is_superuser=row["is_superuser"],
                 is_superuser=row["is_superuser"],
                 is_active=row["is_active"],
                 is_active=row["is_active"],
                 is_verified=row["is_verified"],
                 is_verified=row["is_verified"],
+                name=row["name"],
+                bio=row["bio"],
                 created_at=row["created_at"],
                 created_at=row["created_at"],
                 updated_at=row["updated_at"],
                 updated_at=row["updated_at"],
                 collection_ids=row["collection_ids"] or [],
                 collection_ids=row["collection_ids"] or [],
@@ -596,7 +622,7 @@ class PostgresUserHandler(Handler):
                 document_ids=(
                 document_ids=(
                     []
                     []
                     if row["document_ids"] is None
                     if row["document_ids"] is None
-                    else [doc_id for doc_id in row["document_ids"]]
+                    else list(row["document_ids"])
                 ),
                 ),
             )
             )
             for row in results
             for row in results
@@ -658,3 +684,98 @@ class PostgresUserHandler(Handler):
                 ),
                 ),
             }
             }
         }
         }
+
+    # API Key methods
+    async def store_user_api_key(
+        self,
+        user_id: UUID,
+        key_id: str,
+        hashed_key: str,
+        name: Optional[str] = None,
+    ) -> UUID:
+        """Store a new API key for a user"""
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            (user_id, public_key, hashed_key, name)
+            VALUES ($1, $2, $3, $4)
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [user_id, key_id, hashed_key, name]
+        )
+        if not result:
+            raise R2RException(
+                status_code=500, message="Failed to store API key"
+            )
+        return result["id"]
+
+    async def get_api_key_record(self, key_id: str) -> Optional[dict]:
+        """Get API key record and update updated_at"""
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            SET updated_at = NOW()
+            WHERE public_key = $1
+            RETURNING user_id, hashed_key
+        """
+        result = await self.connection_manager.fetchrow_query(query, [key_id])
+        if not result:
+            return None
+        return {
+            "user_id": result["user_id"],
+            "hashed_key": result["hashed_key"],
+        }
+
+    async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
+        """Get all API keys for a user"""
+        query = f"""
+            SELECT id, public_key, name, created_at, updated_at
+            FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            WHERE user_id = $1
+            ORDER BY created_at DESC
+        """
+        results = await self.connection_manager.fetch_query(query, [user_id])
+        return [
+            {
+                "key_id": str(row["id"]),
+                "public_key": row["public_key"],
+                "name": row["name"] or "",
+                "updated_at": row["updated_at"],
+            }
+            for row in results
+        ]
+
+    async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
+        """Delete a specific API key"""
+        query = f"""
+            DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            WHERE id = $1 AND user_id = $2
+            RETURNING id, public_key, name
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [key_id, user_id]
+        )
+        if result is None:
+            raise R2RException(status_code=404, message="API key not found")
+
+        return {
+            "key_id": str(result["id"]),
+            "public_key": str(result["public_key"]),
+            "name": result["name"] or "",
+        }
+
+    async def update_api_key_name(
+        self, user_id: UUID, key_id: UUID, name: str
+    ) -> bool:
+        """Update the name of an API key"""
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
+            SET name = $1, updated_at = NOW()
+            WHERE id = $2 AND user_id = $3
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [name, key_id, user_id]
+        )
+        if result is None:
+            raise R2RException(status_code=404, message="API key not found")
+        return True

BIN
core/examples/supported_file_types/bmp.bmp


+ 11 - 0
core/examples/supported_file_types/csv.csv

@@ -0,0 +1,11 @@
+Date,Customer ID,Product,Quantity,Unit Price,Total
+2024-01-15,C1001,Laptop Pro X,2,999.99,1999.98
+2024-01-15,C1002,Wireless Mouse,5,29.99,149.95
+2024-01-16,C1003,External SSD 1TB,3,159.99,479.97
+2024-01-16,C1001,USB-C Cable,4,19.99,79.96
+2024-01-17,C1004,Monitor 27",1,349.99,349.99
+2024-01-17,C1005,Keyboard Elite,2,129.99,259.98
+2024-01-18,C1002,Headphones Pro,1,199.99,199.99
+2024-01-18,C1006,Webcam HD,3,79.99,239.97
+2024-01-19,C1007,Power Bank,2,49.99,99.98
+2024-01-19,C1003,Phone Case,5,24.99,124.95

BIN
core/examples/supported_file_types/doc.doc


BIN
core/examples/supported_file_types/docx.docx


+ 61 - 0
core/examples/supported_file_types/eml.eml

@@ -0,0 +1,61 @@
+From: sender@example.com
+To: recipient@example.com
+Subject: Meeting Summary - Q4 Planning
+Date: Mon, 16 Dec 2024 10:30:00 -0500
+Content-Type: multipart/mixed; boundary="boundary123"
+
+--boundary123
+Content-Type: text/plain; charset="utf-8"
+Content-Transfer-Encoding: quoted-printable
+
+Hi Team,
+
+Here's a summary of our Q4 planning meeting:
+
+Key Points:
+1. Revenue targets increased by 15%
+2. New product launch scheduled for November
+3. Marketing budget approved for expansion
+
+Action Items:
+- Sarah: Prepare detailed product roadmap
+- Mike: Contact vendors for pricing
+- Jennifer: Update financial projections
+
+Please review and let me know if you have any questions.
+
+Best regards,
+Alex
+
+--boundary123
+Content-Type: text/html; charset="utf-8"
+Content-Transfer-Encoding: quoted-printable
+
+<html>
+<body>
+<p>Hi Team,</p>
+
+<p>Here's a summary of our Q4 planning meeting:</p>
+
+<h3>Key Points:</h3>
+<ul>
+<li>Revenue targets increased by 15%</li>
+<li>New product launch scheduled for November</li>
+<li>Marketing budget approved for expansion</li>
+</ul>
+
+<h3>Action Items:</h3>
+<ul>
+<li><strong>Sarah:</strong> Prepare detailed product roadmap</li>
+<li><strong>Mike:</strong> Contact vendors for pricing</li>
+<li><strong>Jennifer:</strong> Update financial projections</li>
+</ul>
+
+<p>Please review and let me know if you have any questions.</p>
+
+<p>Best regards,<br>
+Alex</p>
+</body>
+</html>
+
+--boundary123--

BIN
core/examples/supported_file_types/epub.epub


BIN
core/examples/supported_file_types/heic.heic


+ 69 - 0
core/examples/supported_file_types/html.html

@@ -0,0 +1,69 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Product Dashboard</title>
+    <style>
+        body {
+            font-family: Arial, sans-serif;
+            margin: 20px;
+            background-color: #f5f5f5;
+        }
+        .dashboard {
+            max-width: 800px;
+            margin: 0 auto;
+            padding: 20px;
+            background-color: white;
+            border-radius: 8px;
+            box-shadow: 0 2px 4px rgba(0,0,0,0.1);
+        }
+        .header {
+            text-align: center;
+            margin-bottom: 30px;
+        }
+        .metrics {
+            display: grid;
+            grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
+            gap: 20px;
+            margin-bottom: 30px;
+        }
+        .metric-card {
+            padding: 15px;
+            background-color: #f8f9fa;
+            border-radius: 4px;
+            text-align: center;
+        }
+    </style>
+</head>
+<body>
+    <div class="dashboard">
+        <div class="header">
+            <h1>Product Performance Dashboard</h1>
+            <p>Real-time metrics and analytics</p>
+        </div>
+        <div class="metrics">
+            <div class="metric-card">
+                <h3>Active Users</h3>
+                <p>1,234</p>
+            </div>
+            <div class="metric-card">
+                <h3>Revenue</h3>
+                <p>$45,678</p>
+            </div>
+            <div class="metric-card">
+                <h3>Conversion Rate</h3>
+                <p>2.34%</p>
+            </div>
+        </div>
+        <div class="recent-activity">
+            <h2>Recent Activity</h2>
+            <ul>
+                <li>New feature deployed: Enhanced search</li>
+                <li>Bug fix: Mobile navigation issue</li>
+                <li>Performance improvement: Cache optimization</li>
+            </ul>
+        </div>
+    </div>
+</body>
+</html>

BIN
core/examples/supported_file_types/jpeg.jpeg


BIN
core/examples/supported_file_types/jpg.jpg


+ 58 - 0
core/examples/supported_file_types/json.json

@@ -0,0 +1,58 @@
+{
+    "dashboard": {
+        "name": "Product Performance Dashboard",
+        "lastUpdated": "2024-12-16T10:30:00Z",
+        "metrics": {
+            "activeUsers": {
+                "current": 1234,
+                "previousPeriod": 1156,
+                "percentChange": 6.75
+            },
+            "revenue": {
+                "current": 45678.90,
+                "previousPeriod": 41234.56,
+                "percentChange": 10.78,
+                "currency": "USD"
+            },
+            "conversionRate": {
+                "current": 2.34,
+                "previousPeriod": 2.12,
+                "percentChange": 10.38,
+                "unit": "percent"
+            }
+        },
+        "recentActivity": [
+            {
+                "type": "deployment",
+                "title": "Enhanced search",
+                "description": "New feature deployed: Enhanced search functionality",
+                "timestamp": "2024-12-15T15:45:00Z",
+                "status": "successful"
+            },
+            {
+                "type": "bugfix",
+                "title": "Mobile navigation",
+                "description": "Bug fix: Mobile navigation issue resolved",
+                "timestamp": "2024-12-14T09:20:00Z",
+                "status": "successful"
+            },
+            {
+                "type": "performance",
+                "title": "Cache optimization",
+                "description": "Performance improvement: Cache optimization completed",
+                "timestamp": "2024-12-13T11:15:00Z",
+                "status": "successful"
+            }
+        ],
+        "settings": {
+            "refreshInterval": 300,
+            "timezone": "UTC",
+            "theme": "light",
+            "notifications": {
+                "email": true,
+                "slack": true,
+                "inApp": true
+            }
+        }
+    }
+}

+ 310 - 0
core/examples/supported_file_types/md.md

@@ -0,0 +1,310 @@
+# Markdown: Syntax
+
+*   [Overview](#overview)
+    *   [Philosophy](#philosophy)
+    *   [Inline HTML](#html)
+    *   [Automatic Escaping for Special Characters](#autoescape)
+*   [Block Elements](#block)
+    *   [Paragraphs and Line Breaks](#p)
+    *   [Headers](#header)
+    *   [Blockquotes](#blockquote)
+    *   [Lists](#list)
+    *   [Code Blocks](#precode)
+    *   [Horizontal Rules](#hr)
+*   [Span Elements](#span)
+    *   [Links](#link)
+    *   [Emphasis](#em)
+    *   [Code](#code)
+    *   [Images](#img)
+*   [Miscellaneous](#misc)
+    *   [Backslash Escapes](#backslash)
+    *   [Automatic Links](#autolink)
+
+
+**Note:** This document is itself written using Markdown; you
+can [see the source for it by adding '.text' to the URL](/projects/markdown/syntax.text).
+
+----
+
+## Overview
+
+### Philosophy
+
+Markdown is intended to be as easy-to-read and easy-to-write as is feasible.
+
+Readability, however, is emphasized above all else. A Markdown-formatted
+document should be publishable as-is, as plain text, without looking
+like it's been marked up with tags or formatting instructions. While
+Markdown's syntax has been influenced by several existing text-to-HTML
+filters -- including [Setext](http://docutils.sourceforge.net/mirror/setext.html), [atx](http://www.aaronsw.com/2002/atx/), [Textile](http://textism.com/tools/textile/), [reStructuredText](http://docutils.sourceforge.net/rst.html),
+[Grutatext](http://www.triptico.com/software/grutatxt.html), and [EtText](http://ettext.taint.org/doc/) -- the single biggest source of
+inspiration for Markdown's syntax is the format of plain text email.
+
+## Block Elements
+
+### Paragraphs and Line Breaks
+
+A paragraph is simply one or more consecutive lines of text, separated
+by one or more blank lines. (A blank line is any line that looks like a
+blank line -- a line containing nothing but spaces or tabs is considered
+blank.) Normal paragraphs should not be indented with spaces or tabs.
+
+The implication of the "one or more consecutive lines of text" rule is
+that Markdown supports "hard-wrapped" text paragraphs. This differs
+significantly from most other text-to-HTML formatters (including Movable
+Type's "Convert Line Breaks" option) which translate every line break
+character in a paragraph into a `<br />` tag.
+
+When you *do* want to insert a `<br />` break tag using Markdown, you
+end a line with two or more spaces, then type return.
+
+### Headers
+
+Markdown supports two styles of headers, [Setext] [1] and [atx] [2].
+
+Optionally, you may "close" atx-style headers. This is purely
+cosmetic -- you can use this if you think it looks better. The
+closing hashes don't even need to match the number of hashes
+used to open the header. (The number of opening hashes
+determines the header level.)
+
+
+### Blockquotes
+
+Markdown uses email-style `>` characters for blockquoting. If you're
+familiar with quoting passages of text in an email message, then you
+know how to create a blockquote in Markdown. It looks best if you hard
+wrap the text and put a `>` before every line:
+
+> This is a blockquote with two paragraphs. Lorem ipsum dolor sit amet,
+> consectetuer adipiscing elit. Aliquam hendrerit mi posuere lectus.
+> Vestibulum enim wisi, viverra nec, fringilla in, laoreet vitae, risus.
+>
+> Donec sit amet nisl. Aliquam semper ipsum sit amet velit. Suspendisse
+> id sem consectetuer libero luctus adipiscing.
+
+Markdown allows you to be lazy and only put the `>` before the first
+line of a hard-wrapped paragraph:
+
+> This is a blockquote with two paragraphs. Lorem ipsum dolor sit amet,
+consectetuer adipiscing elit. Aliquam hendrerit mi posuere lectus.
+Vestibulum enim wisi, viverra nec, fringilla in, laoreet vitae, risus.
+
+> Donec sit amet nisl. Aliquam semper ipsum sit amet velit. Suspendisse
+id sem consectetuer libero luctus adipiscing.
+
+Blockquotes can be nested (i.e. a blockquote-in-a-blockquote) by
+adding additional levels of `>`:
+
+> This is the first level of quoting.
+>
+> > This is nested blockquote.
+>
+> Back to the first level.
+
+Blockquotes can contain other Markdown elements, including headers, lists,
+and code blocks:
+
+> ## This is a header.
+>
+> 1.   This is the first list item.
+> 2.   This is the second list item.
+>
+> Here's some example code:
+>
+>     return shell_exec("echo $input | $markdown_script");
+
+Any decent text editor should make email-style quoting easy. For
+example, with BBEdit, you can make a selection and choose Increase
+Quote Level from the Text menu.
+
+
+### Lists
+
+Markdown supports ordered (numbered) and unordered (bulleted) lists.
+
+Unordered lists use asterisks, pluses, and hyphens -- interchangably
+-- as list markers:
+
+*   Red
+*   Green
+*   Blue
+
+is equivalent to:
+
++   Red
++   Green
++   Blue
+
+and:
+
+-   Red
+-   Green
+-   Blue
+
+Ordered lists use numbers followed by periods:
+
+1.  Bird
+2.  McHale
+3.  Parish
+
+It's important to note that the actual numbers you use to mark the
+list have no effect on the HTML output Markdown produces. The HTML
+Markdown produces from the above list is:
+
+If you instead wrote the list in Markdown like this:
+
+1.  Bird
+1.  McHale
+1.  Parish
+
+or even:
+
+3. Bird
+1. McHale
+8. Parish
+
+you'd get the exact same HTML output. The point is, if you want to,
+you can use ordinal numbers in your ordered Markdown lists, so that
+the numbers in your source match the numbers in your published HTML.
+But if you want to be lazy, you don't have to.
+
+To make lists look nice, you can wrap items with hanging indents:
+
+*   Lorem ipsum dolor sit amet, consectetuer adipiscing elit.
+    Aliquam hendrerit mi posuere lectus. Vestibulum enim wisi,
+    viverra nec, fringilla in, laoreet vitae, risus.
+*   Donec sit amet nisl. Aliquam semper ipsum sit amet velit.
+    Suspendisse id sem consectetuer libero luctus adipiscing.
+
+But if you want to be lazy, you don't have to:
+
+*   Lorem ipsum dolor sit amet, consectetuer adipiscing elit.
+Aliquam hendrerit mi posuere lectus. Vestibulum enim wisi,
+viverra nec, fringilla in, laoreet vitae, risus.
+*   Donec sit amet nisl. Aliquam semper ipsum sit amet velit.
+Suspendisse id sem consectetuer libero luctus adipiscing.
+
+List items may consist of multiple paragraphs. Each subsequent
+paragraph in a list item must be indented by either 4 spaces
+or one tab:
+
+1.  This is a list item with two paragraphs. Lorem ipsum dolor
+    sit amet, consectetuer adipiscing elit. Aliquam hendrerit
+    mi posuere lectus.
+
+    Vestibulum enim wisi, viverra nec, fringilla in, laoreet
+    vitae, risus. Donec sit amet nisl. Aliquam semper ipsum
+    sit amet velit.
+
+2.  Suspendisse id sem consectetuer libero luctus adipiscing.
+
+It looks nice if you indent every line of the subsequent
+paragraphs, but here again, Markdown will allow you to be
+lazy:
+
+*   This is a list item with two paragraphs.
+
+    This is the second paragraph in the list item. You're
+only required to indent the first line. Lorem ipsum dolor
+sit amet, consectetuer adipiscing elit.
+
+*   Another item in the same list.
+
+To put a blockquote within a list item, the blockquote's `>`
+delimiters need to be indented:
+
+*   A list item with a blockquote:
+
+    > This is a blockquote
+    > inside a list item.
+
+To put a code block within a list item, the code block needs
+to be indented *twice* -- 8 spaces or two tabs:
+
+*   A list item with a code block:
+
+        <code goes here>
+
+### Code Blocks
+
+Pre-formatted code blocks are used for writing about programming or
+markup source code. Rather than forming normal paragraphs, the lines
+of a code block are interpreted literally. Markdown wraps a code block
+in both `<pre>` and `<code>` tags.
+
+To produce a code block in Markdown, simply indent every line of the
+block by at least 4 spaces or 1 tab.
+
+This is a normal paragraph:
+
+    This is a code block.
+
+Here is an example of AppleScript:
+
+    tell application "Foo"
+        beep
+    end tell
+
+A code block continues until it reaches a line that is not indented
+(or the end of the article).
+
+Within a code block, ampersands (`&`) and angle brackets (`<` and `>`)
+are automatically converted into HTML entities. This makes it very
+easy to include example HTML source code using Markdown -- just paste
+it and indent it, and Markdown will handle the hassle of encoding the
+ampersands and angle brackets. For example, this:
+
+    <div class="footer">
+        &copy; 2004 Foo Corporation
+    </div>
+
+Regular Markdown syntax is not processed within code blocks. E.g.,
+asterisks are just literal asterisks within a code block. This means
+it's also easy to use Markdown to write about Markdown's own syntax.
+
+```
+tell application "Foo"
+    beep
+end tell
+```
+
+## Span Elements
+
+### Links
+
+Markdown supports two style of links: *inline* and *reference*.
+
+In both styles, the link text is delimited by [square brackets].
+
+To create an inline link, use a set of regular parentheses immediately
+after the link text's closing square bracket. Inside the parentheses,
+put the URL where you want the link to point, along with an *optional*
+title for the link, surrounded in quotes. For example:
+
+This is [an example](http://example.com/) inline link.
+
+[This link](http://example.net/) has no title attribute.
+
+### Emphasis
+
+Markdown treats asterisks (`*`) and underscores (`_`) as indicators of
+emphasis. Text wrapped with one `*` or `_` will be wrapped with an
+HTML `<em>` tag; double `*`'s or `_`'s will be wrapped with an HTML
+`<strong>` tag. E.g., this input:
+
+*single asterisks*
+
+_single underscores_
+
+**double asterisks**
+
+__double underscores__
+
+### Code
+
+To indicate a span of code, wrap it with backtick quotes (`` ` ``).
+Unlike a pre-formatted code block, a code span indicates code within a
+normal paragraph. For example:
+
+Use the `printf()` function.

BIN
core/examples/supported_file_types/msg.msg


BIN
core/examples/supported_file_types/odt.odt


+ 153 - 0
core/examples/supported_file_types/org.org

@@ -0,0 +1,153 @@
+#+title: Modern Org Example
+#+author: Daniel Mendler
+#+filetags: :example:org:
+
+This example Org file demonstrates the Org elements,
+which are styled by =org-modern=.
+
+-----
+
+* Headlines
+** Second level
+*** Third level
+**** Fourth level
+***** Fifth level
+
+* Task Lists [1/3]
+  - [X] Write =org-modern=
+  - [-] Publish =org-modern=
+  - [ ] Fix all the bugs
+
+* List Bullets
+  - Dash
+  + Plus
+  * Asterisk
+
+* Timestamps
+DEADLINE:  <2022-03-01 Tue>
+SCHEDULED: <2022-02-25 10:00>
+DRANGE:    [2022-03-01]--[2022-04-01]
+DRANGE:    <2022-03-01>--<2022-04-01>
+TRANGE:    [2022-03-01 Tue 10:42-11:00]
+TIMESTAMP: [2022-02-21 Mon 13:00]
+DREPEATED: <2022-02-26 Sat .+1d/2d +3d>
+TREPEATED: <2022-02-26 Sat 10:00 .+1d/2d>
+
+* Blocks
+
+#+begin_src emacs-lisp
+  ;; Taken from the well-structured Emacs config by @oantolin.
+  ;; Take a look at https://github.com/oantolin/emacs-config!
+  (defun command-of-the-day ()
+    "Show the documentation for a random command."
+    (interactive)
+    (let ((commands))
+      (mapatoms (lambda (s)
+                  (when (commandp s) (push s commands))))
+      (describe-function
+       (nth (random (length commands)) commands))))
+#+end_src
+
+#+begin_src calc
+  taylor(sin(x),x=0,3)
+#+end_src
+
+#+results:
+: pi x / 180 - 2.85779606768e-8 pi^3 x^3
+
+#+BEGIN_SRC C
+  printf("a|b\nc|d\n");
+#+END_SRC
+
+#+results:
+| a | b |
+| c | d |
+
+
+
+
+
+
+
+* Todo Labels and Tags
+** DONE Write =org-modern= :emacs:foss:coding:
+** TODO Publish =org-modern=
+** WAIT Fix all the bugs
+
+* Priorities
+** DONE [#A] Most important
+** TODO [#B] Less important
+** CANCEL [#C] Not that important
+** DONE [100%] [#A] Everything combined :tag:test:
+  * [X] First
+  * [X] Second
+  * [X] Third
+
+* Tables
+
+| N | N^2 | N^3 | N^4 | sqrt(n) | sqrt[4](N) |
+|---+----+----+----+---------+------------|
+| 2 |  4 |  8 | 16 |  1.4142 |     1.1892 |
+| 3 |  9 | 27 | 81 |  1.7321 |     1.3161 |
+
+|---+----+----+----+---------+------------|
+| N | N^2 | N^3 | N^4 | sqrt(n) | sqrt[4](N) |
+|---+----+----+----+---------+------------|
+| 2 |  4 |  8 | 16 |  1.4142 |     1.1892 |
+| 3 |  9 | 27 | 81 |  1.7321 |     1.3161 |
+|---+----+----+----+---------+------------|
+
+#+begin_example
+| a | b | c |
+| a | b | c |
+| a | b | c |
+#+end_example
+
+* Special Links
+
+Test numeric footnotes[fn:1] and named footnotes[fn:foo].
+
+<<This is an internal link>>
+
+<<<radio link>>>
+
+[[This is an internal link]]
+
+radio link
+
+[fn:1] This is footnote 1
+[fn:foo] This is the foonote
+
+* Progress bars
+
+- quotient [1/13]
+- quotient [2/13]
+- quotient [3/13]
+- quotient [4/13]
+- quotient [5/13]
+- quotient [6/13]
+- quotient [7/13]
+- quotient [8/13]
+- quotient [9/13]
+- quotient [10/13]
+- quotient [11/13]
+- quotient [12/13]
+- quotient [13/13]
+
+- percent [0%]
+- percent [1%]
+- percent [2%]
+- percent [5%]
+- percent [10%]
+- percent [20%]
+- percent [30%]
+- percent [40%]
+- percent [50%]
+- percent [60%]
+- percent [70%]
+- percent [80%]
+- percent [90%]
+- percent [100%]
+
+- overflow [110%]
+- overflow [20/10]

+ 50 - 0
core/examples/supported_file_types/p7s.p7s

@@ -0,0 +1,50 @@
+MIME-Version: 1.0
+Content-Type: multipart/signed; protocol="application/x-pkcs7-signature"; micalg="sha-256"; boundary="----2234CCF759A742BD58A8D9D012C3BC23"
+
+This is an S/MIME signed message
+
+------2234CCF759A742BD58A8D9D012C3BC23
+Hello World
+
+------2234CCF759A742BD58A8D9D012C3BC23
+Content-Type: application/x-pkcs7-signature; name="smime.p7s"
+Content-Transfer-Encoding: base64
+Content-Disposition: attachment; filename="smime.p7s"
+
+MIIGiwYJKoZIhvcNAQcCoIIGfDCCBngCAQExDzANBglghkgBZQMEAgEFADALBgkq
+hkiG9w0BBwGgggOpMIIDpTCCAo2gAwIBAgIUNUBhVZGwKQ9d8VLtLZLNvEwWnXUw
+DQYJKoZIhvcNAQELBQAwezELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3Ju
+aWExFjAUBgNVBAcMDVNhbiBGcmFuY2lzY28xDzANBgNVBAoMBlNjaVBoaTEOMAwG
+A1UEAwwFTm9sYW4xHjAcBgkqhkiG9w0BCQEWD25vbGFuQHNjaXBoaS5haTAeFw0y
+NDEyMTYyMDIxMjJaFw0yNTEyMTYyMDIxMjJaMHsxCzAJBgNVBAYTAlVTMRMwEQYD
+VQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMQ8wDQYDVQQK
+DAZTY2lQaGkxDjAMBgNVBAMMBU5vbGFuMR4wHAYJKoZIhvcNAQkBFg9ub2xhbkBz
+Y2lwaGkuYWkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCcBfnCPjDl
+SBzauhd/Q0z2lQc1smO6eDmaly3CsHvFMvINQrX9adnQt9PQW35oV+lzikDfEfpv
+W60pYLQR1iZEDu6ELS5iGjHFtnQvj8BYm23CKdDY+NGlZYJXgw9J1Ezz0wgqruYU
+yduy2Tdp3uWxMXkEnR681u1PEPAFqMx3qYpTzEkdu6tmIF5QYHLle4qKyxknV1Yu
+RZYc7OVpBfKlpt9Ya+i+gugNZoSwPgouLxdZkM5XBGgS2iMD7X2C5819DAmXzdm5
+l95VsCISQ5bjpmXiS8LHdFaTEqtvgeqw8nmlcU8994t0PpfdKFr0lL8NoiDYXht7
+v1mLmEmrtAoTAgMBAAGjITAfMB0GA1UdDgQWBBQZW3RPHHKH4MsjXsdwNtI0BQDu
+DzANBgkqhkiG9w0BAQsFAAOCAQEAEqYqqM/8BgB6LfHdj+vo7S9kHauh2bhLOZnm
+ecZu+N/Dg1WwIaCtGL6L5UmLkcQ28pJNgnUyr5eQZxtOa7y1CfDFxO6bnY8oeAcU
+0PqLi6sdUtLTjLlt47rOysCnIx8MjscQRfopH3sUD5eKYk3yMGVcTAVLBUMSgaUJ
+a+tYhk9UEcIFtKrmRmNE+kW8+t/UKSv4xT4aDvmiiIQgel88YMgu3ADv1WWDjbd9
+u96blAHOR4FpfJzuEJ/4YVOND//A4Skqv4r82lu6ZoQx0u1CJd4UOZVcGF2itRgI
+OSm2hgEG/UpmWKdIwskBQM1dwdFpSzMtYWnDAcPB3S5onmE4OjGCAqYwggKiAgEB
+MIGTMHsxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQH
+DA1TYW4gRnJhbmNpc2NvMQ8wDQYDVQQKDAZTY2lQaGkxDjAMBgNVBAMMBU5vbGFu
+MR4wHAYJKoZIhvcNAQkBFg9ub2xhbkBzY2lwaGkuYWkCFDVAYVWRsCkPXfFS7S2S
+zbxMFp11MA0GCWCGSAFlAwQCAQUAoIHkMBgGCSqGSIb3DQEJAzELBgkqhkiG9w0B
+BwEwHAYJKoZIhvcNAQkFMQ8XDTI0MTIxNjIwMjEyOVowLwYJKoZIhvcNAQkEMSIE
+ILCAItMVzx6xLSZlve0OavQGU8CgvpdSMvtJvL0CHPw2MHkGCSqGSIb3DQEJDzFs
+MGowCwYJYIZIAWUDBAEqMAsGCWCGSAFlAwQBFjALBglghkgBZQMEAQIwCgYIKoZI
+hvcNAwcwDgYIKoZIhvcNAwICAgCAMA0GCCqGSIb3DQMCAgFAMAcGBSsOAwIHMA0G
+CCqGSIb3DQMCAgEoMA0GCSqGSIb3DQEBAQUABIIBAAFj405qE8q1KSpxckUqUwrp
+HFnkySyQnxHykeTrC3IwbwerL3lA9KBaP9F+yuweXro4dCKAMx/I0ajCJqiMWgDq
+6Gctn+RQURgP1ZEUViAonCOFMJ9a5bQs351DgH13qB48J8PnRmVQsoZNsjI+0atk
+2f5WBXrbv+onrUemFA5DdKOmb7ZWX6LmuJWg92JZQYuA56hdal0OZMBWvtZxLPaG
+z8CJSscfcbMEJhSDHSodnj4JpS0TkNW8LtqCaKnCFVYWOBsUPI/L6g7kPZ02BAy+
+XjtEf3BlXNq3nTZlppXN21y0thKrp0IMkwKrfLeEzY3ir1XrjkTy99gIz+lw++w=
+
+------2234CCF759A742BD58A8D9D012C3BC23--

BIN
core/examples/supported_file_types/pdf.pdf


BIN
core/examples/supported_file_types/png.png


BIN
core/examples/supported_file_types/ppt.ppt


BIN
core/examples/supported_file_types/pptx.pptx


+ 86 - 0
core/examples/supported_file_types/rst.rst

@@ -0,0 +1,86 @@
+Header 1
+========
+--------
+Subtitle
+--------
+
+Example text.
+
+.. contents:: Table of Contents
+
+Header 2
+--------
+
+1. Blah blah ``code`` blah
+
+2. More ``code``, hooray
+
+3. Somé UTF-8°
+
+The UTF-8 quote character in this table used to cause python to go boom. Now docutils just silently ignores it.
+
+.. csv-table:: Things that are Awesome (on a scale of 1-11)
+	:quote: ”
+
+	Thing,Awesomeness
+	Icecream, 7
+	Honey Badgers, 10.5
+	Nickelback, -2
+	Iron Man, 10
+	Iron Man 2, 3
+	Tabular Data, 5
+	Made up ratings, 11
+
+.. code::
+
+	A block of code
+
+.. code:: python
+
+	python.code('hooray')
+
+.. code:: javascript
+
+	export function ƒ(ɑ, β) {}
+
+.. doctest:: ignored
+
+	>>> some_function()
+	'result'
+
+>>> some_function()
+'result'
+
+==============  ==========================================================
+Travis          http://travis-ci.org/tony/pullv
+Docs            http://pullv.rtfd.org
+API             http://pullv.readthedocs.org/en/latest/api.html
+Issues          https://github.com/tony/pullv/issues
+Source          https://github.com/tony/pullv
+==============  ==========================================================
+
+
+.. image:: https://scan.coverity.com/projects/621/badge.svg
+	:target: https://scan.coverity.com/projects/621
+	:alt: Coverity Scan Build Status
+
+.. image:: https://scan.coverity.com/projects/621/badge.svg
+	:alt: Coverity Scan Build Status
+
+Field list
+----------
+
+:123456789 123456789 123456789 123456789 123456789 1: Uh-oh! This name is too long!
+:123456789 123456789 123456789 123456789 1234567890: this is a long name,
+	but no problem!
+:123456789 12345: this is not so long, but long enough for the default!
+:123456789 1234: this should work even with the default :)
+
+someone@somewhere.org
+
+Press :kbd:`Ctrl+C` to quit
+
+
+.. raw:: html
+
+    <p><strong>RAW HTML!</strong></p><style> p {color:blue;} </style>

+ 5 - 0
core/examples/supported_file_types/rtf.rtf

@@ -0,0 +1,5 @@
+{\rtf1\ansi\deff0
+{\fonttbl{\f0\froman\fcharset0 Times New Roman;}}
+\viewkind4\uc1\pard\f0\fs24
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\par
+}

BIN
core/examples/supported_file_types/tiff.tiff


+ 11 - 0
core/examples/supported_file_types/tsv.tsv

@@ -0,0 +1,11 @@
+Region	Year	Quarter	Sales	Employees	Growth Rate
+North America	2024	Q1	1250000	45	5.2
+Europe	2024	Q1	980000	38	4.8
+Asia Pacific	2024	Q1	1450000	52	6.1
+South America	2024	Q1	580000	25	3.9
+Africa	2024	Q1	320000	18	4.2
+North America	2024	Q2	1380000	47	5.5
+Europe	2024	Q2	1050000	40	4.9
+Asia Pacific	2024	Q2	1520000	54	5.8
+South America	2024	Q2	620000	27	4.1
+Africa	2024	Q2	350000	20	4.4

+ 21 - 0
core/examples/supported_file_types/txt.txt

@@ -0,0 +1,21 @@
+Quod equidem non reprehendo;
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quibus natura iure responderit non esse verum aliunde finem beate vivendi, a se principia rei gerendae peti; Quae enim adhuc protulisti, popularia sunt, ego autem a te elegantiora desidero. Duo Reges: constructio interrete. Tum Lucius: Mihi vero ista valde probata sunt, quod item fratri puto. Bestiarum vero nullum iudicium puto. Nihil enim iam habes, quod ad corpus referas; Deinde prima illa, quae in congressu solemus: Quid tu, inquit, huc? Et homini, qui ceteris animantibus plurimum praestat, praecipue a natura nihil datum esse dicemus?
+
+Iam id ipsum absurdum, maximum malum neglegi. Quod ea non occurrentia fingunt, vincunt Aristonem; Atqui perspicuum est hominem e corpore animoque constare, cum primae sint animi partes, secundae corporis. Fieri, inquam, Triari, nullo pacto potest, ut non dicas, quid non probes eius, a quo dissentias. Equidem e Cn. An dubium est, quin virtus ita maximam partem optineat in rebus humanis, ut reliquas obruat?
+
+Quis istum dolorem timet?
+Summus dolor plures dies manere non potest? Dicet pro me ipsa virtus nec dubitabit isti vestro beato M. Tubulum fuisse, qua illum, cuius is condemnatus est rogatione, P. Quod si ita sit, cur opera philosophiae sit danda nescio.
+
+Ex eorum enim scriptis et institutis cum omnis doctrina liberalis, omnis historia.
+Quod si ita est, sequitur id ipsum, quod te velle video, omnes semper beatos esse sapientes. Cum enim fertur quasi torrens oratio, quamvis multa cuiusque modi rapiat, nihil tamen teneas, nihil apprehendas, nusquam orationem rapidam coerceas. Ita redarguitur ipse a sese, convincunturque scripta eius probitate ipsius ac moribus. At quanta conantur! Mundum hunc omnem oppidum esse nostrum! Incendi igitur eos, qui audiunt, vides. Vide, ne magis, inquam, tuum fuerit, cum re idem tibi, quod mihi, videretur, non nova te rebus nomina inponere. Qui-vere falsone, quaerere mittimus-dicitur oculis se privasse; Si ista mala sunt, in quae potest incidere sapiens, sapientem esse non esse ad beate vivendum satis. At vero si ad vitem sensus accesserit, ut appetitum quendam habeat et per se ipsa moveatur, quid facturam putas?
+
+Quem si tenueris, non modo meum Ciceronem, sed etiam me ipsum abducas licebit.
+Stulti autem malorum memoria torquentur, sapientes bona praeterita grata recordatione renovata delectant.
+Esse enim quam vellet iniquus iustus poterat inpune.
+Quae autem natura suae primae institutionis oblita est?
+Verum tamen cum de rebus grandioribus dicas, ipsae res verba rapiunt;
+Hoc est non modo cor non habere, sed ne palatum quidem.
+Voluptatem cum summum bonum diceret, primum in eo ipso parum vidit, deinde hoc quoque alienum; Sed tu istuc dixti bene Latine, parum plane. Nam haec ipsa mihi erunt in promptu, quae modo audivi, nec ante aggrediar, quam te ab istis, quos dicis, instructum videro. Fatebuntur Stoici haec omnia dicta esse praeclare, neque eam causam Zenoni desciscendi fuisse. Non autem hoc: igitur ne illud quidem. Ratio quidem vestra sic cogit. Cum audissem Antiochum, Brute, ut solebam, cum M. An quod ita callida est, ut optime possit architectari voluptates?
+
+Idemne, quod iucunde?
+Haec mihi videtur delicatior, ut ita dicam, molliorque ratio, quam virtutis vis gravitasque postulat. Sed quoniam et advesperascit et mihi ad villam revertendum est, nunc quidem hactenus; Cuius ad naturam apta ratio vera illa et summa lex a philosophis dicitur. Neque solum ea communia, verum etiam paria esse dixerunt. Sed nunc, quod agimus; A mene tu?

BIN
core/examples/supported_file_types/xls.xls


BIN
core/examples/supported_file_types/xlsx.xlsx


+ 17 - 0
core/examples/supported_file_types/xml.xml

@@ -0,0 +1,17 @@
+<root>
+  <person>
+    <name>John Doe</name>
+    <age>30</age>
+    <email>john.doe@example.com</email>
+  </person>
+  <person>
+    <name>Jane Smith</name>
+    <age>25</age>
+    <email>jane.smith@example.com</email>
+  </person>
+  <book>
+    <title>The Adventure Begins</title>
+    <author>Robert Johnson</author>
+    <year>2022</year>
+  </book>
+</root>

+ 1 - 1
core/main/__init__.py

@@ -30,5 +30,5 @@ __all__ = [
     "IngestionService",
     "IngestionService",
     "ManagementService",
     "ManagementService",
     "RetrievalService",
     "RetrievalService",
-    "KgService",
+    "GraphService",
 ]
 ]

+ 49 - 15
core/main/abstractions.py

@@ -1,9 +1,27 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Optional
+
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
 from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
-from core.base.pipes import AsyncPipe
 from core.database import PostgresDatabaseProvider
 from core.database import PostgresDatabaseProvider
 from core.pipelines import RAGPipeline, SearchPipeline
 from core.pipelines import RAGPipeline, SearchPipeline
+from core.pipes import (
+    EmbeddingPipe,
+    GraphClusteringPipe,
+    GraphCommunitySummaryPipe,
+    GraphDeduplicationPipe,
+    GraphDeduplicationSummaryPipe,
+    GraphDescriptionPipe,
+    GraphExtractionPipe,
+    GraphSearchSearchPipe,
+    GraphStoragePipe,
+    ParsingPipe,
+    RAGPipe,
+    SearchPipe,
+    StreamingRAGPipe,
+    VectorStoragePipe,
+)
 from core.providers import (
 from core.providers import (
     AsyncSMTPEmailProvider,
     AsyncSMTPEmailProvider,
     ConsoleMockEmailProvider,
     ConsoleMockEmailProvider,
@@ -21,6 +39,13 @@ from core.providers import (
     UnstructuredIngestionProvider,
     UnstructuredIngestionProvider,
 )
 )
 
 
+if TYPE_CHECKING:
+    from core.main.services.auth_service import AuthService
+    from core.main.services.graph_service import GraphService
+    from core.main.services.ingestion_service import IngestionService
+    from core.main.services.management_service import ManagementService
+    from core.main.services.retrieval_service import RetrievalService
+
 
 
 class R2RProviders(BaseModel):
 class R2RProviders(BaseModel):
     auth: R2RAuthProvider | SupabaseAuthProvider
     auth: R2RAuthProvider | SupabaseAuthProvider
@@ -44,20 +69,20 @@ class R2RProviders(BaseModel):
 
 
 
 
 class R2RPipes(BaseModel):
 class R2RPipes(BaseModel):
-    parsing_pipe: AsyncPipe
-    embedding_pipe: AsyncPipe
-    kg_search_pipe: AsyncPipe
-    kg_relationships_extraction_pipe: AsyncPipe
-    kg_storage_pipe: AsyncPipe
-    kg_entity_description_pipe: AsyncPipe
-    kg_clustering_pipe: AsyncPipe
-    kg_entity_deduplication_pipe: AsyncPipe
-    kg_entity_deduplication_summary_pipe: AsyncPipe
-    kg_community_summary_pipe: AsyncPipe
-    rag_pipe: AsyncPipe
-    streaming_rag_pipe: AsyncPipe
-    vector_storage_pipe: AsyncPipe
-    vector_search_pipe: AsyncPipe
+    parsing_pipe: ParsingPipe
+    embedding_pipe: EmbeddingPipe
+    graph_search_pipe: GraphSearchSearchPipe
+    graph_extraction_pipe: GraphExtractionPipe
+    graph_storage_pipe: GraphStoragePipe
+    graph_description_pipe: GraphDescriptionPipe
+    graph_clustering_pipe: GraphClusteringPipe
+    graph_deduplication_pipe: GraphDeduplicationPipe
+    graph_deduplication_summary_pipe: GraphDeduplicationSummaryPipe
+    graph_community_summary_pipe: GraphCommunitySummaryPipe
+    rag_pipe: RAGPipe
+    streaming_rag_pipe: StreamingRAGPipe
+    vector_storage_pipe: VectorStoragePipe
+    vector_search_pipe: Any  # TODO - Fix
 
 
     class Config:
     class Config:
         arbitrary_types_allowed = True
         arbitrary_types_allowed = True
@@ -78,3 +103,12 @@ class R2RAgents(BaseModel):
 
 
     class Config:
     class Config:
         arbitrary_types_allowed = True
         arbitrary_types_allowed = True
+
+
+@dataclass
+class R2RServices:
+    auth: Optional["AuthService"] = None
+    ingestion: Optional["IngestionService"] = None
+    management: Optional["ManagementService"] = None
+    retrieval: Optional["RetrievalService"] = None
+    graph: Optional["GraphService"] = None

+ 28 - 22
core/main/api/v3/base_router.py

@@ -3,20 +3,20 @@ import logging
 from abc import abstractmethod
 from abc import abstractmethod
 from typing import Callable
 from typing import Callable
 
 
-from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 
 
 from core.base import R2RException, manage_run
 from core.base import R2RException, manage_run
 
 
+from ...abstractions import R2RProviders, R2RServices
+
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
 class BaseRouterV3:
 class BaseRouterV3:
-    def __init__(self, providers, services, orchestration_provider, run_type):
+    def __init__(self, providers: R2RProviders, services: R2RServices):
         self.providers = providers
         self.providers = providers
         self.services = services
         self.services = services
-        self.run_type = run_type
-        self.orchestration_provider = orchestration_provider
         self.router = APIRouter()
         self.router = APIRouter()
         self.openapi_extras = self._load_openapi_extras()
         self.openapi_extras = self._load_openapi_extras()
         self._setup_routes()
         self._setup_routes()
@@ -29,14 +29,11 @@ class BaseRouterV3:
         @functools.wraps(func)
         @functools.wraps(func)
         async def wrapper(*args, **kwargs):
         async def wrapper(*args, **kwargs):
             async with manage_run(
             async with manage_run(
-                self.services["ingestion"].run_manager, func.__name__
+                self.services.ingestion.run_manager, func.__name__
             ) as run_id:
             ) as run_id:
                 auth_user = kwargs.get("auth_user")
                 auth_user = kwargs.get("auth_user")
                 if auth_user:
                 if auth_user:
-                    await self.services[
-                        "ingestion"
-                    ].run_manager.log_run_info(  # TODO - this is a bit of a hack
-                        run_type=self.run_type,
+                    await self.services.ingestion.run_manager.log_run_info(  # TODO - this is a bit of a hack
                         user=auth_user,
                         user=auth_user,
                     )
                     )
 
 
@@ -93,22 +90,22 @@ class BaseRouterV3:
 import functools
 import functools
 import logging
 import logging
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Callable
+from typing import Callable, Optional
 
 
 from fastapi import APIRouter, Depends, HTTPException, Request
 from fastapi import APIRouter, Depends, HTTPException, Request
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 
 
 from core.base import R2RException, manage_run
 from core.base import R2RException, manage_run
 
 
+from ...abstractions import R2RProviders, R2RServices
+
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
 class BaseRouterV3:
 class BaseRouterV3:
-    def __init__(self, providers, services, orchestration_provider, run_type):
+    def __init__(self, providers: R2RProviders, services: R2RServices):
         self.providers = providers
         self.providers = providers
         self.services = services
         self.services = services
-        self.run_type = run_type
-        self.orchestration_provider = orchestration_provider
         self.router = APIRouter()
         self.router = APIRouter()
         self.openapi_extras = self._load_openapi_extras()
         self.openapi_extras = self._load_openapi_extras()
         self.set_rate_limiting()
         self.set_rate_limiting()
@@ -122,12 +119,11 @@ class BaseRouterV3:
         @functools.wraps(func)
         @functools.wraps(func)
         async def wrapper(*args, **kwargs):
         async def wrapper(*args, **kwargs):
             async with manage_run(
             async with manage_run(
-                self.services["ingestion"].run_manager, func.__name__
+                self.services.ingestion.run_manager, func.__name__
             ) as run_id:
             ) as run_id:
                 auth_user = kwargs.get("auth_user")
                 auth_user = kwargs.get("auth_user")
                 if auth_user:
                 if auth_user:
-                    await self.services["ingestion"].run_manager.log_run_info(
-                        run_type=self.run_type,
+                    await self.services.ingestion.run_manager.log_run_info(
                         user=auth_user,
                         user=auth_user,
                     )
                     )
 
 
@@ -186,29 +182,39 @@ class BaseRouterV3:
 
 
         async def rate_limit_dependency(
         async def rate_limit_dependency(
             request: Request,
             request: Request,
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             user_id = auth_user.id
             user_id = auth_user.id
             route = request.scope["path"]
             route = request.scope["path"]
             # Check the limits before proceeding
             # Check the limits before proceeding
             try:
             try:
-                await self.providers.database.limits_handler.check_limits(
-                    user_id, route
-                )
+                if not auth_user.is_superuser:
+                    await self.providers.database.limits_handler.check_limits(
+                        user_id, route
+                    )
             except ValueError as e:
             except ValueError as e:
                 raise HTTPException(status_code=429, detail=str(e))
                 raise HTTPException(status_code=429, detail=str(e))
 
 
             request.state.user_id = user_id
             request.state.user_id = user_id
             request.state.route = route
             request.state.route = route
-            print("in rate limit dependency....")
             # Yield to run the route
             # Yield to run the route
             try:
             try:
                 yield
                 yield
             finally:
             finally:
-                print("finally....")
                 # After the route completes successfully, log the request
                 # After the route completes successfully, log the request
                 await self.providers.database.limits_handler.log_request(
                 await self.providers.database.limits_handler.log_request(
                     user_id, route
                     user_id, route
                 )
                 )
 
 
+        async def websocket_rate_limit_dependency(
+            websocket: WebSocket,
+        ):
+            route = websocket.scope["path"]
+            try:
+                return True
+            except ValueError as e:
+                await websocket.close(code=4429, reason="Rate limit exceeded")
+                return False
+
         self.rate_limit_dependency = rate_limit_dependency
         self.rate_limit_dependency = rate_limit_dependency
+        self.websocket_rate_limit_dependency = websocket_rate_limit_dependency

+ 25 - 28
core/main/api/v3/chunks_router.py

@@ -1,20 +1,16 @@
 import json
 import json
 import logging
 import logging
 import textwrap
 import textwrap
-from copy import copy
-from typing import Any, Optional
+from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
 from core.base import (
 from core.base import (
     ChunkResponse,
     ChunkResponse,
-    ChunkSearchSettings,
     GraphSearchSettings,
     GraphSearchSettings,
     R2RException,
     R2RException,
-    RunType,
     SearchSettings,
     SearchSettings,
-    UnprocessedChunk,
     UpdateChunk,
     UpdateChunk,
     select_search_filters,
     select_search_filters,
 )
 )
@@ -29,8 +25,8 @@ from core.providers import (
     HatchetOrchestrationProvider,
     HatchetOrchestrationProvider,
     SimpleOrchestrationProvider,
     SimpleOrchestrationProvider,
 )
 )
-from core.utils import generate_id
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -41,19 +37,16 @@ MAX_CHUNKS_PER_REQUEST = 1024 * 100
 class ChunksRouter(BaseRouterV3):
 class ChunksRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.INGESTION,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.post(
         @self.router.post(
             "/chunks/search",
             "/chunks/search",
             summary="Search Chunks",
             summary="Search Chunks",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -81,7 +74,7 @@ class ChunksRouter(BaseRouterV3):
             search_settings: SearchSettings = Body(
             search_settings: SearchSettings = Body(
                 default_factory=SearchSettings,
                 default_factory=SearchSettings,
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedVectorSearchResponse:  # type: ignore
         ) -> WrappedVectorSearchResponse:  # type: ignore
             # TODO - Deduplicate this code by sharing the code on the retrieval router
             # TODO - Deduplicate this code by sharing the code on the retrieval router
             """
             """
@@ -99,7 +92,7 @@ class ChunksRouter(BaseRouterV3):
 
 
             search_settings.graph_settings = GraphSearchSettings(enabled=False)
             search_settings.graph_settings = GraphSearchSettings(enabled=False)
 
 
-            results = await self.services["retrieval"].search(
+            results = await self.services.retrieval.search(
                 query=query,
                 query=query,
                 search_settings=search_settings,
                 search_settings=search_settings,
             )
             )
@@ -108,6 +101,7 @@ class ChunksRouter(BaseRouterV3):
         @self.router.get(
         @self.router.get(
             "/chunks/{id}",
             "/chunks/{id}",
             summary="Retrieve Chunk",
             summary="Retrieve Chunk",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -147,7 +141,7 @@ class ChunksRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def retrieve_chunk(
         async def retrieve_chunk(
             id: UUID = Path(...),
             id: UUID = Path(...),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedChunkResponse:
         ) -> WrappedChunkResponse:
             """
             """
             Get a specific chunk by its ID.
             Get a specific chunk by its ID.
@@ -155,12 +149,12 @@ class ChunksRouter(BaseRouterV3):
             Returns the chunk's content, metadata, and associated document/collection information.
             Returns the chunk's content, metadata, and associated document/collection information.
             Users can only retrieve chunks they own or have access to through collections.
             Users can only retrieve chunks they own or have access to through collections.
             """
             """
-            chunk = await self.services["ingestion"].get_chunk(id)
+            chunk = await self.services.ingestion.get_chunk(id)
             if not chunk:
             if not chunk:
                 raise R2RException("Chunk not found", 404)
                 raise R2RException("Chunk not found", 404)
 
 
             # # Check access rights
             # # Check access rights
-            # document = await self.services["management"].get_document(chunk.document_id)
+            # document = await self.services.management.get_document(chunk.document_id)
             # TODO - Add collection ID check
             # TODO - Add collection ID check
             if not auth_user.is_superuser and str(auth_user.id) != str(
             if not auth_user.is_superuser and str(auth_user.id) != str(
                 chunk["owner_id"]
                 chunk["owner_id"]
@@ -180,6 +174,7 @@ class ChunksRouter(BaseRouterV3):
         @self.router.post(
         @self.router.post(
             "/chunks/{id}",
             "/chunks/{id}",
             summary="Update Chunk",
             summary="Update Chunk",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -227,7 +222,7 @@ class ChunksRouter(BaseRouterV3):
             id: UUID = Path(...),
             id: UUID = Path(...),
             chunk_update: UpdateChunk = Body(...),
             chunk_update: UpdateChunk = Body(...),
             # TODO: Run with orchestration?
             # TODO: Run with orchestration?
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedChunkResponse:
         ) -> WrappedChunkResponse:
             """
             """
             Update an existing chunk's content and/or metadata.
             Update an existing chunk's content and/or metadata.
@@ -236,7 +231,7 @@ class ChunksRouter(BaseRouterV3):
             Users can only update chunks they own unless they are superusers.
             Users can only update chunks they own unless they are superusers.
             """
             """
             # Get the existing chunk to get its chunk_id
             # Get the existing chunk to get its chunk_id
-            existing_chunk = await self.services["ingestion"].get_chunk(
+            existing_chunk = await self.services.ingestion.get_chunk(
                 chunk_update.id
                 chunk_update.id
             )
             )
             if existing_chunk is None:
             if existing_chunk is None:
@@ -256,9 +251,7 @@ class ChunksRouter(BaseRouterV3):
 
 
             # TODO - CLEAN THIS UP
             # TODO - CLEAN THIS UP
 
 
-            simple_ingestor = simple_ingestion_factory(
-                self.services["ingestion"]
-            )
+            simple_ingestor = simple_ingestion_factory(self.services.ingestion)
             await simple_ingestor["update-chunk"](workflow_input)
             await simple_ingestor["update-chunk"](workflow_input)
 
 
             return ChunkResponse(  # type: ignore
             return ChunkResponse(  # type: ignore
@@ -274,6 +267,7 @@ class ChunksRouter(BaseRouterV3):
         @self.router.delete(
         @self.router.delete(
             "/chunks/{id}",
             "/chunks/{id}",
             summary="Delete Chunk",
             summary="Delete Chunk",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -313,7 +307,7 @@ class ChunksRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def delete_chunk(
         async def delete_chunk(
             id: UUID = Path(...),
             id: UUID = Path(...),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete a specific chunk by ID.
             Delete a specific chunk by ID.
@@ -323,7 +317,7 @@ class ChunksRouter(BaseRouterV3):
             own unless they are superusers.
             own unless they are superusers.
             """
             """
             # Get the existing chunk to get its chunk_id
             # Get the existing chunk to get its chunk_id
-            existing_chunk = await self.services["ingestion"].get_chunk(id)
+            existing_chunk = await self.services.ingestion.get_chunk(id)
 
 
             if existing_chunk is None:
             if existing_chunk is None:
                 raise R2RException(
                 raise R2RException(
@@ -336,11 +330,14 @@ class ChunksRouter(BaseRouterV3):
                     {"chunk_id": {"$eq": str(id)}},
                     {"chunk_id": {"$eq": str(id)}},
                 ]
                 ]
             }
             }
-            await self.services["management"].delete(filters=filters)
+            await self.services.management.delete_documents_and_chunks_by_filter(
+                filters=filters
+            )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.get(
         @self.router.get(
             "/chunks",
             "/chunks",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List Chunks",
             summary="List Chunks",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -403,7 +400,7 @@ class ChunksRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedChunksResponse:
         ) -> WrappedChunksResponse:
             """
             """
             List chunks with pagination support.
             List chunks with pagination support.
@@ -426,7 +423,7 @@ class ChunksRouter(BaseRouterV3):
                 metadata_filter = json.loads(metadata_filter)
                 metadata_filter = json.loads(metadata_filter)
 
 
             # Get chunks using the vector handler's list_chunks method
             # Get chunks using the vector handler's list_chunks method
-            results = await self.services["ingestion"].list_chunks(
+            results = await self.services.ingestion.list_chunks(
                 filters=filters,
                 filters=filters,
                 include_vectors=include_vectors,
                 include_vectors=include_vectors,
                 offset=offset,
                 offset=offset,

+ 72 - 82
core/main/api/v3/collections_router.py

@@ -6,7 +6,7 @@ from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
-from core.base import KGCreationSettings, KGRunType, R2RException, RunType
+from core.base import KGCreationSettings, KGRunType, R2RException
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     WrappedBooleanResponse,
     WrappedBooleanResponse,
@@ -16,12 +16,9 @@ from core.base.api.models import (
     WrappedGenericMessageResponse,
     WrappedGenericMessageResponse,
     WrappedUsersResponse,
     WrappedUsersResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 from core.utils import update_settings_from_dict
 from core.utils import update_settings_from_dict
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -59,7 +56,7 @@ async def authorize_collection_action(
 
 
     # Fetch collection details: owner_id and members
     # Fetch collection details: owner_id and members
     results = (
     results = (
-        await services["management"].collections_overview(
+        await services.management.collections_overview(
             0, 1, collection_ids=[collection_id]
             0, 1, collection_ids=[collection_id]
         )
         )
     )["results"]
     )["results"]
@@ -88,21 +85,14 @@ async def authorize_collection_action(
 
 
 
 
 class CollectionsRouter(BaseRouterV3):
 class CollectionsRouter(BaseRouterV3):
-    def __init__(
-        self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.MANAGEMENT,
-    ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+    def __init__(self, providers: R2RProviders, services: R2RServices):
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.post(
         @self.router.post(
             "/collections",
             "/collections",
             summary="Create a new collection",
             summary="Create a new collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -168,7 +158,7 @@ class CollectionsRouter(BaseRouterV3):
             description: Optional[str] = Body(
             description: Optional[str] = Body(
                 None, description="An optional description of the collection"
                 None, description="An optional description of the collection"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionResponse:
         ) -> WrappedCollectionResponse:
             """
             """
             Create a new collection and automatically add the creating user to it.
             Create a new collection and automatically add the creating user to it.
@@ -176,13 +166,13 @@ class CollectionsRouter(BaseRouterV3):
             This endpoint allows authenticated users to create a new collection with a specified name
             This endpoint allows authenticated users to create a new collection with a specified name
             and optional description. The user creating the collection is automatically added as a member.
             and optional description. The user creating the collection is automatically added as a member.
             """
             """
-            collection = await self.services["management"].create_collection(
+            collection = await self.services.management.create_collection(
                 owner_id=auth_user.id,
                 owner_id=auth_user.id,
                 name=name,
                 name=name,
                 description=description,
                 description=description,
             )
             )
             # Add the creating user to the collection
             # Add the creating user to the collection
-            await self.services["management"].add_user_to_collection(
+            await self.services.management.add_user_to_collection(
                 auth_user.id, collection.id
                 auth_user.id, collection.id
             )
             )
             return collection
             return collection
@@ -190,6 +180,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.get(
         @self.router.get(
             "/collections",
             "/collections",
             summary="List collections",
             summary="List collections",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -261,7 +252,7 @@ class CollectionsRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionsResponse:
         ) -> WrappedCollectionsResponse:
             """
             """
             Returns a paginated list of collections the authenticated user has access to.
             Returns a paginated list of collections the authenticated user has access to.
@@ -277,13 +268,13 @@ class CollectionsRouter(BaseRouterV3):
 
 
             collection_uuids = [UUID(collection_id) for collection_id in ids]
             collection_uuids = [UUID(collection_id) for collection_id in ids]
 
 
-            collections_overview_response = await self.services[
-                "management"
-            ].collections_overview(
-                user_ids=requesting_user_id,
-                collection_ids=collection_uuids,
-                offset=offset,
-                limit=limit,
+            collections_overview_response = (
+                await self.services.management.collections_overview(
+                    user_ids=requesting_user_id,
+                    collection_ids=collection_uuids,
+                    offset=offset,
+                    limit=limit,
+                )
             )
             )
 
 
             return (  # type: ignore
             return (  # type: ignore
@@ -298,6 +289,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.get(
         @self.router.get(
             "/collections/{id}",
             "/collections/{id}",
             summary="Get collection details",
             summary="Get collection details",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -354,7 +346,7 @@ class CollectionsRouter(BaseRouterV3):
             id: UUID = Path(
             id: UUID = Path(
                 ..., description="The unique identifier of the collection"
                 ..., description="The unique identifier of the collection"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionResponse:
         ) -> WrappedCollectionResponse:
             """
             """
             Get details of a specific collection.
             Get details of a specific collection.
@@ -366,13 +358,13 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.VIEW, self.services
                 auth_user, id, CollectionAction.VIEW, self.services
             )
             )
 
 
-            collections_overview_response = await self.services[
-                "management"
-            ].collections_overview(
-                user_ids=None,
-                collection_ids=[id],
-                offset=0,
-                limit=1,
+            collections_overview_response = (
+                await self.services.management.collections_overview(
+                    user_ids=None,
+                    collection_ids=[id],
+                    offset=0,
+                    limit=1,
+                )
             )
             )
             overview = collections_overview_response["results"]
             overview = collections_overview_response["results"]
 
 
@@ -386,6 +378,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.post(
         @self.router.post(
             "/collections/{id}",
             "/collections/{id}",
             summary="Update collection",
             summary="Update collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -455,7 +448,7 @@ class CollectionsRouter(BaseRouterV3):
                 False,
                 False,
                 description="Whether to generate a new synthetic description for the collection",
                 description="Whether to generate a new synthetic description for the collection",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionResponse:
         ) -> WrappedCollectionResponse:
             """
             """
             Update an existing collection's configuration.
             Update an existing collection's configuration.
@@ -473,7 +466,7 @@ class CollectionsRouter(BaseRouterV3):
                     400,
                     400,
                 )
                 )
 
 
-            return await self.services["management"].update_collection(  # type: ignore
+            return await self.services.management.update_collection(  # type: ignore
                 id,
                 id,
                 name=name,
                 name=name,
                 description=description,
                 description=description,
@@ -483,6 +476,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.delete(
         @self.router.delete(
             "/collections/{id}",
             "/collections/{id}",
             summary="Delete collection",
             summary="Delete collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -540,7 +534,7 @@ class CollectionsRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The unique identifier of the collection to delete",
                 description="The unique identifier of the collection to delete",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete an existing collection.
             Delete an existing collection.
@@ -553,12 +547,13 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.DELETE, self.services
                 auth_user, id, CollectionAction.DELETE, self.services
             )
             )
 
 
-            await self.services["management"].delete_collection(id)
+            await self.services.management.delete_collection(id)
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/collections/{id}/documents/{document_id}",
             "/collections/{id}/documents/{document_id}",
             summary="Add document to collection",
             summary="Add document to collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -612,7 +607,7 @@ class CollectionsRouter(BaseRouterV3):
         async def add_document_to_collection(
         async def add_document_to_collection(
             id: UUID = Path(...),
             id: UUID = Path(...),
             document_id: UUID = Path(...),
             document_id: UUID = Path(...),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Add a document to a collection.
             Add a document to a collection.
@@ -621,13 +616,16 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.ADD_DOCUMENT, self.services
                 auth_user, id, CollectionAction.ADD_DOCUMENT, self.services
             )
             )
 
 
-            return await self.services[
-                "management"
-            ].assign_document_to_collection(document_id, id)
+            return (
+                await self.services.management.assign_document_to_collection(
+                    document_id, id
+                )
+            )
 
 
         @self.router.get(
         @self.router.get(
             "/collections/{id}/documents",
             "/collections/{id}/documents",
             summary="List documents in collection",
             summary="List documents in collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -699,7 +697,7 @@ class CollectionsRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedDocumentsResponse:
         ) -> WrappedDocumentsResponse:
             """
             """
             Get all documents in a collection with pagination and sorting options.
             Get all documents in a collection with pagination and sorting options.
@@ -711,9 +709,11 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.VIEW, self.services
                 auth_user, id, CollectionAction.VIEW, self.services
             )
             )
 
 
-            documents_in_collection_response = await self.services[
-                "management"
-            ].documents_in_collection(id, offset, limit)
+            documents_in_collection_response = (
+                await self.services.management.documents_in_collection(
+                    id, offset, limit
+                )
+            )
 
 
             return documents_in_collection_response["results"], {  # type: ignore
             return documents_in_collection_response["results"], {  # type: ignore
                 "total_entries": documents_in_collection_response[
                 "total_entries": documents_in_collection_response[
@@ -724,6 +724,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.delete(
         @self.router.delete(
             "/collections/{id}/documents/{document_id}",
             "/collections/{id}/documents/{document_id}",
             summary="Remove document from collection",
             summary="Remove document from collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -782,7 +783,7 @@ class CollectionsRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The unique identifier of the document to remove",
                 description="The unique identifier of the document to remove",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Remove a document from a collection.
             Remove a document from a collection.
@@ -793,7 +794,7 @@ class CollectionsRouter(BaseRouterV3):
             await authorize_collection_action(
             await authorize_collection_action(
                 auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services
                 auth_user, id, CollectionAction.REMOVE_DOCUMENT, self.services
             )
             )
-            await self.services["management"].remove_document_from_collection(
+            await self.services.management.remove_document_from_collection(
                 document_id, id
                 document_id, id
             )
             )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
@@ -801,6 +802,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.get(
         @self.router.get(
             "/collections/{id}/users",
             "/collections/{id}/users",
             summary="List users in collection",
             summary="List users in collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -874,7 +876,7 @@ class CollectionsRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUsersResponse:
         ) -> WrappedUsersResponse:
             """
             """
             Get all users in a collection with pagination and sorting options.
             Get all users in a collection with pagination and sorting options.
@@ -886,12 +888,12 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.VIEW, self.services
                 auth_user, id, CollectionAction.VIEW, self.services
             )
             )
 
 
-            users_in_collection_response = await self.services[
-                "management"
-            ].get_users_in_collection(
-                collection_id=id,
-                offset=offset,
-                limit=min(max(limit, 1), 1000),
+            users_in_collection_response = (
+                await self.services.management.get_users_in_collection(
+                    collection_id=id,
+                    offset=offset,
+                    limit=min(max(limit, 1), 1000),
+                )
             )
             )
 
 
             return users_in_collection_response["results"], {  # type: ignore
             return users_in_collection_response["results"], {  # type: ignore
@@ -901,6 +903,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.post(
         @self.router.post(
             "/collections/{id}/users/{user_id}",
             "/collections/{id}/users/{user_id}",
             summary="Add user to collection",
             summary="Add user to collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -958,7 +961,7 @@ class CollectionsRouter(BaseRouterV3):
             user_id: UUID = Path(
             user_id: UUID = Path(
                 ..., description="The unique identifier of the user to add"
                 ..., description="The unique identifier of the user to add"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Add a user to a collection.
             Add a user to a collection.
@@ -970,7 +973,7 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.MANAGE_USERS, self.services
                 auth_user, id, CollectionAction.MANAGE_USERS, self.services
             )
             )
 
 
-            result = await self.services["management"].add_user_to_collection(
+            result = await self.services.management.add_user_to_collection(
                 user_id, id
                 user_id, id
             )
             )
             return GenericBooleanResponse(success=result)  # type: ignore
             return GenericBooleanResponse(success=result)  # type: ignore
@@ -978,6 +981,7 @@ class CollectionsRouter(BaseRouterV3):
         @self.router.delete(
         @self.router.delete(
             "/collections/{id}/users/{user_id}",
             "/collections/{id}/users/{user_id}",
             summary="Remove user from collection",
             summary="Remove user from collection",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -1035,7 +1039,7 @@ class CollectionsRouter(BaseRouterV3):
             user_id: UUID = Path(
             user_id: UUID = Path(
                 ..., description="The unique identifier of the user to remove"
                 ..., description="The unique identifier of the user to remove"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Remove a user from a collection.
             Remove a user from a collection.
@@ -1047,15 +1051,17 @@ class CollectionsRouter(BaseRouterV3):
                 auth_user, id, CollectionAction.MANAGE_USERS, self.services
                 auth_user, id, CollectionAction.MANAGE_USERS, self.services
             )
             )
 
 
-            result = await self.services[
-                "management"
-            ].remove_user_from_collection(user_id, id)
-            print("result = ", result)
+            result = (
+                await self.services.management.remove_user_from_collection(
+                    user_id, id
+                )
+            )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/collections/{id}/extract",
             "/collections/{id}/extract",
             summary="Extract entities and relationships",
             summary="Extract entities and relationships",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -1094,7 +1100,7 @@ class CollectionsRouter(BaseRouterV3):
                 default=True,
                 default=True,
                 description="Whether to run the entities and relationships extraction process with orchestration.",
                 description="Whether to run the entities and relationships extraction process with orchestration.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             """
             """
             Extracts entities and relationships from a document.
             Extracts entities and relationships from a document.
@@ -1124,22 +1130,6 @@ class CollectionsRouter(BaseRouterV3):
                     server_settings=server_graph_creation_settings,
                     server_settings=server_graph_creation_settings,
                     settings_dict=settings,  # type: ignore
                     settings_dict=settings,  # type: ignore
                 )
                 )
-
-            # If the run type is estimate, return an estimate of the creation cost
-            # if run_type is KGRunType.ESTIMATE:
-            #     return {  # type: ignore
-            #         "message": "Estimate retrieved successfully",
-            #         "task_id": None,
-            #         "id": id,
-            #         "estimate": await self.services[
-            #             "kg"
-            #         ].get_creation_estimate(
-            #             document_id=id,
-            #             graph_creation_settings=server_graph_creation_settings,
-            #         ),
-            #     }
-            # else:
-            # Otherwise, create the graph
             if run_with_orchestration:
             if run_with_orchestration:
                 workflow_input = {
                 workflow_input = {
                     "collection_id": str(id),
                     "collection_id": str(id),
@@ -1147,14 +1137,14 @@ class CollectionsRouter(BaseRouterV3):
                     "user": auth_user.json(),
                     "user": auth_user.json(),
                 }
                 }
 
 
-                return await self.orchestration_provider.run_workflow(  # type: ignore
+                return await self.providers.orchestration.run_workflow(  # type: ignore
                     "extract-triples", {"request": workflow_input}, {}
                     "extract-triples", {"request": workflow_input}, {}
                 )
                 )
             else:
             else:
                 from core.main.orchestration import simple_kg_factory
                 from core.main.orchestration import simple_kg_factory
 
 
                 logger.info("Running extract-triples without orchestration.")
                 logger.info("Running extract-triples without orchestration.")
-                simple_kg = simple_kg_factory(self.services["kg"])
+                simple_kg = simple_kg_factory(self.services.graph)
                 await simple_kg["extract-triples"](workflow_input)  # type: ignore
                 await simple_kg["extract-triples"](workflow_input)  # type: ignore
                 return {  # type: ignore
                 return {  # type: ignore
                     "message": "Graph created successfully.",
                     "message": "Graph created successfully.",

+ 148 - 40
core/main/api/v3/conversations_router.py

@@ -5,7 +5,7 @@ from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
-from core.base import Message, R2RException, RunType
+from core.base import Message, R2RException
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     WrappedBooleanResponse,
     WrappedBooleanResponse,
@@ -14,11 +14,8 @@ from core.base.api.models import (
     WrappedConversationsResponse,
     WrappedConversationsResponse,
     WrappedMessageResponse,
     WrappedMessageResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -27,19 +24,16 @@ logger = logging.getLogger()
 class ConversationsRouter(BaseRouterV3):
 class ConversationsRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.MANAGEMENT,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.post(
         @self.router.post(
             "/conversations",
             "/conversations",
             summary="Create a new conversation",
             summary="Create a new conversation",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -93,18 +87,27 @@ class ConversationsRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def create_conversation(
         async def create_conversation(
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            name: Optional[str] = Body(
+                None, description="The name of the conversation", embed=True
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedConversationResponse:
         ) -> WrappedConversationResponse:
             """
             """
             Create a new conversation.
             Create a new conversation.
 
 
             This endpoint initializes a new conversation for the authenticated user.
             This endpoint initializes a new conversation for the authenticated user.
             """
             """
-            return await self.services["management"].create_conversation()
+            user_id = auth_user.id
+
+            return await self.services.management.create_conversation(
+                user_id=user_id,
+                name=name,
+            )
 
 
         @self.router.get(
         @self.router.get(
             "/conversations",
             "/conversations",
             summary="List conversations",
             summary="List conversations",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -176,23 +179,28 @@ class ConversationsRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedConversationsResponse:
         ) -> WrappedConversationsResponse:
             """
             """
             List conversations with pagination and sorting options.
             List conversations with pagination and sorting options.
 
 
             This endpoint returns a paginated list of conversations for the authenticated user.
             This endpoint returns a paginated list of conversations for the authenticated user.
             """
             """
+            requesting_user_id = (
+                None if auth_user.is_superuser else [auth_user.id]
+            )
+
             conversation_uuids = [
             conversation_uuids = [
                 UUID(conversation_id) for conversation_id in ids
                 UUID(conversation_id) for conversation_id in ids
             ]
             ]
 
 
-            conversations_response = await self.services[
-                "management"
-            ].conversations_overview(
-                conversation_ids=conversation_uuids,
-                offset=offset,
-                limit=limit,
+            conversations_response = (
+                await self.services.management.conversations_overview(
+                    offset=offset,
+                    limit=limit,
+                    conversation_ids=conversation_uuids,
+                    user_ids=requesting_user_id,
+                )
             )
             )
             return conversations_response["results"], {  # type: ignore
             return conversations_response["results"], {  # type: ignore
                 "total_entries": conversations_response["total_entries"]
                 "total_entries": conversations_response["total_entries"]
@@ -201,6 +209,7 @@ class ConversationsRouter(BaseRouterV3):
         @self.router.get(
         @self.router.get(
             "/conversations/{id}",
             "/conversations/{id}",
             summary="Get conversation details",
             summary="Get conversation details",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -261,21 +270,110 @@ class ConversationsRouter(BaseRouterV3):
             id: UUID = Path(
             id: UUID = Path(
                 ..., description="The unique identifier of the conversation"
                 ..., description="The unique identifier of the conversation"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedConversationMessagesResponse:
         ) -> WrappedConversationMessagesResponse:
             """
             """
             Get details of a specific conversation.
             Get details of a specific conversation.
 
 
             This endpoint retrieves detailed information about a single conversation identified by its UUID.
             This endpoint retrieves detailed information about a single conversation identified by its UUID.
             """
             """
-            conversation = await self.services["management"].get_conversation(
-                str(id)
+            requesting_user_id = (
+                None if auth_user.is_superuser else [auth_user.id]
+            )
+
+            conversation = await self.services.management.get_conversation(
+                conversation_id=id,
+                user_ids=requesting_user_id,
             )
             )
             return conversation
             return conversation
 
 
+        @self.router.post(
+            "/conversations/{id}",
+            summary="Update conversation",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # when using auth, do client.login(...)
+
+                            result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "JavaScript",
+                        "source": textwrap.dedent(
+                            """
+                            const { r2rClient } = require("r2r-js");
+
+                            const client = new r2rClient("http://localhost:7272");
+
+                            function main() {
+                                const response = await client.conversations.update({
+                                    id: "123e4567-e89b-12d3-a456-426614174000",
+                                    name: "new_name",
+                                });
+                            }
+
+                            main();
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "CLI",
+                        "source": textwrap.dedent(
+                            """
+                            r2r conversations delete 123e4567-e89b-12d3-a456-426614174000
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \
+                                -H "Authorization: Bearer YOUR_API_KEY" \
+                                -H "Content-Type: application/json" \
+                                -d '{"name": "new_name"}'
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def update_conversation(
+            id: UUID = Path(
+                ...,
+                description="The unique identifier of the conversation to delete",
+            ),
+            name: str = Body(
+                ...,
+                description="The updated name for the conversation",
+                embed=True,
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> WrappedConversationResponse:
+            """
+            Update an existing conversation.
+
+            This endpoint updates the name of an existing conversation identified by its UUID.
+            """
+            return await self.services.management.update_conversation(
+                conversation_id=id,
+                name=name,
+            )
+
         @self.router.delete(
         @self.router.delete(
             "/conversations/{id}",
             "/conversations/{id}",
             summary="Delete conversation",
             summary="Delete conversation",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -335,19 +433,27 @@ class ConversationsRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The unique identifier of the conversation to delete",
                 description="The unique identifier of the conversation to delete",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete an existing conversation.
             Delete an existing conversation.
 
 
             This endpoint deletes a conversation identified by its UUID.
             This endpoint deletes a conversation identified by its UUID.
             """
             """
-            await self.services["management"].delete_conversation(str(id))
+            requesting_user_id = (
+                None if auth_user.is_superuser else [auth_user.id]
+            )
+
+            await self.services.management.delete_conversation(
+                conversation_id=id,
+                user_ids=requesting_user_id,
+            )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/conversations/{id}/messages",
             "/conversations/{id}/messages",
             summary="Add message to conversation",
             summary="Add message to conversation",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -415,13 +521,13 @@ class ConversationsRouter(BaseRouterV3):
             role: str = Body(
             role: str = Body(
                 ..., description="The role of the message to add"
                 ..., description="The role of the message to add"
             ),
             ),
-            parent_id: Optional[str] = Body(
+            parent_id: Optional[UUID] = Body(
                 None, description="The ID of the parent message, if any"
                 None, description="The ID of the parent message, if any"
             ),
             ),
             metadata: Optional[dict[str, str]] = Body(
             metadata: Optional[dict[str, str]] = Body(
                 None, description="Additional metadata for the message"
                 None, description="Additional metadata for the message"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedMessageResponse:
         ) -> WrappedMessageResponse:
             """
             """
             Add a new message to a conversation.
             Add a new message to a conversation.
@@ -433,16 +539,17 @@ class ConversationsRouter(BaseRouterV3):
             if role not in ["user", "assistant", "system"]:
             if role not in ["user", "assistant", "system"]:
                 raise R2RException("Invalid role", status_code=400)
                 raise R2RException("Invalid role", status_code=400)
             message = Message(role=role, content=content)
             message = Message(role=role, content=content)
-            return await self.services["management"].add_message(
-                str(id),
-                message,
-                parent_id,
-                metadata,
+            return await self.services.management.add_message(
+                conversation_id=id,
+                content=message,
+                parent_id=parent_id,
+                metadata=metadata,
             )
             )
 
 
         @self.router.post(
         @self.router.post(
             "/conversations/{id}/messages/{message_id}",
             "/conversations/{id}/messages/{message_id}",
             summary="Update message in conversation",
             summary="Update message in conversation",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -501,23 +608,24 @@ class ConversationsRouter(BaseRouterV3):
             id: UUID = Path(
             id: UUID = Path(
                 ..., description="The unique identifier of the conversation"
                 ..., description="The unique identifier of the conversation"
             ),
             ),
-            message_id: str = Path(
+            message_id: UUID = Path(
                 ..., description="The ID of the message to update"
                 ..., description="The ID of the message to update"
             ),
             ),
-            content: str = Body(
-                ..., description="The new content for the message"
+            content: Optional[str] = Body(
+                None, description="The new content for the message"
             ),
             ),
             metadata: Optional[dict[str, str]] = Body(
             metadata: Optional[dict[str, str]] = Body(
                 None, description="Additional metadata for the message"
                 None, description="Additional metadata for the message"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedMessageResponse:
         ) -> WrappedMessageResponse:
             """
             """
             Update an existing message in a conversation.
             Update an existing message in a conversation.
 
 
             This endpoint updates the content of an existing message in a conversation.
             This endpoint updates the content of an existing message in a conversation.
             """
             """
-            messge_response = await self.services["management"].edit_message(
-                message_id, content, metadata
+            return await self.services.management.edit_message(
+                message_id=message_id,
+                new_content=content,
+                additional_metadata=metadata,
             )
             )
-            return messge_response

+ 215 - 140
core/main/api/v3/documents_router.py

@@ -15,7 +15,6 @@ from core.base import (
     IngestionConfig,
     IngestionConfig,
     IngestionMode,
     IngestionMode,
     R2RException,
     R2RException,
-    RunType,
     SearchMode,
     SearchMode,
     SearchSettings,
     SearchSettings,
     UnprocessedChunk,
     UnprocessedChunk,
@@ -37,12 +36,9 @@ from core.base.api.models import (
     WrappedIngestionResponse,
     WrappedIngestionResponse,
     WrappedRelationshipsResponse,
     WrappedRelationshipsResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 from core.utils import update_settings_from_dict
 from core.utils import update_settings_from_dict
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -80,14 +76,10 @@ def merge_ingestion_config(
 class DocumentsRouter(BaseRouterV3):
 class DocumentsRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.INGESTION,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
         self._register_workflows()
         self._register_workflows()
 
 
     def _prepare_search_settings(
     def _prepare_search_settings(
@@ -122,48 +114,48 @@ class DocumentsRouter(BaseRouterV3):
 
 
     # TODO - Remove this legacy method
     # TODO - Remove this legacy method
     def _register_workflows(self):
     def _register_workflows(self):
-        self.orchestration_provider.register_workflows(
+        self.providers.orchestration.register_workflows(
             Workflow.INGESTION,
             Workflow.INGESTION,
-            self.services["ingestion"],
+            self.services.ingestion,
             {
             {
                 "ingest-files": (
                 "ingest-files": (
                     "Ingest files task queued successfully."
                     "Ingest files task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Document created and ingested successfully."
                     else "Document created and ingested successfully."
                 ),
                 ),
                 "ingest-chunks": (
                 "ingest-chunks": (
                     "Ingest chunks task queued successfully."
                     "Ingest chunks task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Document created and ingested successfully."
                     else "Document created and ingested successfully."
                 ),
                 ),
                 "update-files": (
                 "update-files": (
                     "Update file task queued successfully."
                     "Update file task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Update task queued successfully."
                     else "Update task queued successfully."
                 ),
                 ),
                 "update-chunk": (
                 "update-chunk": (
                     "Update chunk task queued successfully."
                     "Update chunk task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Chunk update completed successfully."
                     else "Chunk update completed successfully."
                 ),
                 ),
                 "update-document-metadata": (
                 "update-document-metadata": (
                     "Update document metadata task queued successfully."
                     "Update document metadata task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Document metadata update completed successfully."
                     else "Document metadata update completed successfully."
                 ),
                 ),
                 "create-vector-index": (
                 "create-vector-index": (
                     "Vector index creation task queued successfully."
                     "Vector index creation task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Vector index creation task completed successfully."
                     else "Vector index creation task completed successfully."
                 ),
                 ),
                 "delete-vector-index": (
                 "delete-vector-index": (
                     "Vector index deletion task queued successfully."
                     "Vector index deletion task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Vector index deletion task completed successfully."
                     else "Vector index deletion task completed successfully."
                 ),
                 ),
                 "select-vector-index": (
                 "select-vector-index": (
                     "Vector index selection task queued successfully."
                     "Vector index selection task queued successfully."
-                    if self.orchestration_provider.config.provider != "simple"
+                    if self.providers.orchestration.config.provider != "simple"
                     else "Vector index selection task completed successfully."
                     else "Vector index selection task completed successfully."
                 ),
                 ),
             },
             },
@@ -195,6 +187,7 @@ class DocumentsRouter(BaseRouterV3):
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.post(
         @self.router.post(
             "/documents",
             "/documents",
+            dependencies=[Depends(self.rate_limit_dependency)],
             status_code=202,
             status_code=202,
             summary="Create a new document",
             summary="Create a new document",
             openapi_extra={
             openapi_extra={
@@ -304,7 +297,7 @@ class DocumentsRouter(BaseRouterV3):
                 True,
                 True,
                 description="Whether or not ingestion runs with orchestration, default is `True`. When set to `False`, the ingestion process will run synchronous and directly return the result.",
                 description="Whether or not ingestion runs with orchestration, default is `True`. When set to `False`, the ingestion process will run synchronous and directly return the result.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedIngestionResponse:
         ) -> WrappedIngestionResponse:
             """
             """
             Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines
             Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines
@@ -320,6 +313,63 @@ class DocumentsRouter(BaseRouterV3):
             The ingestion process runs asynchronously and its progress can be tracked using the returned
             The ingestion process runs asynchronously and its progress can be tracked using the returned
             task_id.
             task_id.
             """
             """
+            if not auth_user.is_superuser:
+                user_document_count = (
+                    await self.services.management.documents_overview(
+                        user_ids=[auth_user.id],
+                        offset=0,
+                        limit=1,
+                    )
+                )["total_entries"]
+                user_max_documents = (
+                    await self.services.management.get_user_max_documents(
+                        auth_user.id
+                    )
+                )
+
+                if user_document_count >= user_max_documents:
+                    raise R2RException(
+                        status_code=403,
+                        message=f"User has reached the maximum number of documents allowed ({user_max_documents}).",
+                    )
+
+                # Get chunks using the vector handler's list_chunks method
+                user_chunk_count = (
+                    await self.services.ingestion.list_chunks(
+                        filters={"owner_id": {"$eq": str(auth_user.id)}},
+                        offset=0,
+                        limit=1,
+                    )
+                )["page_info"]["total_entries"]
+                user_max_chunks = (
+                    await self.services.management.get_user_max_chunks(
+                        auth_user.id
+                    )
+                )
+                if user_chunk_count >= user_max_chunks:
+                    raise R2RException(
+                        status_code=403,
+                        message=f"User has reached the maximum number of chunks allowed ({user_max_chunks}).",
+                    )
+
+                user_collections_count = (
+                    await self.services.management.collections_overview(
+                        user_ids=[auth_user.id],
+                        offset=0,
+                        limit=1,
+                    )
+                )["total_entries"]
+                user_max_collections = (
+                    await self.services.management.get_user_max_collections(
+                        auth_user.id
+                    )
+                )
+                if user_collections_count >= user_max_collections:
+                    raise R2RException(
+                        status_code=403,
+                        message=f"User has reached the maximum number of collections allowed ({user_max_collections}).",
+                    )
+
             effective_ingestion_config = self._prepare_ingestion_config(
             effective_ingestion_config = self._prepare_ingestion_config(
                 ingestion_mode=ingestion_mode,
                 ingestion_mode=ingestion_mode,
                 ingestion_config=ingestion_config,
                 ingestion_config=ingestion_config,
@@ -351,13 +401,15 @@ class DocumentsRouter(BaseRouterV3):
                         400,
                         400,
                     )
                     )
                 document_id = generate_document_id(
                 document_id = generate_document_id(
-                    str(json.dumps(chunks)), auth_user.id
+                    json.dumps(chunks), auth_user.id
                 )
                 )
 
 
                 # FIXME: Metadata doesn't seem to be getting passed through
                 # FIXME: Metadata doesn't seem to be getting passed through
                 raw_chunks_for_doc = [
                 raw_chunks_for_doc = [
                     UnprocessedChunk(
                     UnprocessedChunk(
-                        text=chunk, metadata=metadata, id=generate_id()
+                        text=chunk,
+                        metadata=metadata,
+                        id=generate_id(),
                     )
                     )
                     for chunk in chunks
                     for chunk in chunks
                 ]
                 ]
@@ -366,11 +418,14 @@ class DocumentsRouter(BaseRouterV3):
                 workflow_input = {
                 workflow_input = {
                     "document_id": str(document_id),
                     "document_id": str(document_id),
                     "chunks": [
                     "chunks": [
-                        chunk.model_dump() for chunk in raw_chunks_for_doc
+                        chunk.model_dump(mode="json")
+                        for chunk in raw_chunks_for_doc
                     ],
                     ],
                     "metadata": metadata,  # Base metadata for the document
                     "metadata": metadata,  # Base metadata for the document
                     "user": auth_user.model_dump_json(),
                     "user": auth_user.model_dump_json(),
-                    "ingestion_config": effective_ingestion_config.model_dump(),
+                    "ingestion_config": effective_ingestion_config.model_dump(
+                        mode="json"
+                    ),
                 }
                 }
 
 
                 # TODO - Modify create_chunks so that we can add chunks to existing document
                 # TODO - Modify create_chunks so that we can add chunks to existing document
@@ -378,7 +433,7 @@ class DocumentsRouter(BaseRouterV3):
                 if run_with_orchestration:
                 if run_with_orchestration:
                     # Run ingestion with orchestration
                     # Run ingestion with orchestration
                     raw_message = (
                     raw_message = (
-                        await self.orchestration_provider.run_workflow(
+                        await self.providers.orchestration.run_workflow(
                             "ingest-chunks",
                             "ingest-chunks",
                             {"request": workflow_input},
                             {"request": workflow_input},
                             options={
                             options={
@@ -400,7 +455,7 @@ class DocumentsRouter(BaseRouterV3):
                     )
                     )
 
 
                     simple_ingestor = simple_ingestion_factory(
                     simple_ingestor = simple_ingestion_factory(
-                        self.services["ingestion"]
+                        self.services.ingestion
                     )
                     )
                     await simple_ingestor["ingest-chunks"](workflow_input)
                     await simple_ingestor["ingest-chunks"](workflow_input)
 
 
@@ -447,7 +502,9 @@ class DocumentsRouter(BaseRouterV3):
                     else None
                     else None
                 ),
                 ),
                 "metadata": metadata,
                 "metadata": metadata,
-                "ingestion_config": effective_ingestion_config.model_dump(),
+                "ingestion_config": effective_ingestion_config.model_dump(
+                    mode="json"
+                ),
                 "user": auth_user.model_dump_json(),
                 "user": auth_user.model_dump_json(),
                 "size_in_bytes": content_length,
                 "size_in_bytes": content_length,
             }
             }
@@ -461,7 +518,7 @@ class DocumentsRouter(BaseRouterV3):
             )
             )
 
 
             if run_with_orchestration:
             if run_with_orchestration:
-                raw_message: dict[str, str | None] = await self.orchestration_provider.run_workflow(  # type: ignore
+                raw_message: dict[str, str | None] = await self.providers.orchestration.run_workflow(  # type: ignore
                     "ingest-files",
                     "ingest-files",
                     {"request": workflow_input},
                     {"request": workflow_input},
                     options={
                     options={
@@ -480,7 +537,7 @@ class DocumentsRouter(BaseRouterV3):
                 from core.main.orchestration import simple_ingestion_factory
                 from core.main.orchestration import simple_ingestion_factory
 
 
                 simple_ingestor = simple_ingestion_factory(
                 simple_ingestor = simple_ingestion_factory(
-                    self.services["ingestion"]
+                    self.services.ingestion
                 )
                 )
                 await simple_ingestor["ingest-files"](workflow_input)
                 await simple_ingestor["ingest-files"](workflow_input)
                 return {  # type: ignore
                 return {  # type: ignore
@@ -491,6 +548,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents",
             "/documents",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List documents",
             summary="List documents",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -570,7 +628,7 @@ class DocumentsRouter(BaseRouterV3):
                 False,
                 False,
                 description="Specifies whether or not to include embeddings of each document summary.",
                 description="Specifies whether or not to include embeddings of each document summary.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedDocumentsResponse:
         ) -> WrappedDocumentsResponse:
             """
             """
             Returns a paginated list of documents the authenticated user has access to.
             Returns a paginated list of documents the authenticated user has access to.
@@ -588,14 +646,14 @@ class DocumentsRouter(BaseRouterV3):
             )
             )
 
 
             document_uuids = [UUID(document_id) for document_id in ids]
             document_uuids = [UUID(document_id) for document_id in ids]
-            documents_overview_response = await self.services[
-                "management"
-            ].documents_overview(
-                user_ids=requesting_user_id,
-                collection_ids=filter_collection_ids,
-                document_ids=document_uuids,
-                offset=offset,
-                limit=limit,
+            documents_overview_response = (
+                await self.services.management.documents_overview(
+                    user_ids=requesting_user_id,
+                    collection_ids=filter_collection_ids,
+                    document_ids=document_uuids,
+                    offset=offset,
+                    limit=limit,
+                )
             )
             )
             if not include_summary_embeddings:
             if not include_summary_embeddings:
                 for document in documents_overview_response["results"]:
                 for document in documents_overview_response["results"]:
@@ -612,6 +670,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}",
             "/documents/{id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Retrieve a document",
             summary="Retrieve a document",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -674,7 +733,7 @@ class DocumentsRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The ID of the document to retrieve.",
                 description="The ID of the document to retrieve.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedDocumentResponse:
         ) -> WrappedDocumentResponse:
             """
             """
             Retrieves detailed information about a specific document by its ID.
             Retrieves detailed information about a specific document by its ID.
@@ -692,9 +751,7 @@ class DocumentsRouter(BaseRouterV3):
                 None if auth_user.is_superuser else auth_user.collection_ids
                 None if auth_user.is_superuser else auth_user.collection_ids
             )
             )
 
 
-            documents_overview_response = await self.services[
-                "management"
-            ].documents_overview(  # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+            documents_overview_response = await self.services.management.documents_overview(  # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
                 user_ids=request_user_ids,
                 user_ids=request_user_ids,
                 collection_ids=filter_collection_ids,
                 collection_ids=filter_collection_ids,
                 document_ids=[id],
                 document_ids=[id],
@@ -709,6 +766,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}/chunks",
             "/documents/{id}/chunks",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List document chunks",
             summary="List document chunks",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -786,7 +844,7 @@ class DocumentsRouter(BaseRouterV3):
                 False,
                 False,
                 description="Whether to include vector embeddings in the response.",
                 description="Whether to include vector embeddings in the response.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedChunksResponse:
         ) -> WrappedChunksResponse:
             """
             """
             Retrieves the text chunks that were generated from a document during ingestion.
             Retrieves the text chunks that were generated from a document during ingestion.
@@ -799,9 +857,11 @@ class DocumentsRouter(BaseRouterV3):
             Results are returned in chunk sequence order, representing their position in
             Results are returned in chunk sequence order, representing their position in
             the original document.
             the original document.
             """
             """
-            list_document_chunks = await self.services[
-                "management"
-            ].list_document_chunks(id, offset, limit, include_vectors)
+            list_document_chunks = (
+                await self.services.management.list_document_chunks(
+                    id, offset, limit, include_vectors
+                )
+            )
 
 
             if not list_document_chunks["results"]:
             if not list_document_chunks["results"]:
                 raise R2RException(
                 raise R2RException(
@@ -811,12 +871,12 @@ class DocumentsRouter(BaseRouterV3):
             is_owner = str(
             is_owner = str(
                 list_document_chunks["results"][0].get("owner_id")
                 list_document_chunks["results"][0].get("owner_id")
             ) == str(auth_user.id)
             ) == str(auth_user.id)
-            document_collections = await self.services[
-                "management"
-            ].collections_overview(
-                offset=0,
-                limit=-1,
-                document_ids=[id],
+            document_collections = (
+                await self.services.management.collections_overview(
+                    offset=0,
+                    limit=-1,
+                    document_ids=[id],
+                )
             )
             )
 
 
             user_has_access = (
             user_has_access = (
@@ -839,6 +899,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}/download",
             "/documents/{id}/download",
+            dependencies=[Depends(self.rate_limit_dependency)],
             response_class=StreamingResponse,
             response_class=StreamingResponse,
             summary="Download document content",
             summary="Download document content",
             openapi_extra={
             openapi_extra={
@@ -891,7 +952,7 @@ class DocumentsRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def get_document_file(
         async def get_document_file(
             id: str = Path(..., description="Document ID"),
             id: str = Path(..., description="Document ID"),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> StreamingResponse:
         ) -> StreamingResponse:
             """
             """
             Downloads the original file content of a document.
             Downloads the original file content of a document.
@@ -909,14 +970,14 @@ class DocumentsRouter(BaseRouterV3):
                 )
                 )
 
 
             # Retrieve the document's information
             # Retrieve the document's information
-            documents_overview_response = await self.services[
-                "management"
-            ].documents_overview(
-                user_ids=None,
-                collection_ids=None,
-                document_ids=[document_uuid],
-                offset=0,
-                limit=1,
+            documents_overview_response = (
+                await self.services.management.documents_overview(
+                    user_ids=None,
+                    collection_ids=None,
+                    document_ids=[document_uuid],
+                    offset=0,
+                    limit=1,
+                )
             )
             )
 
 
             if not documents_overview_response["results"]:
             if not documents_overview_response["results"]:
@@ -927,21 +988,21 @@ class DocumentsRouter(BaseRouterV3):
             is_owner = str(document.owner_id) == str(auth_user.id)
             is_owner = str(document.owner_id) == str(auth_user.id)
 
 
             if not auth_user.is_superuser and not is_owner:
             if not auth_user.is_superuser and not is_owner:
-                document_collections = await self.services[
-                    "management"
-                ].collections_overview(
-                    offset=0,
-                    limit=-1,
-                    document_ids=[document_uuid],
+                document_collections = (
+                    await self.services.management.collections_overview(
+                        offset=0,
+                        limit=-1,
+                        document_ids=[document_uuid],
+                    )
                 )
                 )
 
 
                 document_collection_ids = {
                 document_collection_ids = {
                     str(ele.id) for ele in document_collections["results"]
                     str(ele.id) for ele in document_collections["results"]
                 }
                 }
 
 
-                user_collection_ids = set(
+                user_collection_ids = {
                     str(cid) for cid in auth_user.collection_ids
                     str(cid) for cid in auth_user.collection_ids
-                )
+                }
 
 
                 has_collection_access = user_collection_ids.intersection(
                 has_collection_access = user_collection_ids.intersection(
                     document_collection_ids
                     document_collection_ids
@@ -952,7 +1013,7 @@ class DocumentsRouter(BaseRouterV3):
                         "Not authorized to access this document.", 403
                         "Not authorized to access this document.", 403
                     )
                     )
 
 
-            file_tuple = await self.services["management"].download_file(
+            file_tuple = await self.services.management.download_file(
                 document_uuid
                 document_uuid
             )
             )
             if not file_tuple:
             if not file_tuple:
@@ -983,6 +1044,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.delete(
         @self.router.delete(
             "/documents/by-filter",
             "/documents/by-filter",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete documents by filter",
             summary="Delete documents by filter",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1016,7 +1078,7 @@ class DocumentsRouter(BaseRouterV3):
             filters: Json[dict] = Body(
             filters: Json[dict] = Body(
                 ..., description="JSON-encoded filters"
                 ..., description="JSON-encoded filters"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete documents based on provided filters. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. Deletion requests are limited to a user's own documents.
             Delete documents based on provided filters. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. Deletion requests are limited to a user's own documents.
@@ -1025,12 +1087,15 @@ class DocumentsRouter(BaseRouterV3):
             filters_dict = {
             filters_dict = {
                 "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters]
                 "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters]
             }
             }
-            await self.services["management"].delete(filters=filters_dict)
+            await self.services.management.delete_documents_and_chunks_by_filter(
+                filters=filters_dict
+            )
 
 
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.delete(
         @self.router.delete(
             "/documents/{id}",
             "/documents/{id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete a document",
             summary="Delete a document",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1090,24 +1155,28 @@ class DocumentsRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def delete_document_by_id(
         async def delete_document_by_id(
             id: UUID = Path(..., description="Document ID"),
             id: UUID = Path(..., description="Document ID"),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete a specific document. All chunks corresponding to the document are deleted, and all other references to the document are removed.
             Delete a specific document. All chunks corresponding to the document are deleted, and all other references to the document are removed.
 
 
             NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release.
             NOTE - Deletions do not yet impact the knowledge graph or other derived data. This feature is planned for a future release.
             """
             """
-            filters = {
-                "$and": [
-                    {"owner_id": {"$eq": str(auth_user.id)}},
-                    {"document_id": {"$eq": str(id)}},
-                ]
-            }
-            await self.services["management"].delete(filters=filters)
+
+            filters = {"document_id": {"$eq": str(id)}}
+            if not auth_user.is_superuser:
+                filters = {
+                    "$and": [{"owner_id": {"$eq": str(auth_user.id)}}, filters]
+                }
+
+            await self.services.management.delete_documents_and_chunks_by_filter(
+                filters=filters
+            )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}/collections",
             "/documents/{id}/collections",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List document collections",
             summary="List document collections",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1178,7 +1247,7 @@ class DocumentsRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionsResponse:
         ) -> WrappedCollectionsResponse:
             """
             """
             Retrieves all collections that contain the specified document. This endpoint is restricted
             Retrieves all collections that contain the specified document. This endpoint is restricted
@@ -1198,12 +1267,12 @@ class DocumentsRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            collections_response = await self.services[
-                "management"
-            ].collections_overview(
-                offset=offset,
-                limit=limit,
-                document_ids=[UUID(id)],  # Convert string ID to UUID
+            collections_response = (
+                await self.services.management.collections_overview(
+                    offset=offset,
+                    limit=limit,
+                    document_ids=[UUID(id)],  # Convert string ID to UUID
+                )
             )
             )
 
 
             return collections_response["results"], {  # type: ignore
             return collections_response["results"], {  # type: ignore
@@ -1212,6 +1281,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/documents/{id}/extract",
             "/documents/{id}/extract",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Extract entities and relationships",
             summary="Extract entities and relationships",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1251,7 +1321,7 @@ class DocumentsRouter(BaseRouterV3):
                 default=True,
                 default=True,
                 description="Whether to run the entities and relationships extraction process with orchestration.",
                 description="Whether to run the entities and relationships extraction process with orchestration.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Extracts entities and relationships from a document.
             Extracts entities and relationships from a document.
@@ -1294,9 +1364,7 @@ class DocumentsRouter(BaseRouterV3):
                     "message": "Estimate retrieved successfully",
                     "message": "Estimate retrieved successfully",
                     "task_id": None,
                     "task_id": None,
                     "id": id,
                     "id": id,
-                    "estimate": await self.services[
-                        "kg"
-                    ].get_creation_estimate(
+                    "estimate": await self.services.graph.get_creation_estimate(
                         document_id=id,
                         document_id=id,
                         graph_creation_settings=server_graph_creation_settings,
                         graph_creation_settings=server_graph_creation_settings,
                     ),
                     ),
@@ -1309,14 +1377,14 @@ class DocumentsRouter(BaseRouterV3):
                     "user": auth_user.json(),
                     "user": auth_user.json(),
                 }
                 }
 
 
-                return await self.orchestration_provider.run_workflow(
+                return await self.providers.orchestration.run_workflow(
                     "extract-triples", {"request": workflow_input}, {}
                     "extract-triples", {"request": workflow_input}, {}
                 )
                 )
             else:
             else:
                 from core.main.orchestration import simple_kg_factory
                 from core.main.orchestration import simple_kg_factory
 
 
                 logger.info("Running extract-triples without orchestration.")
                 logger.info("Running extract-triples without orchestration.")
-                simple_kg = simple_kg_factory(self.services["kg"])
+                simple_kg = simple_kg_factory(self.services.graph)
                 await simple_kg["extract-triples"](workflow_input)
                 await simple_kg["extract-triples"](workflow_input)
                 return {  # type: ignore
                 return {  # type: ignore
                     "message": "Graph created successfully.",
                     "message": "Graph created successfully.",
@@ -1325,6 +1393,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}/entities",
             "/documents/{id}/entities",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Lists the entities from the document",
             summary="Lists the entities from the document",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1367,7 +1436,7 @@ class DocumentsRouter(BaseRouterV3):
                 False,
                 False,
                 description="Whether to include vector embeddings in the response.",
                 description="Whether to include vector embeddings in the response.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedEntitiesResponse:
         ) -> WrappedEntitiesResponse:
             """
             """
             Retrieves the entities that were extracted from a document. These represent
             Retrieves the entities that were extracted from a document. These represent
@@ -1378,28 +1447,30 @@ class DocumentsRouter(BaseRouterV3):
 
 
             Results are returned in the order they were extracted from the document.
             Results are returned in the order they were extracted from the document.
             """
             """
-            if (
-                not auth_user.is_superuser
-                and id not in auth_user.collection_ids
-            ):
-                raise R2RException(
-                    "The currently authenticated user does not have access to the specified collection.",
-                    403,
-                )
+            # if (
+            #     not auth_user.is_superuser
+            #     and id not in auth_user.collection_ids
+            # ):
+            #     raise R2RException(
+            #         "The currently authenticated user does not have access to the specified collection.",
+            #         403,
+            #     )
 
 
             # First check if the document exists and user has access
             # First check if the document exists and user has access
-            documents_overview_response = await self.services[
-                "management"
-            ].documents_overview(
-                user_ids=None if auth_user.is_superuser else [auth_user.id],
-                collection_ids=(
-                    None
-                    if auth_user.is_superuser
-                    else auth_user.collection_ids
-                ),
-                document_ids=[id],
-                offset=0,
-                limit=1,
+            documents_overview_response = (
+                await self.services.management.documents_overview(
+                    user_ids=(
+                        None if auth_user.is_superuser else [auth_user.id]
+                    ),
+                    collection_ids=(
+                        None
+                        if auth_user.is_superuser
+                        else auth_user.collection_ids
+                    ),
+                    document_ids=[id],
+                    offset=0,
+                    limit=1,
+                )
             )
             )
 
 
             if not documents_overview_response["results"]:
             if not documents_overview_response["results"]:
@@ -1420,6 +1491,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/documents/{id}/relationships",
             "/documents/{id}/relationships",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List document relationships",
             summary="List document relationships",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1505,7 +1577,7 @@ class DocumentsRouter(BaseRouterV3):
                 None,
                 None,
                 description="Filter relationships by specific relationship types.",
                 description="Filter relationships by specific relationship types.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRelationshipsResponse:
         ) -> WrappedRelationshipsResponse:
             """
             """
             Retrieves the relationships between entities that were extracted from a document. These represent
             Retrieves the relationships between entities that were extracted from a document. These represent
@@ -1516,28 +1588,30 @@ class DocumentsRouter(BaseRouterV3):
 
 
             Results are returned in the order they were extracted from the document.
             Results are returned in the order they were extracted from the document.
             """
             """
-            if (
-                not auth_user.is_superuser
-                and id not in auth_user.collection_ids
-            ):
-                raise R2RException(
-                    "The currently authenticated user does not have access to the specified collection.",
-                    403,
-                )
+            # if (
+            #     not auth_user.is_superuser
+            #     and id not in auth_user.collection_ids
+            # ):
+            #     raise R2RException(
+            #         "The currently authenticated user does not have access to the specified collection.",
+            #         403,
+            #     )
 
 
             # First check if the document exists and user has access
             # First check if the document exists and user has access
-            documents_overview_response = await self.services[
-                "management"
-            ].documents_overview(
-                user_ids=None if auth_user.is_superuser else [auth_user.id],
-                collection_ids=(
-                    None
-                    if auth_user.is_superuser
-                    else auth_user.collection_ids
-                ),
-                document_ids=[id],
-                offset=0,
-                limit=1,
+            documents_overview_response = (
+                await self.services.management.documents_overview(
+                    user_ids=(
+                        None if auth_user.is_superuser else [auth_user.id]
+                    ),
+                    collection_ids=(
+                        None
+                        if auth_user.is_superuser
+                        else auth_user.collection_ids
+                    ),
+                    document_ids=[id],
+                    offset=0,
+                    limit=1,
+                )
             )
             )
 
 
             if not documents_overview_response["results"]:
             if not documents_overview_response["results"]:
@@ -1559,6 +1633,7 @@ class DocumentsRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/documents/search",
             "/documents/search",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Search document summaries",
             summary="Search document summaries",
         )
         )
         @self.base_endpoint
         @self.base_endpoint
@@ -1583,7 +1658,7 @@ class DocumentsRouter(BaseRouterV3):
                 default_factory=SearchSettings,
                 default_factory=SearchSettings,
                 description="Settings for document search",
                 description="Settings for document search",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):  # -> WrappedDocumentSearchResponse:  # type: ignore
         ):  # -> WrappedDocumentSearchResponse:  # type: ignore
             """
             """
             Perform a search query on the automatically generated document summaries in the system.
             Perform a search query on the automatically generated document summaries in the system.
@@ -1601,7 +1676,7 @@ class DocumentsRouter(BaseRouterV3):
             query_embedding = (
             query_embedding = (
                 await self.providers.embedding.async_get_embedding(query)
                 await self.providers.embedding.async_get_embedding(query)
             )
             )
-            results = await self.services["retrieval"].search_documents(
+            results = await self.services.retrieval.search_documents(
                 query=query,
                 query=query,
                 query_embedding=query_embedding,
                 query_embedding=query_embedding,
                 settings=effective_settings,
                 settings=effective_settings,

+ 93 - 73
core/main/api/v3/graph_router.py

@@ -5,7 +5,7 @@ from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
-from core.base import KGEnrichmentStatus, R2RException, RunType, Workflow
+from core.base import KGEnrichmentStatus, R2RException, Workflow
 from core.base.abstractions import KGRunType
 from core.base.abstractions import KGRunType
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
@@ -19,15 +19,12 @@ from core.base.api.models import (
     WrappedRelationshipResponse,
     WrappedRelationshipResponse,
     WrappedRelationshipsResponse,
     WrappedRelationshipsResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 from core.utils import (
 from core.utils import (
     generate_default_user_collection_id,
     generate_default_user_collection_id,
     update_settings_from_dict,
     update_settings_from_dict,
 )
 )
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -36,22 +33,18 @@ logger = logging.getLogger()
 class GraphRouter(BaseRouterV3):
 class GraphRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.KG,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
         self._register_workflows()
         self._register_workflows()
 
 
     def _register_workflows(self):
     def _register_workflows(self):
 
 
         workflow_messages = {}
         workflow_messages = {}
-        if self.orchestration_provider.config.provider == "hatchet":
+        if self.providers.orchestration.config.provider == "hatchet":
             workflow_messages["extract-triples"] = (
             workflow_messages["extract-triples"] = (
-                "Graph creation task queued successfully."
+                "Document extraction task queued successfully."
             )
             )
             workflow_messages["build-communities"] = (
             workflow_messages["build-communities"] = (
                 "Graph enrichment task queued successfully."
                 "Graph enrichment task queued successfully."
@@ -61,18 +54,18 @@ class GraphRouter(BaseRouterV3):
             )
             )
         else:
         else:
             workflow_messages["extract-triples"] = (
             workflow_messages["extract-triples"] = (
-                "Document entities and relationships extracted successfully. To generate GraphRAG communities, POST to `/graphs/<collection_id>/communities/build` with a collection this document belongs to."
+                "Document entities and relationships extracted successfully."
             )
             )
             workflow_messages["build-communities"] = (
             workflow_messages["build-communities"] = (
-                "Graph communities created successfully. You can view the communities at http://localhost:7272/v2/communities"
+                "Graph communities created successfully."
             )
             )
             workflow_messages["entity-deduplication"] = (
             workflow_messages["entity-deduplication"] = (
                 "KG Entity Deduplication completed successfully."
                 "KG Entity Deduplication completed successfully."
             )
             )
 
 
-        self.orchestration_provider.register_workflows(
+        self.providers.orchestration.register_workflows(
             Workflow.KG,
             Workflow.KG,
-            self.services["kg"],
+            self.services.graph,
             workflow_messages,
             workflow_messages,
         )
         )
 
 
@@ -126,7 +119,7 @@ class GraphRouter(BaseRouterV3):
 
 
         # Return cost estimate if requested
         # Return cost estimate if requested
         if run_type == KGRunType.ESTIMATE:
         if run_type == KGRunType.ESTIMATE:
-            return await self.services["kg"].get_deduplication_estimate(
+            return await self.services.graph.get_deduplication_estimate(
                 collection_id, server_settings
                 collection_id, server_settings
             )
             )
 
 
@@ -137,13 +130,13 @@ class GraphRouter(BaseRouterV3):
         }
         }
 
 
         if run_with_orchestration:
         if run_with_orchestration:
-            return await self.orchestration_provider.run_workflow(  # type: ignore
+            return await self.providers.orchestration.run_workflow(  # type: ignore
                 "entity-deduplication", {"request": workflow_input}, {}
                 "entity-deduplication", {"request": workflow_input}, {}
             )
             )
         else:
         else:
             from core.main.orchestration import simple_kg_factory
             from core.main.orchestration import simple_kg_factory
 
 
-            simple_kg = simple_kg_factory(self.services["kg"])
+            simple_kg = simple_kg_factory(self.services.graph)
             await simple_kg["entity-deduplication"](workflow_input)
             await simple_kg["entity-deduplication"](workflow_input)
             return {  # type: ignore
             return {  # type: ignore
                 "message": "Entity deduplication completed successfully.",
                 "message": "Entity deduplication completed successfully.",
@@ -161,6 +154,7 @@ class GraphRouter(BaseRouterV3):
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.get(
         @self.router.get(
             "/graphs",
             "/graphs",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List graphs",
             summary="List graphs",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -213,7 +207,7 @@ class GraphRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGraphsResponse:
         ) -> WrappedGraphsResponse:
             """
             """
             Returns a paginated list of graphs the authenticated user has access to.
             Returns a paginated list of graphs the authenticated user has access to.
@@ -229,7 +223,7 @@ class GraphRouter(BaseRouterV3):
 
 
             graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
             graph_uuids = [UUID(graph_id) for graph_id in collection_ids]
 
 
-            list_graphs_response = await self.services["kg"].list_graphs(
+            list_graphs_response = await self.services.graph.list_graphs(
                 # user_ids=requesting_user_id,
                 # user_ids=requesting_user_id,
                 graph_ids=graph_uuids,
                 graph_ids=graph_uuids,
                 offset=offset,
                 offset=offset,
@@ -243,6 +237,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}",
             "/graphs/{collection_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Retrieve graph details",
             summary="Retrieve graph details",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -292,7 +287,7 @@ class GraphRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def get_graph(
         async def get_graph(
             collection_id: UUID = Path(...),
             collection_id: UUID = Path(...),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGraphResponse:
         ) -> WrappedGraphResponse:
             """
             """
             Retrieves detailed information about a specific graph by ID.
             Retrieves detailed information about a specific graph by ID.
@@ -307,7 +302,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            list_graphs_response = await self.services["kg"].list_graphs(
+            list_graphs_response = await self.services.graph.list_graphs(
                 # user_ids=None,
                 # user_ids=None,
                 graph_ids=[collection_id],
                 graph_ids=[collection_id],
                 offset=0,
                 offset=0,
@@ -317,6 +312,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/communities/build",
             "/graphs/{collection_id}/communities/build",
+            dependencies=[Depends(self.rate_limit_dependency)],
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def build_communities(
         async def build_communities(
@@ -332,7 +328,7 @@ class GraphRouter(BaseRouterV3):
                 description="Settings for the graph enrichment process.",
                 description="Settings for the graph enrichment process.",
             ),
             ),
             run_with_orchestration: Optional[bool] = Body(True),
             run_with_orchestration: Optional[bool] = Body(True),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             """
             """
             Creates communities in the graph by analyzing entity relationships and similarities.
             Creates communities in the graph by analyzing entity relationships and similarities.
@@ -391,14 +387,14 @@ class GraphRouter(BaseRouterV3):
 
 
             if run_with_orchestration:
             if run_with_orchestration:
 
 
-                return await self.orchestration_provider.run_workflow(  # type: ignore
+                return await self.providers.orchestration.run_workflow(  # type: ignore
                     "build-communities", {"request": workflow_input}, {}
                     "build-communities", {"request": workflow_input}, {}
                 )
                 )
             else:
             else:
                 from core.main.orchestration import simple_kg_factory
                 from core.main.orchestration import simple_kg_factory
 
 
                 logger.info("Running build-communities without orchestration.")
                 logger.info("Running build-communities without orchestration.")
-                simple_kg = simple_kg_factory(self.services["kg"])
+                simple_kg = simple_kg_factory(self.services.graph)
                 await simple_kg["build-communities"](workflow_input)
                 await simple_kg["build-communities"](workflow_input)
                 return {
                 return {
                     "message": "Graph communities created successfully.",
                     "message": "Graph communities created successfully.",
@@ -407,6 +403,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/reset",
             "/graphs/{collection_id}/reset",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Reset a graph back to the initial state.",
             summary="Reset a graph back to the initial state.",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -456,7 +453,7 @@ class GraphRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def reset(
         async def reset(
             collection_id: UUID = Path(...),
             collection_id: UUID = Path(...),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Deletes a graph and all its associated data.
             Deletes a graph and all its associated data.
@@ -479,13 +476,14 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            await self.services["kg"].reset_graph_v3(id=collection_id)
+            await self.services.graph.reset_graph_v3(id=collection_id)
             # await _pull(collection_id, auth_user)
             # await _pull(collection_id, auth_user)
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         # update graph
         # update graph
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}",
             "/graphs/{collection_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Update graph",
             summary="Update graph",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -542,7 +540,7 @@ class GraphRouter(BaseRouterV3):
             description: Optional[str] = Body(
             description: Optional[str] = Body(
                 None, description="An optional description of the graph"
                 None, description="An optional description of the graph"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             """
             """
             Update an existing graphs's configuration.
             Update an existing graphs's configuration.
@@ -564,7 +562,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].update_graph(  # type: ignore
+            return await self.services.graph.update_graph(  # type: ignore
                 collection_id,
                 collection_id,
                 name=name,
                 name=name,
                 description=description,
                 description=description,
@@ -572,6 +570,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/entities",
             "/graphs/{collection_id}/entities",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -625,7 +624,7 @@ class GraphRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedEntitiesResponse:
         ) -> WrappedEntitiesResponse:
             """Lists all entities in the graph with pagination support."""
             """Lists all entities in the graph with pagination support."""
             if (
             if (
@@ -638,7 +637,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            entities, count = await self.services["kg"].get_entities(
+            entities, count = await self.services.graph.get_entities(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 offset=offset,
                 offset=offset,
                 limit=limit,
                 limit=limit,
@@ -648,7 +647,10 @@ class GraphRouter(BaseRouterV3):
                 "total_entries": count,
                 "total_entries": count,
             }
             }
 
 
-        @self.router.post("/graphs/{collection_id}/entities")
+        @self.router.post(
+            "/graphs/{collection_id}/entities",
+            dependencies=[Depends(self.rate_limit_dependency)],
+        )
         @self.base_endpoint
         @self.base_endpoint
         async def create_entity(
         async def create_entity(
             collection_id: UUID = Path(
             collection_id: UUID = Path(
@@ -667,7 +669,7 @@ class GraphRouter(BaseRouterV3):
             metadata: Optional[dict] = Body(
             metadata: Optional[dict] = Body(
                 None, description="The metadata of the entity to create."
                 None, description="The metadata of the entity to create."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedEntityResponse:
         ) -> WrappedEntityResponse:
             """Creates a new entity in the graph."""
             """Creates a new entity in the graph."""
             if (
             if (
@@ -680,7 +682,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].create_entity(
+            return await self.services.graph.create_entity(
                 name=name,
                 name=name,
                 description=description,
                 description=description,
                 parent_id=collection_id,
                 parent_id=collection_id,
@@ -688,7 +690,10 @@ class GraphRouter(BaseRouterV3):
                 metadata=metadata,
                 metadata=metadata,
             )
             )
 
 
-        @self.router.post("/graphs/{collection_id}/relationships")
+        @self.router.post(
+            "/graphs/{collection_id}/relationships",
+            dependencies=[Depends(self.rate_limit_dependency)],
+        )
         @self.base_endpoint
         @self.base_endpoint
         async def create_relationship(
         async def create_relationship(
             collection_id: UUID = Path(
             collection_id: UUID = Path(
@@ -722,7 +727,7 @@ class GraphRouter(BaseRouterV3):
             metadata: Optional[dict] = Body(
             metadata: Optional[dict] = Body(
                 None, description="The metadata of the relationship to create."
                 None, description="The metadata of the relationship to create."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRelationshipResponse:
         ) -> WrappedRelationshipResponse:
             """Creates a new relationship in the graph."""
             """Creates a new relationship in the graph."""
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
@@ -739,7 +744,7 @@ class GraphRouter(BaseRouterV3):
                     "The currently authenticated user does not have access to the collection associated with the given graph.",
                     "The currently authenticated user does not have access to the collection associated with the given graph.",
                     403,
                     403,
                 )
                 )
-            return await self.services["kg"].create_relationship(
+            return await self.services.graph.create_relationship(
                 subject=subject,
                 subject=subject,
                 subject_id=subject_id,
                 subject_id=subject_id,
                 predicate=predicate,
                 predicate=predicate,
@@ -753,6 +758,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/entities/{entity_id}",
             "/graphs/{collection_id}/entities/{entity_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -802,7 +808,7 @@ class GraphRouter(BaseRouterV3):
             entity_id: UUID = Path(
             entity_id: UUID = Path(
                 ..., description="The ID of the entity to retrieve."
                 ..., description="The ID of the entity to retrieve."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedEntityResponse:
         ) -> WrappedEntityResponse:
             """Retrieves a specific entity by its ID."""
             """Retrieves a specific entity by its ID."""
             if (
             if (
@@ -826,7 +832,10 @@ class GraphRouter(BaseRouterV3):
                 raise R2RException("Entity not found", 404)
                 raise R2RException("Entity not found", 404)
             return result[0][0]
             return result[0][0]
 
 
-        @self.router.post("/graphs/{collection_id}/entities/{entity_id}")
+        @self.router.post(
+            "/graphs/{collection_id}/entities/{entity_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
+        )
         @self.base_endpoint
         @self.base_endpoint
         async def update_entity(
         async def update_entity(
             collection_id: UUID = Path(
             collection_id: UUID = Path(
@@ -848,7 +857,7 @@ class GraphRouter(BaseRouterV3):
             metadata: Optional[dict] = Body(
             metadata: Optional[dict] = Body(
                 None, description="The updated metadata of the entity."
                 None, description="The updated metadata of the entity."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedEntityResponse:
         ) -> WrappedEntityResponse:
             """Updates an existing entity in the graph."""
             """Updates an existing entity in the graph."""
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
@@ -865,7 +874,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].update_entity(
+            return await self.services.graph.update_entity(
                 entity_id=entity_id,
                 entity_id=entity_id,
                 name=name,
                 name=name,
                 category=category,
                 category=category,
@@ -875,6 +884,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.delete(
         @self.router.delete(
             "/graphs/{collection_id}/entities/{entity_id}",
             "/graphs/{collection_id}/entities/{entity_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Remove an entity",
             summary="Remove an entity",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -926,7 +936,7 @@ class GraphRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The ID of the entity to remove from the graph.",
                 description="The ID of the entity to remove from the graph.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """Removes an entity from the graph."""
             """Removes an entity from the graph."""
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
@@ -944,7 +954,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            await self.services["kg"].delete_entity(
+            await self.services.graph.delete_entity(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 entity_id=entity_id,
                 entity_id=entity_id,
             )
             )
@@ -953,6 +963,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/relationships",
             "/graphs/{collection_id}/relationships",
+            dependencies=[Depends(self.rate_limit_dependency)],
             description="Lists all relationships in the graph with pagination support.",
             description="Lists all relationships in the graph with pagination support.",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1007,7 +1018,7 @@ class GraphRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRelationshipsResponse:
         ) -> WrappedRelationshipsResponse:
             """
             """
             Lists all relationships in the graph with pagination support.
             Lists all relationships in the graph with pagination support.
@@ -1022,7 +1033,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            relationships, count = await self.services["kg"].get_relationships(
+            relationships, count = await self.services.graph.get_relationships(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 offset=offset,
                 offset=offset,
                 limit=limit,
                 limit=limit,
@@ -1034,6 +1045,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/relationships/{relationship_id}",
             "/graphs/{collection_id}/relationships/{relationship_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             description="Retrieves a specific relationship by its ID.",
             description="Retrieves a specific relationship by its ID.",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1084,7 +1096,7 @@ class GraphRouter(BaseRouterV3):
             relationship_id: UUID = Path(
             relationship_id: UUID = Path(
                 ..., description="The ID of the relationship to retrieve."
                 ..., description="The ID of the relationship to retrieve."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRelationshipResponse:
         ) -> WrappedRelationshipResponse:
             """Retrieves a specific relationship by its ID."""
             """Retrieves a specific relationship by its ID."""
             if (
             if (
@@ -1111,7 +1123,8 @@ class GraphRouter(BaseRouterV3):
             return results[0][0]
             return results[0][0]
 
 
         @self.router.post(
         @self.router.post(
-            "/graphs/{collection_id}/relationships/{relationship_id}"
+            "/graphs/{collection_id}/relationships/{relationship_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def update_relationship(
         async def update_relationship(
@@ -1147,7 +1160,7 @@ class GraphRouter(BaseRouterV3):
             metadata: Optional[dict] = Body(
             metadata: Optional[dict] = Body(
                 None, description="The updated metadata of the relationship."
                 None, description="The updated metadata of the relationship."
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRelationshipResponse:
         ) -> WrappedRelationshipResponse:
             """Updates an existing relationship in the graph."""
             """Updates an existing relationship in the graph."""
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
@@ -1165,7 +1178,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].update_relationship(
+            return await self.services.graph.update_relationship(
                 relationship_id=relationship_id,
                 relationship_id=relationship_id,
                 subject=subject,
                 subject=subject,
                 subject_id=subject_id,
                 subject_id=subject_id,
@@ -1179,6 +1192,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.delete(
         @self.router.delete(
             "/graphs/{collection_id}/relationships/{relationship_id}",
             "/graphs/{collection_id}/relationships/{relationship_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             description="Removes a relationship from the graph.",
             description="Removes a relationship from the graph.",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1230,7 +1244,7 @@ class GraphRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The ID of the relationship to remove from the graph.",
                 description="The ID of the relationship to remove from the graph.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """Removes a relationship from the graph."""
             """Removes a relationship from the graph."""
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
@@ -1247,7 +1261,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            await self.services["kg"].delete_relationship(
+            await self.services.graph.delete_relationship(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 relationship_id=relationship_id,
                 relationship_id=relationship_id,
             )
             )
@@ -1256,6 +1270,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/communities",
             "/graphs/{collection_id}/communities",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Create a new community",
             summary="Create a new community",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1322,7 +1337,7 @@ class GraphRouter(BaseRouterV3):
             rating_explanation: Optional[str] = Body(
             rating_explanation: Optional[str] = Body(
                 default="", description="Explanation for the rating"
                 default="", description="Explanation for the rating"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCommunityResponse:
         ) -> WrappedCommunityResponse:
             """
             """
             Creates a new community in the graph.
             Creates a new community in the graph.
@@ -1353,7 +1368,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].create_community(
+            return await self.services.graph.create_community(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 name=name,
                 name=name,
                 summary=summary,
                 summary=summary,
@@ -1364,6 +1379,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/communities",
             "/graphs/{collection_id}/communities",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List communities",
             summary="List communities",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1418,7 +1434,7 @@ class GraphRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCommunitiesResponse:
         ) -> WrappedCommunitiesResponse:
             """
             """
             Lists all communities in the graph with pagination support.
             Lists all communities in the graph with pagination support.
@@ -1433,7 +1449,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            communities, count = await self.services["kg"].get_communities(
+            communities, count = await self.services.graph.get_communities(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 offset=offset,
                 offset=offset,
                 limit=limit,
                 limit=limit,
@@ -1445,6 +1461,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/graphs/{collection_id}/communities/{community_id}",
             "/graphs/{collection_id}/communities/{community_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Retrieve a community",
             summary="Retrieve a community",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1492,7 +1509,7 @@ class GraphRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The ID of the community to get.",
                 description="The ID of the community to get.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCommunityResponse:
         ) -> WrappedCommunityResponse:
             """
             """
             Retrieves a specific community by its ID.
             Retrieves a specific community by its ID.
@@ -1507,14 +1524,14 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            results = await self.services[
-                "kg"
-            ].providers.database.graphs_handler.communities.get(
-                parent_id=collection_id,
-                community_ids=[community_id],
-                store_type="graphs",
-                offset=0,
-                limit=1,
+            results = (
+                await self.providers.database.graphs_handler.communities.get(
+                    parent_id=collection_id,
+                    community_ids=[community_id],
+                    store_type="graphs",
+                    offset=0,
+                    limit=1,
+                )
             )
             )
             if len(results) == 0 or len(results[0]) == 0:
             if len(results) == 0 or len(results[0]) == 0:
                 raise R2RException("Community not found", 404)
                 raise R2RException("Community not found", 404)
@@ -1522,6 +1539,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.delete(
         @self.router.delete(
             "/graphs/{collection_id}/communities/{community_id}",
             "/graphs/{collection_id}/communities/{community_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete a community",
             summary="Delete a community",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1573,7 +1591,7 @@ class GraphRouter(BaseRouterV3):
                 ...,
                 ...,
                 description="The ID of the community to delete.",
                 description="The ID of the community to delete.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             if (
             if (
                 not auth_user.is_superuser
                 not auth_user.is_superuser
@@ -1593,7 +1611,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            await self.services["kg"].delete_community(
+            await self.services.graph.delete_community(
                 parent_id=collection_id,
                 parent_id=collection_id,
                 community_id=community_id,
                 community_id=community_id,
             )
             )
@@ -1601,6 +1619,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/communities/{community_id}",
             "/graphs/{collection_id}/communities/{community_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Update community",
             summary="Update community",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1661,7 +1680,7 @@ class GraphRouter(BaseRouterV3):
             findings: Optional[list[str]] = Body(None),
             findings: Optional[list[str]] = Body(None),
             rating: Optional[float] = Body(default=None, ge=1, le=10),
             rating: Optional[float] = Body(default=None, ge=1, le=10),
             rating_explanation: Optional[str] = Body(None),
             rating_explanation: Optional[str] = Body(None),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCommunityResponse:
         ) -> WrappedCommunityResponse:
             """
             """
             Updates an existing community in the graph.
             Updates an existing community in the graph.
@@ -1684,7 +1703,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["kg"].update_community(
+            return await self.services.graph.update_community(
                 community_id=community_id,
                 community_id=community_id,
                 name=name,
                 name=name,
                 summary=summary,
                 summary=summary,
@@ -1695,6 +1714,7 @@ class GraphRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/graphs/{collection_id}/pull",
             "/graphs/{collection_id}/pull",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Pull latest entities to the graph",
             summary="Pull latest entities to the graph",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1745,7 +1765,7 @@ class GraphRouter(BaseRouterV3):
             # document_ids: list[UUID] = Body(
             # document_ids: list[UUID] = Body(
             #     ..., description="List of document IDs to add to the graph."
             #     ..., description="List of document IDs to add to the graph."
             # ),
             # ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Adds documents to a graph by copying their entities and relationships.
             Adds documents to a graph by copying their entities and relationships.
@@ -1781,7 +1801,7 @@ class GraphRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            list_graphs_response = await self.services["kg"].list_graphs(
+            list_graphs_response = await self.services.graph.list_graphs(
                 # user_ids=None,
                 # user_ids=None,
                 graph_ids=[collection_id],
                 graph_ids=[collection_id],
                 offset=0,
                 offset=0,

+ 16 - 19
core/main/api/v3/indices_router.py

@@ -8,18 +8,15 @@ from typing import Optional
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
-from core.base import IndexConfig, R2RException, RunType
+from core.base import IndexConfig, R2RException
 from core.base.abstractions import VectorTableName
 from core.base.abstractions import VectorTableName
 from core.base.api.models import (
 from core.base.api.models import (
     GenericMessageResponse,
     GenericMessageResponse,
     WrappedGenericMessageResponse,
     WrappedGenericMessageResponse,
     WrappedListVectorIndicesResponse,
     WrappedListVectorIndicesResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
@@ -29,20 +26,17 @@ class IndicesRouter(BaseRouterV3):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.INGESTION,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
 
 
         ## TODO - Allow developer to pass the index id with the request
         ## TODO - Allow developer to pass the index id with the request
         @self.router.post(
         @self.router.post(
             "/indices",
             "/indices",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Create Vector Index",
             summary="Create Vector Index",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -178,7 +172,7 @@ class IndicesRouter(BaseRouterV3):
                 True,
                 True,
                 description="Whether to run index creation as an orchestrated task (recommended for large indices)",
                 description="Whether to run index creation as an orchestrated task (recommended for large indices)",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Create a new vector similarity search index in over the target table. Allowed tables include 'vectors', 'entity', 'document_collections'.
             Create a new vector similarity search index in over the target table. Allowed tables include 'vectors', 'entity', 'document_collections'.
@@ -220,7 +214,7 @@ class IndicesRouter(BaseRouterV3):
                 f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}"
                 f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}"
             )
             )
 
 
-            result = await self.orchestration_provider.run_workflow(
+            result = await self.providers.orchestration.run_workflow(
                 "create-vector-index",
                 "create-vector-index",
                 {
                 {
                     "request": {
                     "request": {
@@ -242,6 +236,7 @@ class IndicesRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/indices",
             "/indices",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List Vector Indices",
             summary="List Vector Indices",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -320,7 +315,7 @@ class IndicesRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedListVectorIndicesResponse:
         ) -> WrappedListVectorIndicesResponse:
             """
             """
             List existing vector similarity search indices with pagination support.
             List existing vector similarity search indices with pagination support.
@@ -345,6 +340,7 @@ class IndicesRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/indices/{table_name}/{index_name}",
             "/indices/{table_name}/{index_name}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Get Vector Index Details",
             summary="Get Vector Index Details",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -411,7 +407,7 @@ class IndicesRouter(BaseRouterV3):
             index_name: str = Path(
             index_name: str = Path(
                 ..., description="The name of the index to delete"
                 ..., description="The name of the index to delete"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> dict:  #  -> WrappedGetIndexResponse:
         ) -> dict:  #  -> WrappedGetIndexResponse:
             """
             """
             Get detailed information about a specific vector index.
             Get detailed information about a specific vector index.
@@ -498,7 +494,7 @@ class IndicesRouter(BaseRouterV3):
         #             id: UUID = Path(...),
         #             id: UUID = Path(...),
         #             config: IndexConfig = Body(...),
         #             config: IndexConfig = Body(...),
         #             run_with_orchestration: Optional[bool] = Body(True),
         #             run_with_orchestration: Optional[bool] = Body(True),
-        #             auth_user=Depends(self.providers.auth.auth_wrapper),
+        #             auth_user=Depends(self.providers.auth.auth_wrapper()),
         #         ):  # -> WrappedUpdateIndexResponse:
         #         ):  # -> WrappedUpdateIndexResponse:
         #             """
         #             """
         #             Update an existing index's configuration.
         #             Update an existing index's configuration.
@@ -508,6 +504,7 @@ class IndicesRouter(BaseRouterV3):
 
 
         @self.router.delete(
         @self.router.delete(
             "/indices/{table_name}/{index_name}",
             "/indices/{table_name}/{index_name}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete Vector Index",
             summary="Delete Vector Index",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -584,7 +581,7 @@ class IndicesRouter(BaseRouterV3):
             #     description="Whether to delete the index concurrently (recommended for large indices)",
             #     description="Whether to delete the index concurrently (recommended for large indices)",
             # ),
             # ),
             # run_with_orchestration: Optional[bool] = Body(True),
             # run_with_orchestration: Optional[bool] = Body(True),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Delete an existing vector similarity search index.
             Delete an existing vector similarity search index.
@@ -604,7 +601,7 @@ class IndicesRouter(BaseRouterV3):
                 f"Deleting vector index {index_name} from table {table_name}"
                 f"Deleting vector index {index_name} from table {table_name}"
             )
             )
 
 
-            return await self.orchestration_provider.run_workflow(
+            return await self.providers.orchestration.run_workflow(
                 "delete-vector-index",
                 "delete-vector-index",
                 {
                 {
                     "request": {
                     "request": {

+ 13 - 16
core/main/api/v3/logs_router.py

@@ -4,30 +4,21 @@ import logging
 from pathlib import Path
 from pathlib import Path
 
 
 import aiofiles
 import aiofiles
-from fastapi import WebSocket
+from fastapi import Depends, WebSocket
 from fastapi.requests import Request
 from fastapi.requests import Request
 from fastapi.templating import Jinja2Templates
 from fastapi.templating import Jinja2Templates
 
 
-from core.base.logger.base import RunType
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
-
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 
 
 class LogsRouter(BaseRouterV3):
 class LogsRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.UNSPECIFIED,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
         CURRENT_DIR = Path(__file__).resolve().parent
         CURRENT_DIR = Path(__file__).resolve().parent
         TEMPLATES_DIR = CURRENT_DIR.parent / "templates"
         TEMPLATES_DIR = CURRENT_DIR.parent / "templates"
         self.templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
         self.templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
@@ -73,7 +64,10 @@ class LogsRouter(BaseRouterV3):
             return f"Error accessing log file: {str(e)}"
             return f"Error accessing log file: {str(e)}"
 
 
     def _setup_routes(self):
     def _setup_routes(self):
-        @self.router.websocket("/logs/stream")
+        @self.router.websocket(
+            "/logs/stream",
+            dependencies=[Depends(self.websocket_rate_limit_dependency)],
+        )
         async def stream_logs(websocket: WebSocket):
         async def stream_logs(websocket: WebSocket):
             await websocket.accept()
             await websocket.accept()
             try:
             try:
@@ -94,7 +88,10 @@ class LogsRouter(BaseRouterV3):
                 with contextlib.suppress(Exception):
                 with contextlib.suppress(Exception):
                     await websocket.close()
                     await websocket.close()
 
 
-        @self.router.get("/logs/viewer")
+        @self.router.get(
+            "/logs/viewer",
+            dependencies=[Depends(self.rate_limit_dependency)],
+        )
         async def get_log_viewer(request: Request):
         async def get_log_viewer(request: Request):
             return self.templates.TemplateResponse(
             return self.templates.TemplateResponse(
                 "log_viewer.html", {"request": request}
                 "log_viewer.html", {"request": request}

+ 22 - 24
core/main/api/v3/prompts_router.py

@@ -3,7 +3,7 @@ from typing import Optional
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
 
 
-from core.base import R2RException, RunType
+from core.base import R2RException
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     GenericMessageResponse,
     GenericMessageResponse,
@@ -12,29 +12,23 @@ from core.base.api.models import (
     WrappedPromptResponse,
     WrappedPromptResponse,
     WrappedPromptsResponse,
     WrappedPromptsResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 
 
 class PromptsRouter(BaseRouterV3):
 class PromptsRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.MANAGEMENT,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.post(
         @self.router.post(
             "/prompts",
             "/prompts",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Create a new prompt",
             summary="Create a new prompt",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -99,7 +93,7 @@ class PromptsRouter(BaseRouterV3):
                 default={},
                 default={},
                 description="A dictionary mapping input names to their types",
                 description="A dictionary mapping input names to their types",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Create a new prompt with the given configuration.
             Create a new prompt with the given configuration.
@@ -111,13 +105,14 @@ class PromptsRouter(BaseRouterV3):
                     "Only a superuser can create prompts.",
                     "Only a superuser can create prompts.",
                     403,
                     403,
                 )
                 )
-            result = await self.services["management"].add_prompt(
+            result = await self.services.management.add_prompt(
                 name, template, input_types
                 name, template, input_types
             )
             )
             return GenericMessageResponse(message=result)  # type: ignore
             return GenericMessageResponse(message=result)  # type: ignore
 
 
         @self.router.get(
         @self.router.get(
             "/prompts",
             "/prompts",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List all prompts",
             summary="List all prompts",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -172,7 +167,7 @@ class PromptsRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def get_prompts(
         async def get_prompts(
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedPromptsResponse:
         ) -> WrappedPromptsResponse:
             """
             """
             List all available prompts.
             List all available prompts.
@@ -184,9 +179,9 @@ class PromptsRouter(BaseRouterV3):
                     "Only a superuser can list prompts.",
                     "Only a superuser can list prompts.",
                     403,
                     403,
                 )
                 )
-            get_prompts_response = await self.services[
-                "management"
-            ].get_all_prompts()
+            get_prompts_response = (
+                await self.services.management.get_all_prompts()
+            )
 
 
             return (  # type: ignore
             return (  # type: ignore
                 get_prompts_response["results"],
                 get_prompts_response["results"],
@@ -197,6 +192,7 @@ class PromptsRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/prompts/{name}",
             "/prompts/{name}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Get a specific prompt",
             summary="Get a specific prompt",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -266,7 +262,7 @@ class PromptsRouter(BaseRouterV3):
             prompt_override: Optional[str] = Query(
             prompt_override: Optional[str] = Query(
                 None, description="Prompt override"
                 None, description="Prompt override"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedPromptResponse:
         ) -> WrappedPromptResponse:
             """
             """
             Get a specific prompt by name, optionally with inputs and override.
             Get a specific prompt by name, optionally with inputs and override.
@@ -279,13 +275,14 @@ class PromptsRouter(BaseRouterV3):
                     "Only a superuser can retrieve prompts.",
                     "Only a superuser can retrieve prompts.",
                     403,
                     403,
                 )
                 )
-            result = await self.services["management"].get_prompt(
+            result = await self.services.management.get_prompt(
                 name, inputs, prompt_override
                 name, inputs, prompt_override
             )
             )
             return result  # type: ignore
             return result  # type: ignore
 
 
         @self.router.put(
         @self.router.put(
             "/prompts/{name}",
             "/prompts/{name}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Update an existing prompt",
             summary="Update an existing prompt",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -350,7 +347,7 @@ class PromptsRouter(BaseRouterV3):
                 default={},
                 default={},
                 description="A dictionary mapping input names to their types",
                 description="A dictionary mapping input names to their types",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """
             """
             Update an existing prompt's template and/or input types.
             Update an existing prompt's template and/or input types.
@@ -362,13 +359,14 @@ class PromptsRouter(BaseRouterV3):
                     "Only a superuser can update prompts.",
                     "Only a superuser can update prompts.",
                     403,
                     403,
                 )
                 )
-            result = await self.services["management"].update_prompt(
+            result = await self.services.management.update_prompt(
                 name, template, input_types
                 name, template, input_types
             )
             )
             return GenericMessageResponse(message=result)  # type: ignore
             return GenericMessageResponse(message=result)  # type: ignore
 
 
         @self.router.delete(
         @self.router.delete(
             "/prompts/{name}",
             "/prompts/{name}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete a prompt",
             summary="Delete a prompt",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -426,7 +424,7 @@ class PromptsRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def delete_prompt(
         async def delete_prompt(
             name: str = Path(..., description="Prompt name"),
             name: str = Path(..., description="Prompt name"),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete a prompt by name.
             Delete a prompt by name.
@@ -438,5 +436,5 @@ class PromptsRouter(BaseRouterV3):
                     "Only a superuser can delete prompts.",
                     "Only a superuser can delete prompts.",
                     403,
                     403,
                 )
                 )
-            await self.services["management"].delete_prompt(name)
+            await self.services.management.delete_prompt(name)
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore

+ 29 - 39
core/main/api/v3/retrieval_router.py

@@ -1,6 +1,4 @@
-import asyncio
 import textwrap
 import textwrap
-from copy import copy
 from typing import Any, Optional
 from typing import Any, Optional
 from uuid import UUID
 from uuid import UUID
 
 
@@ -21,12 +19,8 @@ from core.base.api.models import (
     WrappedRAGResponse,
     WrappedRAGResponse,
     WrappedSearchResponse,
     WrappedSearchResponse,
 )
 )
-from core.base.logger.base import RunType
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 
 
@@ -49,14 +43,10 @@ def merge_search_settings(
 class RetrievalRouterV3(BaseRouterV3):
 class RetrievalRouterV3(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.RETRIEVAL,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
 
 
     def _register_workflows(self):
     def _register_workflows(self):
         pass
         pass
@@ -152,12 +142,12 @@ class RetrievalRouterV3(BaseRouterV3):
                                     query: "Who is Aristotle?",
                                     query: "Who is Aristotle?",
                                     search_settings: {
                                     search_settings: {
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
-                                        use_semantic_search: true,
-                                        chunk_settings: {
+                                        useSemanticSearch: true,
+                                        chunkSettings: {
                                             limit: 20, # separate limit for chunk vs. graph
                                             limit: 20, # separate limit for chunk vs. graph
                                             enabled: true
                                             enabled: true
                                         },
                                         },
-                                        graph_settings: {
+                                        graphSettings: {
                                             enabled: true,
                                             enabled: true,
                                         }
                                         }
                                     }
                                     }
@@ -229,7 +219,7 @@ class RetrievalRouterV3(BaseRouterV3):
                     "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
                     "Common overrides include `filters` to narrow results and `limit` to control how many results are returned."
                 ),
                 ),
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedSearchResponse:
         ) -> WrappedSearchResponse:
             """
             """
             Perform a search query against vector and/or graph-based databases.
             Perform a search query against vector and/or graph-based databases.
@@ -269,7 +259,7 @@ class RetrievalRouterV3(BaseRouterV3):
             effective_settings = self._prepare_search_settings(
             effective_settings = self._prepare_search_settings(
                 auth_user, search_mode, search_settings
                 auth_user, search_mode, search_settings
             )
             )
-            results = await self.services["retrieval"].search(
+            results = await self.services.retrieval.search(
                 query=query,
                 query=query,
                 search_settings=effective_settings,
                 search_settings=effective_settings,
             )
             )
@@ -326,19 +316,19 @@ class RetrievalRouterV3(BaseRouterV3):
                                     query: "Who is Aristotle?",
                                     query: "Who is Aristotle?",
                                     search_settings: {
                                     search_settings: {
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
-                                        use_semantic_search: true,
-                                        chunk_settings: {
+                                        useSemanticSearch: true,
+                                        chunkSettings: {
                                             limit: 20, # separate limit for chunk vs. graph
                                             limit: 20, # separate limit for chunk vs. graph
                                             enabled: true
                                             enabled: true
                                         },
                                         },
-                                        graph_settings: {
+                                        graphSettings: {
                                             enabled: true,
                                             enabled: true,
                                         },
                                         },
                                     },
                                     },
-                                    rag_generation_config: {
+                                    ragGenerationConfig: {
                                         stream: false,
                                         stream: false,
                                         temperature: 0.7,
                                         temperature: 0.7,
-                                        max_tokens: 150
+                                        maxTokens: 150
                                     }
                                     }
                                 });
                                 });
                             }
                             }
@@ -422,7 +412,7 @@ class RetrievalRouterV3(BaseRouterV3):
                 default=False,
                 default=False,
                 description="Include document titles in responses when available",
                 description="Include document titles in responses when available",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedRAGResponse:
         ) -> WrappedRAGResponse:
             """
             """
             Execute a RAG (Retrieval-Augmented Generation) query.
             Execute a RAG (Retrieval-Augmented Generation) query.
@@ -438,7 +428,7 @@ class RetrievalRouterV3(BaseRouterV3):
                 auth_user, search_mode, search_settings
                 auth_user, search_mode, search_settings
             )
             )
 
 
-            response = await self.services["retrieval"].rag(
+            response = await self.services.retrieval.rag(
                 query=query,
                 query=query,
                 search_settings=effective_settings,
                 search_settings=effective_settings,
                 rag_generation_config=rag_generation_config,
                 rag_generation_config=rag_generation_config,
@@ -464,8 +454,8 @@ class RetrievalRouterV3(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/retrieval/agent",
             "/retrieval/agent",
-            summary="RAG-powered Conversational Agent",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
+            summary="RAG-powered Conversational Agent",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -518,21 +508,21 @@ class RetrievalRouterV3(BaseRouterV3):
                                         role: "user",
                                         role: "user",
                                         content: "What were the key contributions of Aristotle to logic and how did they influence later philosophers?"
                                         content: "What were the key contributions of Aristotle to logic and how did they influence later philosophers?"
                                     },
                                     },
-                                    search_settings: {
+                                    searchSettings: {
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
                                         filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}},
-                                        use_semantic_search: true,
-                                        chunk_settings: {
+                                        useSemanticSearch: true,
+                                        chunkSettings: {
                                             limit: 20, # separate limit for chunk vs. graph
                                             limit: 20, # separate limit for chunk vs. graph
                                             enabled: true
                                             enabled: true
                                         },
                                         },
-                                        graph_settings: {
+                                        graphSettings: {
                                             enabled: true,
                                             enabled: true,
                                         },
                                         },
                                     },
                                     },
-                                    rag_generation_config: {
+                                    ragGenerationConfig: {
                                         stream: false,
                                         stream: false,
                                         temperature: 0.7,
                                         temperature: 0.7,
-                                        max_tokens: 150
+                                        maxTokens: 150
                                     },
                                     },
                                     includeTitleIfAvailable: true,
                                     includeTitleIfAvailable: true,
                                     conversationId: "550e8400-e29b-41d4-a716-446655440000"
                                     conversationId: "550e8400-e29b-41d4-a716-446655440000"
@@ -622,7 +612,7 @@ class RetrievalRouterV3(BaseRouterV3):
                 default=None,
                 default=None,
                 description="ID of the conversation",
                 description="ID of the conversation",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedAgentResponse:
         ) -> WrappedAgentResponse:
             """
             """
             Engage with an intelligent RAG-powered conversational agent for complex information retrieval and analysis.
             Engage with an intelligent RAG-powered conversational agent for complex information retrieval and analysis.
@@ -661,7 +651,7 @@ class RetrievalRouterV3(BaseRouterV3):
             )
             )
 
 
             try:
             try:
-                response = await self.services["retrieval"].agent(
+                response = await self.services.retrieval.agent(
                     message=message,
                     message=message,
                     messages=messages,
                     messages=messages,
                     search_settings=effective_settings,
                     search_settings=effective_settings,
@@ -810,7 +800,7 @@ class RetrievalRouterV3(BaseRouterV3):
                     "stream": False,
                     "stream": False,
                 },
                 },
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
             response_model=WrappedCompletionResponse,
             response_model=WrappedCompletionResponse,
         ):
         ):
             """
             """
@@ -823,7 +813,7 @@ class RetrievalRouterV3(BaseRouterV3):
             system message at the start. Each message should have a 'role' and 'content'.
             system message at the start. Each message should have a 'role' and 'content'.
             """
             """
 
 
-            return await self.services["retrieval"].completion(
+            return await self.services.retrieval.completion(
                 messages=messages,
                 messages=messages,
                 generation_config=generation_config,
                 generation_config=generation_config,
             )
             )
@@ -889,7 +879,7 @@ class RetrievalRouterV3(BaseRouterV3):
                 ...,
                 ...,
                 description="Text to generate embeddings for",
                 description="Text to generate embeddings for",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
             """
             """
             Generate embeddings for the provided text using the specified model.
             Generate embeddings for the provided text using the specified model.
@@ -898,6 +888,6 @@ class RetrievalRouterV3(BaseRouterV3):
             The model parameter specifies the model to use for generating embeddings.
             The model parameter specifies the model to use for generating embeddings.
             """
             """
 
 
-            return await self.services["retrieval"].embedding(
+            return await self.services.retrieval.embedding(
                 text=text,
                 text=text,
             )
             )

+ 11 - 97
core/main/api/v3/system_router.py

@@ -5,7 +5,7 @@ from typing import Optional
 import psutil
 import psutil
 from fastapi import Depends, Query
 from fastapi import Depends, Query
 
 
-from core.base import R2RException, RunType
+from core.base import R2RException
 from core.base.api.models import (
 from core.base.api.models import (
     GenericMessageResponse,
     GenericMessageResponse,
     WrappedGenericMessageResponse,
     WrappedGenericMessageResponse,
@@ -13,30 +13,24 @@ from core.base.api.models import (
     WrappedServerStatsResponse,
     WrappedServerStatsResponse,
     WrappedSettingsResponse,
     WrappedSettingsResponse,
 )
 )
-from core.providers import (
-    HatchetOrchestrationProvider,
-    SimpleOrchestrationProvider,
-)
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 
 
 class SystemRouter(BaseRouterV3):
 class SystemRouter(BaseRouterV3):
     def __init__(
     def __init__(
         self,
         self,
-        providers,
-        services,
-        orchestration_provider: (
-            HatchetOrchestrationProvider | SimpleOrchestrationProvider
-        ),
-        run_type: RunType = RunType.MANAGEMENT,
+        providers: R2RProviders,
+        services: R2RServices,
     ):
     ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+        super().__init__(providers, services)
         self.start_time = datetime.now(timezone.utc)
         self.start_time = datetime.now(timezone.utc)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
         @self.router.get(
         @self.router.get(
             "/health",
             "/health",
+            # dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -95,6 +89,7 @@ class SystemRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/system/settings",
             "/system/settings",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -149,17 +144,18 @@ class SystemRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def app_settings(
         async def app_settings(
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedSettingsResponse:
         ) -> WrappedSettingsResponse:
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
                 raise R2RException(
                 raise R2RException(
                     "Only a superuser can call the `system/settings` endpoint.",
                     "Only a superuser can call the `system/settings` endpoint.",
                     403,
                     403,
                 )
                 )
-            return await self.services["management"].app_settings()
+            return await self.services.management.app_settings()
 
 
         @self.router.get(
         @self.router.get(
             "/system/status",
             "/system/status",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -214,7 +210,7 @@ class SystemRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def server_stats(
         async def server_stats(
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedServerStatsResponse:
         ) -> WrappedServerStatsResponse:
             if not auth_user.is_superuser:
             if not auth_user.is_superuser:
                 raise R2RException(
                 raise R2RException(
@@ -229,85 +225,3 @@ class SystemRouter(BaseRouterV3):
                 "cpu_usage": psutil.cpu_percent(),
                 "cpu_usage": psutil.cpu_percent(),
                 "memory_usage": psutil.virtual_memory().percent,
                 "memory_usage": psutil.virtual_memory().percent,
             }
             }
-
-        @self.router.get(
-            "/system/logs",
-            openapi_extra={
-                "x-codeSamples": [
-                    {
-                        "lang": "Python",
-                        "source": textwrap.dedent(
-                            """
-                            from r2r import R2RClient
-
-                            client = R2RClient("http://localhost:7272")
-                            # when using auth, do client.login(...)
-
-                            result = client.system.logs()
-                        """
-                        ),
-                    },
-                    {
-                        "lang": "JavaScript",
-                        "source": textwrap.dedent(
-                            """
-                            const { r2rClient } = require("r2r-js");
-
-                            const client = new r2rClient("http://localhost:7272");
-
-                            function main() {
-                                const response = await client.system.logs({});
-                            }
-
-                            main();
-                            """
-                        ),
-                    },
-                    {
-                        "lang": "CLI",
-                        "source": textwrap.dedent(
-                            """
-                            r2r system logs
-                            """
-                        ),
-                    },
-                    {
-                        "lang": "cURL",
-                        "source": textwrap.dedent(
-                            """
-                            curl -X POST "https://api.example.com/v3/system/logs" \\
-                                 -H "Content-Type: application/json" \\
-                                 -H "Authorization: Bearer YOUR_API_KEY" \\
-                        """
-                        ),
-                    },
-                ]
-            },
-        )
-        @self.base_endpoint
-        async def logs(
-            run_type_filter: Optional[str] = Query(""),
-            offset: int = Query(
-                0,
-                ge=0,
-                description="Specifies the number of objects to skip. Defaults to 0.",
-            ),
-            limit: int = Query(
-                100,
-                ge=1,
-                le=1000,
-                description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
-            ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
-        ) -> WrappedLogsResponse:
-            if not auth_user.is_superuser:
-                raise R2RException(
-                    "Only a superuser can call the `system/logs` endpoint.",
-                    403,
-                )
-
-            return await self.services["management"].logs(
-                run_type_filter=run_type_filter,
-                offset=offset,
-                limit=limit,
-            )

+ 271 - 50
core/main/api/v3/users_router.py

@@ -10,6 +10,8 @@ from core.base import R2RException
 from core.base.api.models import (
 from core.base.api.models import (
     GenericBooleanResponse,
     GenericBooleanResponse,
     GenericMessageResponse,
     GenericMessageResponse,
+    WrappedAPIKeyResponse,
+    WrappedAPIKeysResponse,
     WrappedBooleanResponse,
     WrappedBooleanResponse,
     WrappedCollectionsResponse,
     WrappedCollectionsResponse,
     WrappedGenericMessageResponse,
     WrappedGenericMessageResponse,
@@ -18,21 +20,21 @@ from core.base.api.models import (
     WrappedUsersResponse,
     WrappedUsersResponse,
 )
 )
 
 
+from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
 
 
 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
 
 
 
 class UsersRouter(BaseRouterV3):
 class UsersRouter(BaseRouterV3):
-    def __init__(
-        self, providers, services, orchestration_provider=None, run_type=None
-    ):
-        super().__init__(providers, services, orchestration_provider, run_type)
+    def __init__(self, providers: R2RProviders, services: R2RServices):
+        super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
 
 
         @self.router.post(
         @self.router.post(
             "/users",
             "/users",
+            # dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedUserResponse,
             response_model=WrappedUserResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -104,18 +106,31 @@ class UsersRouter(BaseRouterV3):
             profile_picture: str | None = Body(
             profile_picture: str | None = Body(
                 None, description="Updated user profile picture"
                 None, description="Updated user profile picture"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            # auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
         ) -> WrappedUserResponse:
             """Register a new user with the given email and password."""
             """Register a new user with the given email and password."""
-            print('email = ', email)
-            print('making request.....')
-            registration_response = await self.services["auth"].register(
+
+            # TODO: Do we really want this validation? The default password for the superuser would not pass...
+            def validate_password(password: str) -> bool:
+                if len(password) < 10:
+                    return False
+                if not any(c.isupper() for c in password):
+                    return False
+                if not any(c.islower() for c in password):
+                    return False
+                if not any(c.isdigit() for c in password):
+                    return False
+                if not any(c in "!@#$%^&*" for c in password):
+                    return False
+                return True
+
+            validate_password(password)
+
+            registration_response = await self.services.auth.register(
                 email, password
                 email, password
             )
             )
-            print('registration_response = ', registration_response)
-
             if name or bio or profile_picture:
             if name or bio or profile_picture:
-                return await self.services["auth"].update_user(
+                return await self.services.auth.update_user(
                     user_id=registration_response.id,
                     user_id=registration_response.id,
                     name=name,
                     name=name,
                     bio=bio,
                     bio=bio,
@@ -127,6 +142,7 @@ class UsersRouter(BaseRouterV3):
         # TODO: deprecated, remove in next release
         # TODO: deprecated, remove in next release
         @self.router.post(
         @self.router.post(
             "/users/register",
             "/users/register",
+            # dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedUserResponse,
             response_model=WrappedUserResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -191,10 +207,11 @@ class UsersRouter(BaseRouterV3):
             password: str = Body(..., description="User's password"),
             password: str = Body(..., description="User's password"),
         ):
         ):
             """Register a new user with the given email and password."""
             """Register a new user with the given email and password."""
-            return await self.services["auth"].register(email, password)
+            return await self.services.auth.register(email, password)
 
 
         @self.router.post(
         @self.router.post(
             "/users/verify-email",
             "/users/verify-email",
+            # dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedGenericMessageResponse,
             response_model=WrappedGenericMessageResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -251,13 +268,25 @@ class UsersRouter(BaseRouterV3):
             ),
             ),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """Verify a user's email address."""
             """Verify a user's email address."""
-            result = await self.services["auth"].verify_email(
+            user = (
+                await self.providers.database.users_handler.get_user_by_email(
+                    email
+                )
+            )
+            if user and user.is_verified:
+                raise R2RException(
+                    status_code=400,
+                    message="This email is already verified. Please log in.",
+                )
+
+            result = await self.services.auth.verify_email(
                 email, verification_code
                 email, verification_code
             )
             )
             return GenericMessageResponse(message=result["message"])  # type: ignore
             return GenericMessageResponse(message=result["message"])  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/users/login",
             "/users/login",
+            # dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedTokenResponse,
             response_model=WrappedTokenResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -310,7 +339,7 @@ class UsersRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def login(form_data: OAuth2PasswordRequestForm = Depends()):
         async def login(form_data: OAuth2PasswordRequestForm = Depends()):
             """Authenticate a user and provide access tokens."""
             """Authenticate a user and provide access tokens."""
-            return await self.services["auth"].login(
+            return await self.services.auth.login(
                 form_data.username, form_data.password
                 form_data.username, form_data.password
             )
             )
 
 
@@ -362,14 +391,15 @@ class UsersRouter(BaseRouterV3):
         @self.base_endpoint
         @self.base_endpoint
         async def logout(
         async def logout(
             token: str = Depends(oauth2_scheme),
             token: str = Depends(oauth2_scheme),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """Log out the current user."""
             """Log out the current user."""
-            result = await self.services["auth"].logout(token)
+            result = await self.services.auth.logout(token)
             return GenericMessageResponse(message=result["message"])  # type: ignore
             return GenericMessageResponse(message=result["message"])  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/users/refresh-token",
             "/users/refresh-token",
+            dependencies=[Depends(self.rate_limit_dependency)],
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
                     {
                     {
@@ -420,13 +450,14 @@ class UsersRouter(BaseRouterV3):
             refresh_token: str = Body(..., description="Refresh token")
             refresh_token: str = Body(..., description="Refresh token")
         ) -> WrappedTokenResponse:
         ) -> WrappedTokenResponse:
             """Refresh the access token using a refresh token."""
             """Refresh the access token using a refresh token."""
-            result = await self.services["auth"].refresh_access_token(
+            result = await self.services.auth.refresh_access_token(
                 refresh_token=refresh_token
                 refresh_token=refresh_token
             )
             )
             return result
             return result
 
 
         @self.router.post(
         @self.router.post(
             "/users/change-password",
             "/users/change-password",
+            dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedGenericMessageResponse,
             response_model=WrappedGenericMessageResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -484,16 +515,19 @@ class UsersRouter(BaseRouterV3):
         async def change_password(
         async def change_password(
             current_password: str = Body(..., description="Current password"),
             current_password: str = Body(..., description="Current password"),
             new_password: str = Body(..., description="New password"),
             new_password: str = Body(..., description="New password"),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> GenericMessageResponse:
         ) -> GenericMessageResponse:
             """Change the authenticated user's password."""
             """Change the authenticated user's password."""
-            result = await self.services["auth"].change_password(
+            result = await self.services.auth.change_password(
                 auth_user, current_password, new_password
                 auth_user, current_password, new_password
             )
             )
             return GenericMessageResponse(message=result["message"])  # type: ignore
             return GenericMessageResponse(message=result["message"])  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/users/request-password-reset",
             "/users/request-password-reset",
+            dependencies=[
+                Depends(self.providers.auth.auth_wrapper(public=True))
+            ],
             response_model=WrappedGenericMessageResponse,
             response_model=WrappedGenericMessageResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -543,14 +577,15 @@ class UsersRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def request_password_reset(
         async def request_password_reset(
-            email: EmailStr = Body(..., description="User's email address")
+            email: EmailStr = Body(..., description="User's email address"),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """Request a password reset for a user."""
             """Request a password reset for a user."""
-            result = await self.services["auth"].request_password_reset(email)
+            result = await self.services.auth.request_password_reset(email)
             return GenericMessageResponse(message=result["message"])  # type: ignore
             return GenericMessageResponse(message=result["message"])  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/users/reset-password",
             "/users/reset-password",
+            dependencies=[Depends(self.rate_limit_dependency)],
             response_model=WrappedGenericMessageResponse,
             response_model=WrappedGenericMessageResponse,
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -607,13 +642,14 @@ class UsersRouter(BaseRouterV3):
             new_password: str = Body(..., description="New password"),
             new_password: str = Body(..., description="New password"),
         ) -> WrappedGenericMessageResponse:
         ) -> WrappedGenericMessageResponse:
             """Reset a user's password using a reset token."""
             """Reset a user's password using a reset token."""
-            result = await self.services["auth"].confirm_password_reset(
+            result = await self.services.auth.confirm_password_reset(
                 reset_token, new_password
                 reset_token, new_password
             )
             )
             return GenericMessageResponse(message=result["message"])  # type: ignore
             return GenericMessageResponse(message=result["message"])  # type: ignore
 
 
         @self.router.get(
         @self.router.get(
             "/users",
             "/users",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="List Users",
             summary="List Users",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -679,7 +715,7 @@ class UsersRouter(BaseRouterV3):
             #     email: Optional[str] = Query(None, example="john@example.com"),
             #     email: Optional[str] = Query(None, example="john@example.com"),
             #     is_active: Optional[bool] = Query(None, example=True),
             #     is_active: Optional[bool] = Query(None, example=True),
             #     is_superuser: Optional[bool] = Query(None, example=False),
             #     is_superuser: Optional[bool] = Query(None, example=False),
-            #     auth_user=Depends(self.providers.auth.auth_wrapper),
+            #     auth_user=Depends(self.providers.auth.auth_wrapper()),
             ids: list[str] = Query(
             ids: list[str] = Query(
                 [], description="List of user IDs to filter by"
                 [], description="List of user IDs to filter by"
             ),
             ),
@@ -694,7 +730,7 @@ class UsersRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUsersResponse:
         ) -> WrappedUsersResponse:
             """
             """
             List all users with pagination and filtering options.
             List all users with pagination and filtering options.
@@ -709,15 +745,18 @@ class UsersRouter(BaseRouterV3):
 
 
             user_uuids = [UUID(user_id) for user_id in ids]
             user_uuids = [UUID(user_id) for user_id in ids]
 
 
-            users_overview_response = await self.services[
-                "management"
-            ].users_overview(user_ids=user_uuids, offset=offset, limit=limit)
+            users_overview_response = (
+                await self.services.management.users_overview(
+                    user_ids=user_uuids, offset=offset, limit=limit
+                )
+            )
             return users_overview_response["results"], {  # type: ignore
             return users_overview_response["results"], {  # type: ignore
                 "total_entries": users_overview_response["total_entries"]
                 "total_entries": users_overview_response["total_entries"]
             }
             }
 
 
         @self.router.get(
         @self.router.get(
             "/users/me",
             "/users/me",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Get the Current User",
             summary="Get the Current User",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -773,7 +812,7 @@ class UsersRouter(BaseRouterV3):
         )
         )
         @self.base_endpoint
         @self.base_endpoint
         async def get_current_user(
         async def get_current_user(
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
         ) -> WrappedUserResponse:
             """
             """
             Get detailed information about the currently authenticated user.
             Get detailed information about the currently authenticated user.
@@ -782,6 +821,7 @@ class UsersRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/users/{id}",
             "/users/{id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Get User Details",
             summary="Get User Details",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -844,7 +884,7 @@ class UsersRouter(BaseRouterV3):
             id: UUID = Path(
             id: UUID = Path(
                 ..., example="550e8400-e29b-41d4-a716-446655440000"
                 ..., example="550e8400-e29b-41d4-a716-446655440000"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
         ) -> WrappedUserResponse:
             """
             """
             Get detailed information about a specific user.
             Get detailed information about a specific user.
@@ -856,18 +896,19 @@ class UsersRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            users_overview_response = await self.services[
-                "management"
-            ].users_overview(
-                offset=0,
-                limit=1,
-                user_ids=[id],
+            users_overview_response = (
+                await self.services.management.users_overview(
+                    offset=0,
+                    limit=1,
+                    user_ids=[id],
+                )
             )
             )
 
 
             return users_overview_response["results"][0]
             return users_overview_response["results"][0]
 
 
         @self.router.delete(
         @self.router.delete(
             "/users/{id}",
             "/users/{id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Delete User",
             summary="Delete User",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -919,7 +960,7 @@ class UsersRouter(BaseRouterV3):
                 False,
                 False,
                 description="Whether to delete the user's vector data",
                 description="Whether to delete the user's vector data",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Delete a specific user.
             Delete a specific user.
@@ -931,7 +972,7 @@ class UsersRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            await self.services["auth"].delete_user(
+            await self.services.auth.delete_user(
                 user_id=id,
                 user_id=id,
                 password=password,
                 password=password,
                 delete_vector_data=delete_vector_data,
                 delete_vector_data=delete_vector_data,
@@ -941,6 +982,7 @@ class UsersRouter(BaseRouterV3):
 
 
         @self.router.get(
         @self.router.get(
             "/users/{id}/collections",
             "/users/{id}/collections",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Get User Collections",
             summary="Get User Collections",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1018,7 +1060,7 @@ class UsersRouter(BaseRouterV3):
                 le=1000,
                 le=1000,
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
                 description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.",
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedCollectionsResponse:
         ) -> WrappedCollectionsResponse:
             """
             """
             Get all collections associated with a specific user.
             Get all collections associated with a specific user.
@@ -1029,12 +1071,12 @@ class UsersRouter(BaseRouterV3):
                     "The currently authenticated user does not have access to the specified collection.",
                     "The currently authenticated user does not have access to the specified collection.",
                     403,
                     403,
                 )
                 )
-            user_collection_response = await self.services[
-                "management"
-            ].collections_overview(
-                offset=offset,
-                limit=limit,
-                user_ids=[id],
+            user_collection_response = (
+                await self.services.management.collections_overview(
+                    offset=offset,
+                    limit=limit,
+                    user_ids=[id],
+                )
             )
             )
             return user_collection_response["results"], {  # type: ignore
             return user_collection_response["results"], {  # type: ignore
                 "total_entries": user_collection_response["total_entries"]
                 "total_entries": user_collection_response["total_entries"]
@@ -1042,6 +1084,7 @@ class UsersRouter(BaseRouterV3):
 
 
         @self.router.post(
         @self.router.post(
             "/users/{id}/collections/{collection_id}",
             "/users/{id}/collections/{collection_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Add User to Collection",
             summary="Add User to Collection",
             response_model=WrappedBooleanResponse,
             response_model=WrappedBooleanResponse,
             openapi_extra={
             openapi_extra={
@@ -1110,7 +1153,7 @@ class UsersRouter(BaseRouterV3):
             collection_id: UUID = Path(
             collection_id: UUID = Path(
                 ..., example="750e8400-e29b-41d4-a716-446655440000"
                 ..., example="750e8400-e29b-41d4-a716-446655440000"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             if auth_user.id != id and not auth_user.is_superuser:
             if auth_user.id != id and not auth_user.is_superuser:
                 raise R2RException(
                 raise R2RException(
@@ -1119,13 +1162,14 @@ class UsersRouter(BaseRouterV3):
                 )
                 )
 
 
             # TODO - Do we need a check on user access to the collection?
             # TODO - Do we need a check on user access to the collection?
-            await self.services["management"].add_user_to_collection(  # type: ignore
+            await self.services.management.add_user_to_collection(  # type: ignore
                 id, collection_id
                 id, collection_id
             )
             )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.delete(
         @self.router.delete(
             "/users/{id}/collections/{collection_id}",
             "/users/{id}/collections/{collection_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Remove User from Collection",
             summary="Remove User from Collection",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1193,7 +1237,7 @@ class UsersRouter(BaseRouterV3):
             collection_id: UUID = Path(
             collection_id: UUID = Path(
                 ..., example="750e8400-e29b-41d4-a716-446655440000"
                 ..., example="750e8400-e29b-41d4-a716-446655440000"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedBooleanResponse:
         ) -> WrappedBooleanResponse:
             """
             """
             Remove a user from a collection.
             Remove a user from a collection.
@@ -1206,13 +1250,14 @@ class UsersRouter(BaseRouterV3):
                 )
                 )
 
 
             # TODO - Do we need a check on user access to the collection?
             # TODO - Do we need a check on user access to the collection?
-            await self.services["management"].remove_user_from_collection(  # type: ignore
+            await self.services.management.remove_user_from_collection(  # type: ignore
                 id, collection_id
                 id, collection_id
             )
             )
             return GenericBooleanResponse(success=True)  # type: ignore
             return GenericBooleanResponse(success=True)  # type: ignore
 
 
         @self.router.post(
         @self.router.post(
             "/users/{id}",
             "/users/{id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
             summary="Update User",
             summary="Update User",
             openapi_extra={
             openapi_extra={
                 "x-codeSamples": [
                 "x-codeSamples": [
@@ -1284,7 +1329,7 @@ class UsersRouter(BaseRouterV3):
             profile_picture: str | None = Body(
             profile_picture: str | None = Body(
                 None, description="Updated profile picture URL"
                 None, description="Updated profile picture URL"
             ),
             ),
-            auth_user=Depends(self.providers.auth.auth_wrapper),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
         ) -> WrappedUserResponse:
             """
             """
             Update user information.
             Update user information.
@@ -1303,7 +1348,7 @@ class UsersRouter(BaseRouterV3):
                     403,
                     403,
                 )
                 )
 
 
-            return await self.services["auth"].update_user(
+            return await self.services.auth.update_user(
                 user_id=id,
                 user_id=id,
                 email=email,
                 email=email,
                 is_superuser=is_superuser,
                 is_superuser=is_superuser,
@@ -1311,3 +1356,179 @@ class UsersRouter(BaseRouterV3):
                 bio=bio,
                 bio=bio,
                 profile_picture=profile_picture,
                 profile_picture=profile_picture,
             )
             )
+
+        @self.router.post(
+            "/users/{id}/api-keys",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            summary="Create User API Key",
+            response_model=WrappedAPIKeyResponse,
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # client.login(...)
+
+                            result = client.users.create_api_key(
+                                id="550e8400-e29b-41d4-a716-446655440000",
+                            )
+                            # result["api_key"] contains the newly created API key
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X POST "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+                                -H "Authorization: Bearer YOUR_API_TOKEN"
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def create_user_api_key(
+            id: UUID = Path(
+                ..., description="ID of the user for whom to create an API key"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> WrappedAPIKeyResponse:
+            """
+            Create a new API key for the specified user.
+            Only superusers or the user themselves may create an API key.
+            """
+            if auth_user.id != id and not auth_user.is_superuser:
+                raise R2RException(
+                    "Only the user themselves or a superuser can create API keys for this user.",
+                    403,
+                )
+
+            api_key = await self.services.auth.create_user_api_key(id)
+            return api_key  # type: ignore
+
+        @self.router.get(
+            "/users/{id}/api-keys",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            summary="List User API Keys",
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+
+                            client = R2RClient("http://localhost:7272")
+                            # client.login(...)
+
+                            keys = client.users.list_api_keys(
+                                id="550e8400-e29b-41d4-a716-446655440000"
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X GET "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys" \\
+                                -H "Authorization: Bearer YOUR_API_TOKEN"
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def list_user_api_keys(
+            id: UUID = Path(
+                ..., description="ID of the user whose API keys to list"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> WrappedAPIKeysResponse:
+            """
+            List all API keys for the specified user.
+            Only superusers or the user themselves may list the API keys.
+            """
+            if auth_user.id != id and not auth_user.is_superuser:
+                raise R2RException(
+                    "Only the user themselves or a superuser can list API keys for this user.",
+                    403,
+                )
+
+            keys = (
+                await self.providers.database.users_handler.get_user_api_keys(
+                    id
+                )
+            )
+            return keys, {"total_entries": len(keys)}  # type: ignore
+
+        @self.router.delete(
+            "/users/{id}/api-keys/{key_id}",
+            dependencies=[Depends(self.rate_limit_dependency)],
+            summary="Delete User API Key",
+            openapi_extra={
+                "x-codeSamples": [
+                    {
+                        "lang": "Python",
+                        "source": textwrap.dedent(
+                            """
+                            from r2r import R2RClient
+                            from uuid import UUID
+
+                            client = R2RClient("http://localhost:7272")
+                            # client.login(...)
+
+                            response = client.users.delete_api_key(
+                                id="550e8400-e29b-41d4-a716-446655440000",
+                                key_id="d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25"
+                            )
+                            """
+                        ),
+                    },
+                    {
+                        "lang": "cURL",
+                        "source": textwrap.dedent(
+                            """
+                            curl -X DELETE "https://api.example.com/users/550e8400-e29b-41d4-a716-446655440000/api-keys/d9c562d4-3aef-43e8-8f08-0cf7cd5e0a25" \\
+                                -H "Authorization: Bearer YOUR_API_TOKEN"
+                            """
+                        ),
+                    },
+                ]
+            },
+        )
+        @self.base_endpoint
+        async def delete_user_api_key(
+            id: UUID = Path(..., description="ID of the user"),
+            key_id: UUID = Path(
+                ..., description="ID of the API key to delete"
+            ),
+            auth_user=Depends(self.providers.auth.auth_wrapper()),
+        ) -> WrappedBooleanResponse:
+            """
+            Delete a specific API key for the specified user.
+            Only superusers or the user themselves may delete the API key.
+            """
+            if auth_user.id != id and not auth_user.is_superuser:
+                raise R2RException(
+                    "Only the user themselves or a superuser can delete this API key.",
+                    403,
+                )
+
+            success = (
+                await self.providers.database.users_handler.delete_api_key(
+                    id, key_id
+                )
+            )
+            if not success:
+                raise R2RException(
+                    "API key not found or could not be deleted", 400
+                )
+            return {"success": True}  # type: ignore

+ 21 - 38
core/main/assembly/builder.py

@@ -1,6 +1,5 @@
 import logging
 import logging
-from dataclasses import dataclass
-from typing import Any, Optional, Type
+from typing import Any, Type
 
 
 from core.agent import R2RRAGAgent
 from core.agent import R2RRAGAgent
 from core.base import (
 from core.base import (
@@ -13,6 +12,12 @@ from core.base import (
     OrchestrationProvider,
     OrchestrationProvider,
     RunManager,
     RunManager,
 )
 )
+from core.main.abstractions import R2RServices
+from core.main.services.auth_service import AuthService
+from core.main.services.graph_service import GraphService
+from core.main.services.ingestion_service import IngestionService
+from core.main.services.management_service import ManagementService
+from core.main.services.retrieval_service import RetrievalService
 from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
 from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
 
 
 from ..abstractions import R2RProviders
 from ..abstractions import R2RProviders
@@ -29,11 +34,6 @@ from ..api.v3.system_router import SystemRouter
 from ..api.v3.users_router import UsersRouter
 from ..api.v3.users_router import UsersRouter
 from ..app import R2RApp
 from ..app import R2RApp
 from ..config import R2RConfig
 from ..config import R2RConfig
-from ..services.auth_service import AuthService
-from ..services.ingestion_service import IngestionService
-from ..services.kg_service import KgService
-from ..services.management_service import ManagementService
-from ..services.retrieval_service import RetrievalService
 from .factory import (
 from .factory import (
     R2RAgentFactory,
     R2RAgentFactory,
     R2RPipeFactory,
     R2RPipeFactory,
@@ -44,15 +44,6 @@ from .factory import (
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
-@dataclass
-class Services:
-    auth: Optional["AuthService"] = None
-    ingestion: Optional["IngestionService"] = None
-    management: Optional["ManagementService"] = None
-    retrieval: Optional["RetrievalService"] = None
-    kg: Optional["KgService"] = None
-
-
 class R2RBuilder:
 class R2RBuilder:
     def __init__(self, config: R2RConfig):
     def __init__(self, config: R2RConfig):
         self.config = config
         self.config = config
@@ -60,7 +51,7 @@ class R2RBuilder:
     def _create_pipes(
     def _create_pipes(
         self,
         self,
         pipe_factory: type[R2RPipeFactory],
         pipe_factory: type[R2RPipeFactory],
-        providers: Any,
+        providers: R2RProviders,
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> Any:
     ) -> Any:
@@ -80,17 +71,22 @@ class R2RBuilder:
             self.config, providers, pipes
             self.config, providers, pipes
         ).create_pipelines(*args, **kwargs)
         ).create_pipelines(*args, **kwargs)
 
 
-    def _create_services(
-        self, service_params: dict[str, Any]
-    ) -> dict[str, Any]:
-        services = {}
-        for service_type, override in vars(Services()).items():
+    def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+        service_instances = {}
+        for service_type, override in vars(R2RServices()).items():
             logger.info(f"Creating {service_type} service")
             logger.info(f"Creating {service_type} service")
             service_class = globals()[f"{service_type.capitalize()}Service"]
             service_class = globals()[f"{service_type.capitalize()}Service"]
-            services[service_type] = override or service_class(
+            service_instances[service_type] = override or service_class(
                 **service_params
                 **service_params
             )
             )
-        return services
+
+        return R2RServices(
+            auth=service_instances["auth"],
+            ingestion=service_instances["ingestion"],
+            management=service_instances["management"],
+            retrieval=service_instances["retrieval"],
+            graph=service_instances["graph"],
+        )
 
 
     async def _create_providers(
     async def _create_providers(
         self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
         self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
@@ -133,68 +129,55 @@ class R2RBuilder:
 
 
         services = self._create_services(service_params)
         services = self._create_services(service_params)
 
 
-        orchestration_provider = providers.orchestration
-
         routers = {
         routers = {
             "chunks_router": ChunksRouter(
             "chunks_router": ChunksRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "collections_router": CollectionsRouter(
             "collections_router": CollectionsRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "conversations_router": ConversationsRouter(
             "conversations_router": ConversationsRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "documents_router": DocumentsRouter(
             "documents_router": DocumentsRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "graph_router": GraphRouter(
             "graph_router": GraphRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "indices_router": IndicesRouter(
             "indices_router": IndicesRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "logs_router": LogsRouter(
             "logs_router": LogsRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "prompts_router": PromptsRouter(
             "prompts_router": PromptsRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "retrieval_router_v3": RetrievalRouterV3(
             "retrieval_router_v3": RetrievalRouterV3(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "system_router": SystemRouter(
             "system_router": SystemRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
             "users_router": UsersRouter(
             "users_router": UsersRouter(
                 providers=providers,
                 providers=providers,
                 services=services,
                 services=services,
-                orchestration_provider=orchestration_provider,
             ).get_router(),
             ).get_router(),
         }
         }
 
 
         return R2RApp(
         return R2RApp(
             config=self.config,
             config=self.config,
-            orchestration_provider=orchestration_provider,
+            orchestration_provider=providers.orchestration,
             **routers,
             **routers,
         )
         )

+ 111 - 111
core/main/assembly/factory.py

@@ -1,6 +1,6 @@
 import logging
 import logging
 import os
 import os
-from typing import Any, Optional, Union
+from typing import Any, Optional
 
 
 from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
 from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
 from core.base import (
 from core.base import (
@@ -25,14 +25,16 @@ from ..config import R2RConfig
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 from core.database import PostgresDatabaseProvider
 from core.database import PostgresDatabaseProvider
-from core.providers import (  # PostgresDatabaseProvider,
+from core.providers import (
     AsyncSMTPEmailProvider,
     AsyncSMTPEmailProvider,
-    BCryptConfig,
-    BCryptProvider,
+    BcryptCryptoConfig,
+    BCryptCryptoProvider,
     ConsoleMockEmailProvider,
     ConsoleMockEmailProvider,
     HatchetOrchestrationProvider,
     HatchetOrchestrationProvider,
     LiteLLMCompletionProvider,
     LiteLLMCompletionProvider,
     LiteLLMEmbeddingProvider,
     LiteLLMEmbeddingProvider,
+    NaClCryptoConfig,
+    NaClCryptoProvider,
     OllamaEmbeddingProvider,
     OllamaEmbeddingProvider,
     OpenAICompletionProvider,
     OpenAICompletionProvider,
     OpenAIEmbeddingProvider,
     OpenAIEmbeddingProvider,
@@ -53,16 +55,16 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     async def create_auth_provider(
     async def create_auth_provider(
         auth_config: AuthConfig,
         auth_config: AuthConfig,
-        crypto_provider: BCryptProvider,
+        crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
         database_provider: PostgresDatabaseProvider,
         database_provider: PostgresDatabaseProvider,
-        email_provider: Union[
-            AsyncSMTPEmailProvider,
-            ConsoleMockEmailProvider,
-            SendGridEmailProvider,
-        ],
+        email_provider: (
+            AsyncSMTPEmailProvider
+            | ConsoleMockEmailProvider
+            | SendGridEmailProvider
+        ),
         *args,
         *args,
         **kwargs,
         **kwargs,
-    ) -> Union[R2RAuthProvider, SupabaseAuthProvider]:
+    ) -> R2RAuthProvider | SupabaseAuthProvider:
         if auth_config.provider == "r2r":
         if auth_config.provider == "r2r":
 
 
             r2r_auth = R2RAuthProvider(
             r2r_auth = R2RAuthProvider(
@@ -82,9 +84,15 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     def create_crypto_provider(
     def create_crypto_provider(
         crypto_config: CryptoConfig, *args, **kwargs
         crypto_config: CryptoConfig, *args, **kwargs
-    ) -> BCryptProvider:
+    ) -> BCryptCryptoProvider | NaClCryptoProvider:
         if crypto_config.provider == "bcrypt":
         if crypto_config.provider == "bcrypt":
-            return BCryptProvider(BCryptConfig(**crypto_config.dict()))
+            return BCryptCryptoProvider(
+                BcryptCryptoConfig(**crypto_config.model_dump())
+            )
+        if crypto_config.provider == "nacl":
+            return NaClCryptoProvider(
+                NaClCryptoConfig(**crypto_config.model_dump())
+            )
         else:
         else:
             raise ValueError(
             raise ValueError(
                 f"Crypto provider {crypto_config.provider} not supported."
                 f"Crypto provider {crypto_config.provider} not supported."
@@ -94,12 +102,10 @@ class R2RProviderFactory:
     def create_ingestion_provider(
     def create_ingestion_provider(
         ingestion_config: IngestionConfig,
         ingestion_config: IngestionConfig,
         database_provider: PostgresDatabaseProvider,
         database_provider: PostgresDatabaseProvider,
-        llm_provider: Union[
-            LiteLLMCompletionProvider, OpenAICompletionProvider
-        ],
+        llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
         *args,
         *args,
         **kwargs,
         **kwargs,
-    ) -> Union[R2RIngestionProvider, UnstructuredIngestionProvider]:
+    ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
 
 
         config_dict = (
         config_dict = (
             ingestion_config.model_dump()
             ingestion_config.model_dump()
@@ -135,7 +141,7 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     def create_orchestration_provider(
     def create_orchestration_provider(
         config: OrchestrationConfig, *args, **kwargs
         config: OrchestrationConfig, *args, **kwargs
-    ) -> Union[HatchetOrchestrationProvider, SimpleOrchestrationProvider]:
+    ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
         if config.provider == "hatchet":
         if config.provider == "hatchet":
             orchestration_provider = HatchetOrchestrationProvider(config)
             orchestration_provider = HatchetOrchestrationProvider(config)
             orchestration_provider.get_worker("r2r-worker")
             orchestration_provider.get_worker("r2r-worker")
@@ -152,7 +158,7 @@ class R2RProviderFactory:
     async def create_database_provider(
     async def create_database_provider(
         self,
         self,
         db_config: DatabaseConfig,
         db_config: DatabaseConfig,
-        crypto_provider: BCryptProvider,
+        crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> PostgresDatabaseProvider:
     ) -> PostgresDatabaseProvider:
@@ -184,11 +190,11 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     def create_embedding_provider(
     def create_embedding_provider(
         embedding: EmbeddingConfig, *args, **kwargs
         embedding: EmbeddingConfig, *args, **kwargs
-    ) -> Union[
-        LiteLLMEmbeddingProvider,
-        OllamaEmbeddingProvider,
-        OpenAIEmbeddingProvider,
-    ]:
+    ) -> (
+        LiteLLMEmbeddingProvider
+        | OllamaEmbeddingProvider
+        | OpenAIEmbeddingProvider
+    ):
         embedding_provider: Optional[EmbeddingProvider] = None
         embedding_provider: Optional[EmbeddingProvider] = None
 
 
         if embedding.provider == "openai":
         if embedding.provider == "openai":
@@ -220,7 +226,7 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     def create_llm_provider(
     def create_llm_provider(
         llm_config: CompletionConfig, *args, **kwargs
         llm_config: CompletionConfig, *args, **kwargs
-    ) -> Union[LiteLLMCompletionProvider, OpenAICompletionProvider]:
+    ) -> LiteLLMCompletionProvider | OpenAICompletionProvider:
         llm_provider: Optional[CompletionProvider] = None
         llm_provider: Optional[CompletionProvider] = None
         if llm_config.provider == "openai":
         if llm_config.provider == "openai":
             llm_provider = OpenAICompletionProvider(llm_config)
             llm_provider = OpenAICompletionProvider(llm_config)
@@ -237,13 +243,15 @@ class R2RProviderFactory:
     @staticmethod
     @staticmethod
     async def create_email_provider(
     async def create_email_provider(
         email_config: Optional[EmailConfig] = None, *args, **kwargs
         email_config: Optional[EmailConfig] = None, *args, **kwargs
-    ) -> Union[
-        AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider
-    ]:
+    ) -> (
+        AsyncSMTPEmailProvider
+        | ConsoleMockEmailProvider
+        | SendGridEmailProvider
+    ):
         """Creates an email provider based on configuration."""
         """Creates an email provider based on configuration."""
         if not email_config:
         if not email_config:
             raise ValueError(
             raise ValueError(
-                f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
+                "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
             )
             )
 
 
         if email_config.provider == "smtp":
         if email_config.provider == "smtp":
@@ -260,29 +268,27 @@ class R2RProviderFactory:
     async def create_providers(
     async def create_providers(
         self,
         self,
         auth_provider_override: Optional[
         auth_provider_override: Optional[
-            Union[R2RAuthProvider, SupabaseAuthProvider]
+            R2RAuthProvider | SupabaseAuthProvider
+        ] = None,
+        crypto_provider_override: Optional[
+            BCryptCryptoProvider | NaClCryptoProvider
         ] = None,
         ] = None,
-        crypto_provider_override: Optional[BCryptProvider] = None,
         database_provider_override: Optional[PostgresDatabaseProvider] = None,
         database_provider_override: Optional[PostgresDatabaseProvider] = None,
         email_provider_override: Optional[
         email_provider_override: Optional[
-            Union[
-                AsyncSMTPEmailProvider,
-                ConsoleMockEmailProvider,
-                SendGridEmailProvider,
-            ]
+            AsyncSMTPEmailProvider
+            | ConsoleMockEmailProvider
+            | SendGridEmailProvider
         ] = None,
         ] = None,
         embedding_provider_override: Optional[
         embedding_provider_override: Optional[
-            Union[
-                LiteLLMEmbeddingProvider,
-                OpenAIEmbeddingProvider,
-                OllamaEmbeddingProvider,
-            ]
+            LiteLLMEmbeddingProvider
+            | OpenAIEmbeddingProvider
+            | OllamaEmbeddingProvider
         ] = None,
         ] = None,
         ingestion_provider_override: Optional[
         ingestion_provider_override: Optional[
-            Union[R2RIngestionProvider, UnstructuredIngestionProvider]
+            R2RIngestionProvider | UnstructuredIngestionProvider
         ] = None,
         ] = None,
         llm_provider_override: Optional[
         llm_provider_override: Optional[
-            Union[OpenAICompletionProvider, LiteLLMCompletionProvider]
+            OpenAICompletionProvider | LiteLLMCompletionProvider
         ] = None,
         ] = None,
         orchestration_provider_override: Optional[Any] = None,
         orchestration_provider_override: Optional[Any] = None,
         *args,
         *args,
@@ -364,18 +370,18 @@ class R2RPipeFactory:
         self,
         self,
         parsing_pipe_override: Optional[AsyncPipe] = None,
         parsing_pipe_override: Optional[AsyncPipe] = None,
         embedding_pipe_override: Optional[AsyncPipe] = None,
         embedding_pipe_override: Optional[AsyncPipe] = None,
-        kg_relationships_extraction_pipe_override: Optional[AsyncPipe] = None,
-        kg_storage_pipe_override: Optional[AsyncPipe] = None,
-        kg_search_pipe_override: Optional[AsyncPipe] = None,
+        graph_extraction_pipe_override: Optional[AsyncPipe] = None,
+        graph_storage_pipe_override: Optional[AsyncPipe] = None,
+        graph_search_pipe_override: Optional[AsyncPipe] = None,
         vector_storage_pipe_override: Optional[AsyncPipe] = None,
         vector_storage_pipe_override: Optional[AsyncPipe] = None,
         vector_search_pipe_override: Optional[AsyncPipe] = None,
         vector_search_pipe_override: Optional[AsyncPipe] = None,
         rag_pipe_override: Optional[AsyncPipe] = None,
         rag_pipe_override: Optional[AsyncPipe] = None,
         streaming_rag_pipe_override: Optional[AsyncPipe] = None,
         streaming_rag_pipe_override: Optional[AsyncPipe] = None,
-        kg_entity_description_pipe: Optional[AsyncPipe] = None,
-        kg_clustering_pipe: Optional[AsyncPipe] = None,
-        kg_entity_deduplication_pipe: Optional[AsyncPipe] = None,
-        kg_entity_deduplication_summary_pipe: Optional[AsyncPipe] = None,
-        kg_community_summary_pipe: Optional[AsyncPipe] = None,
+        graph_description_pipe: Optional[AsyncPipe] = None,
+        graph_clustering_pipe: Optional[AsyncPipe] = None,
+        graph_deduplication_pipe: Optional[AsyncPipe] = None,
+        graph_deduplication_summary_pipe: Optional[AsyncPipe] = None,
+        graph_community_summary_pipe: Optional[AsyncPipe] = None,
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> R2RPipes:
     ) -> R2RPipes:
@@ -388,32 +394,30 @@ class R2RPipeFactory:
             ),
             ),
             embedding_pipe=embedding_pipe_override
             embedding_pipe=embedding_pipe_override
             or self.create_embedding_pipe(*args, **kwargs),
             or self.create_embedding_pipe(*args, **kwargs),
-            kg_relationships_extraction_pipe=kg_relationships_extraction_pipe_override
-            or self.create_kg_relationships_extraction_pipe(*args, **kwargs),
-            kg_storage_pipe=kg_storage_pipe_override
-            or self.create_kg_storage_pipe(*args, **kwargs),
+            graph_extraction_pipe=graph_extraction_pipe_override
+            or self.create_graph_extraction_pipe(*args, **kwargs),
+            graph_storage_pipe=graph_storage_pipe_override
+            or self.create_graph_storage_pipe(*args, **kwargs),
             vector_storage_pipe=vector_storage_pipe_override
             vector_storage_pipe=vector_storage_pipe_override
             or self.create_vector_storage_pipe(*args, **kwargs),
             or self.create_vector_storage_pipe(*args, **kwargs),
             vector_search_pipe=vector_search_pipe_override
             vector_search_pipe=vector_search_pipe_override
             or self.create_vector_search_pipe(*args, **kwargs),
             or self.create_vector_search_pipe(*args, **kwargs),
-            kg_search_pipe=kg_search_pipe_override
-            or self.create_kg_search_pipe(*args, **kwargs),
+            graph_search_pipe=graph_search_pipe_override
+            or self.create_graph_search_pipe(*args, **kwargs),
             rag_pipe=rag_pipe_override
             rag_pipe=rag_pipe_override
             or self.create_rag_pipe(*args, **kwargs),
             or self.create_rag_pipe(*args, **kwargs),
             streaming_rag_pipe=streaming_rag_pipe_override
             streaming_rag_pipe=streaming_rag_pipe_override
             or self.create_rag_pipe(True, *args, **kwargs),
             or self.create_rag_pipe(True, *args, **kwargs),
-            kg_entity_description_pipe=kg_entity_description_pipe
-            or self.create_kg_entity_description_pipe(*args, **kwargs),
-            kg_clustering_pipe=kg_clustering_pipe
-            or self.create_kg_clustering_pipe(*args, **kwargs),
-            kg_entity_deduplication_pipe=kg_entity_deduplication_pipe
-            or self.create_kg_entity_deduplication_pipe(*args, **kwargs),
-            kg_entity_deduplication_summary_pipe=kg_entity_deduplication_summary_pipe
-            or self.create_kg_entity_deduplication_summary_pipe(
-                *args, **kwargs
-            ),
-            kg_community_summary_pipe=kg_community_summary_pipe
-            or self.create_kg_community_summary_pipe(*args, **kwargs),
+            graph_description_pipe=graph_description_pipe
+            or self.create_graph_description_pipe(*args, **kwargs),
+            graph_clustering_pipe=graph_clustering_pipe
+            or self.create_graph_clustering_pipe(*args, **kwargs),
+            graph_deduplication_pipe=graph_deduplication_pipe
+            or self.create_graph_deduplication_pipe(*args, **kwargs),
+            graph_deduplication_summary_pipe=graph_deduplication_summary_pipe
+            or self.create_graph_deduplication_summary_pipe(*args, **kwargs),
+            graph_community_summary_pipe=graph_community_summary_pipe
+            or self.create_graph_community_summary_pipe(*args, **kwargs),
         )
         )
 
 
     def create_parsing_pipe(self, *args, **kwargs) -> Any:
     def create_parsing_pipe(self, *args, **kwargs) -> Any:
@@ -525,29 +529,27 @@ class R2RPipeFactory:
             config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
             config=AsyncPipe.PipeConfig(name="routing_search_pipe"),
         )
         )
 
 
-    def create_kg_relationships_extraction_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGExtractionPipe
+    def create_graph_extraction_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphExtractionPipe
 
 
-        return KGExtractionPipe(
+        return GraphExtractionPipe(
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             database_provider=self.providers.database,
             database_provider=self.providers.database,
-            config=AsyncPipe.PipeConfig(
-                name="kg_relationships_extraction_pipe"
-            ),
+            config=AsyncPipe.PipeConfig(name="graph_extraction_pipe"),
         )
         )
 
 
-    def create_kg_storage_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGStoragePipe
+    def create_graph_storage_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphStoragePipe
 
 
-        return KGStoragePipe(
+        return GraphStoragePipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
-            config=AsyncPipe.PipeConfig(name="kg_storage_pipe"),
+            config=AsyncPipe.PipeConfig(name="graph_storage_pipe"),
         )
         )
 
 
-    def create_kg_search_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGSearchSearchPipe
+    def create_graph_search_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphSearchSearchPipe
 
 
-        return KGSearchSearchPipe(
+        return GraphSearchSearchPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
@@ -558,9 +560,9 @@ class R2RPipeFactory:
 
 
     def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
     def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any:
         if stream:
         if stream:
-            from core.pipes import StreamingSearchRAGPipe
+            from core.pipes import StreamingRAGPipe
 
 
-            return StreamingSearchRAGPipe(
+            return StreamingRAGPipe(
                 llm_provider=self.providers.llm,
                 llm_provider=self.providers.llm,
                 database_provider=self.providers.database,
                 database_provider=self.providers.database,
                 config=GeneratorPipe.PipeConfig(
                 config=GeneratorPipe.PipeConfig(
@@ -568,9 +570,9 @@ class R2RPipeFactory:
                 ),
                 ),
             )
             )
         else:
         else:
-            from core.pipes import SearchRAGPipe
+            from core.pipes import RAGPipe
 
 
-            return SearchRAGPipe(
+            return RAGPipe(
                 llm_provider=self.providers.llm,
                 llm_provider=self.providers.llm,
                 database_provider=self.providers.database,
                 database_provider=self.providers.database,
                 config=GeneratorPipe.PipeConfig(
                 config=GeneratorPipe.PipeConfig(
@@ -578,67 +580,65 @@ class R2RPipeFactory:
                 ),
                 ),
             )
             )
 
 
-    def create_kg_entity_description_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGEntityDescriptionPipe
+    def create_graph_description_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphDescriptionPipe
 
 
-        return KGEntityDescriptionPipe(
+        return GraphDescriptionPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="kg_entity_description_pipe"),
+            config=AsyncPipe.PipeConfig(name="graph_description_pipe"),
         )
         )
 
 
-    def create_kg_clustering_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGClusteringPipe
+    def create_graph_clustering_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphClusteringPipe
 
 
-        return KGClusteringPipe(
+        return GraphClusteringPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="kg_clustering_pipe"),
+            config=AsyncPipe.PipeConfig(name="graph_clustering_pipe"),
         )
         )
 
 
     def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
     def create_kg_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGEntityDeduplicationSummaryPipe
+        from core.pipes import GraphDeduplicationSummaryPipe
 
 
-        return KGEntityDeduplicationSummaryPipe(
+        return GraphDeduplicationSummaryPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
             config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
             config=AsyncPipe.PipeConfig(name="kg_deduplication_summary_pipe"),
         )
         )
 
 
-    def create_kg_community_summary_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGCommunitySummaryPipe
+    def create_graph_community_summary_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphCommunitySummaryPipe
 
 
-        return KGCommunitySummaryPipe(
+        return GraphCommunitySummaryPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="kg_community_summary_pipe"),
+            config=AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
         )
         )
 
 
-    def create_kg_entity_deduplication_pipe(self, *args, **kwargs) -> Any:
-        from core.pipes import KGEntityDeduplicationPipe
+    def create_graph_deduplication_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphDeduplicationPipe
 
 
-        return KGEntityDeduplicationPipe(
+        return GraphDeduplicationPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
-            config=AsyncPipe.PipeConfig(name="kg_entity_deduplication_pipe"),
+            config=AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
         )
         )
 
 
-    def create_kg_entity_deduplication_summary_pipe(
-        self, *args, **kwargs
-    ) -> Any:
-        from core.pipes import KGEntityDeduplicationSummaryPipe
+    def create_graph_deduplication_summary_pipe(self, *args, **kwargs) -> Any:
+        from core.pipes import GraphDeduplicationSummaryPipe
 
 
-        return KGEntityDeduplicationSummaryPipe(
+        return GraphDeduplicationSummaryPipe(
             database_provider=self.providers.database,
             database_provider=self.providers.database,
             llm_provider=self.providers.llm,
             llm_provider=self.providers.llm,
             embedding_provider=self.providers.embedding,
             embedding_provider=self.providers.embedding,
             config=AsyncPipe.PipeConfig(
             config=AsyncPipe.PipeConfig(
-                name="kg_entity_deduplication_summary_pipe"
+                name="graph_deduplication_summary_pipe"
             ),
             ),
         )
         )
 
 
@@ -664,7 +664,7 @@ class R2RPipelineFactory:
                 self.pipes.vector_search_pipe, vector_search_pipe=True
                 self.pipes.vector_search_pipe, vector_search_pipe=True
             )
             )
             search_pipeline.add_pipe(
             search_pipeline.add_pipe(
-                self.pipes.kg_search_pipe, kg_search_pipe=True
+                self.pipes.graph_search_pipe, graph_search_pipe=True
             )
             )
 
 
         return search_pipeline
         return search_pipeline

+ 16 - 16
core/main/config.py

@@ -17,6 +17,7 @@ from ..base.providers.embedding import EmbeddingConfig
 from ..base.providers.ingestion import IngestionConfig
 from ..base.providers.ingestion import IngestionConfig
 from ..base.providers.llm import CompletionConfig
 from ..base.providers.llm import CompletionConfig
 from ..base.providers.orchestration import OrchestrationConfig
 from ..base.providers.orchestration import OrchestrationConfig
+from ..base.utils import deep_update
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
@@ -77,19 +78,13 @@ class R2RConfig:
         default_config = self.load_default_config()
         default_config = self.load_default_config()
 
 
         # Override the default configuration with the passed configuration
         # Override the default configuration with the passed configuration
-        for key in config_data:
-            if key in default_config:
-                default_config[key].update(config_data[key])
-            else:
-                default_config[key] = config_data[key]
+        default_config = deep_update(default_config, config_data)
 
 
         # Validate and set the configuration
         # Validate and set the configuration
         for section, keys in R2RConfig.REQUIRED_KEYS.items():
         for section, keys in R2RConfig.REQUIRED_KEYS.items():
             # Check the keys when provider is set
             # Check the keys when provider is set
             # TODO - remove after deprecation
             # TODO - remove after deprecation
-            if (
-                section == "kg" or section == "file"
-            ) and section not in default_config:
+            if section in ["kg", "file"] and section not in default_config:
                 continue
                 continue
             if "provider" in default_config[section] and (
             if "provider" in default_config[section] and (
                 default_config[section]["provider"] is not None
                 default_config[section]["provider"] is not None
@@ -151,10 +146,13 @@ class R2RConfig:
         return cls(config_data)
         return cls(config_data)
 
 
     def to_toml(self):
     def to_toml(self):
-        config_data = {
-            section: self._serialize_config(getattr(self, section))
-            for section in R2RConfig.REQUIRED_KEYS.keys()
-        }
+        config_data = {}
+        for section in R2RConfig.REQUIRED_KEYS.keys():
+            section_data = self._serialize_config(getattr(self, section))
+            if isinstance(section_data, dict):
+                # Remove app from nested configs before serializing
+                section_data.pop("app", None)
+            config_data[section] = section_data
         return toml.dumps(config_data)
         return toml.dumps(config_data)
 
 
     @classmethod
     @classmethod
@@ -164,21 +162,23 @@ class R2RConfig:
 
 
     @staticmethod
     @staticmethod
     def _serialize_config(config_section: Any) -> dict:
     def _serialize_config(config_section: Any) -> dict:
+        """Serialize config section while excluding internal state"""
         if isinstance(config_section, dict):
         if isinstance(config_section, dict):
             return {
             return {
                 R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
                 R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
                 for k, v in config_section.items()
                 for k, v in config_section.items()
+                if k != "app"  # Exclude app from serialization
             }
             }
         elif isinstance(config_section, (list, tuple)):
         elif isinstance(config_section, (list, tuple)):
-            return [  # type: ignore
+            return [
                 R2RConfig._serialize_config(item) for item in config_section
                 R2RConfig._serialize_config(item) for item in config_section
             ]
             ]
         elif isinstance(config_section, Enum):
         elif isinstance(config_section, Enum):
             return config_section.value
             return config_section.value
         elif isinstance(config_section, BaseModel):
         elif isinstance(config_section, BaseModel):
-            return R2RConfig._serialize_config(
-                config_section.model_dump(exclude_none=True)
-            )
+            data = config_section.model_dump(exclude_none=True)
+            data.pop("app", None)  # Remove app from the serialized data
+            return R2RConfig._serialize_config(data)
         else:
         else:
             return config_section
             return config_section
 
 

+ 8 - 8
core/main/orchestration/hatchet/kg_workflow.py

@@ -11,7 +11,7 @@ from core import GenerationConfig
 from core.base import OrchestrationProvider, R2RException
 from core.base import OrchestrationProvider, R2RException
 from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus
 from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus
 
 
-from ...services import KgService
+from ...services import GraphService
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 from typing import TYPE_CHECKING
 from typing import TYPE_CHECKING
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
 
 
 
 
 def hatchet_kg_factory(
 def hatchet_kg_factory(
-    orchestration_provider: OrchestrationProvider, service: KgService
+    orchestration_provider: OrchestrationProvider, service: GraphService
 ) -> dict[str, "Hatchet.Workflow"]:
 ) -> dict[str, "Hatchet.Workflow"]:
 
 
     def convert_to_dict(input_data):
     def convert_to_dict(input_data):
@@ -124,7 +124,7 @@ def hatchet_kg_factory(
 
 
     @orchestration_provider.workflow(name="kg-extract", timeout="360m")
     @orchestration_provider.workflow(name="kg-extract", timeout="360m")
     class KGExtractDescribeEmbedWorkflow:
     class KGExtractDescribeEmbedWorkflow:
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.concurrency(  # type: ignore
         @orchestration_provider.concurrency(  # type: ignore
@@ -273,7 +273,7 @@ def hatchet_kg_factory(
             except Exception as e:
             except Exception as e:
                 pass
                 pass
 
 
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.step(retries=1)
         @orchestration_provider.step(retries=1)
@@ -392,7 +392,7 @@ def hatchet_kg_factory(
         name="entity-deduplication", timeout="360m"
         name="entity-deduplication", timeout="360m"
     )
     )
     class EntityDeduplicationWorkflow:
     class EntityDeduplicationWorkflow:
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.step(retries=0, timeout="360m")
         @orchestration_provider.step(retries=0, timeout="360m")
@@ -460,7 +460,7 @@ def hatchet_kg_factory(
         name="kg-entity-deduplication-summary", timeout="360m"
         name="kg-entity-deduplication-summary", timeout="360m"
     )
     )
     class EntityDeduplicationSummaryWorkflow:
     class EntityDeduplicationSummaryWorkflow:
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.step(retries=0, timeout="360m")
         @orchestration_provider.step(retries=0, timeout="360m")
@@ -490,7 +490,7 @@ def hatchet_kg_factory(
 
 
     @orchestration_provider.workflow(name="build-communities", timeout="360m")
     @orchestration_provider.workflow(name="build-communities", timeout="360m")
     class EnrichGraphWorkflow:
     class EnrichGraphWorkflow:
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.step(retries=1, parents=[], timeout="360m")
         @orchestration_provider.step(retries=1, parents=[], timeout="360m")
@@ -642,7 +642,7 @@ def hatchet_kg_factory(
         name="kg-community-summary", timeout="360m"
         name="kg-community-summary", timeout="360m"
     )
     )
     class KGCommunitySummaryWorkflow:
     class KGCommunitySummaryWorkflow:
-        def __init__(self, kg_service: KgService):
+        def __init__(self, kg_service: GraphService):
             self.kg_service = kg_service
             self.kg_service = kg_service
 
 
         @orchestration_provider.concurrency(  # type: ignore
         @orchestration_provider.concurrency(  # type: ignore

+ 2 - 2
core/main/orchestration/simple/kg_workflow.py

@@ -6,12 +6,12 @@ import uuid
 from core import GenerationConfig, R2RException
 from core import GenerationConfig, R2RException
 from core.base.abstractions import KGEnrichmentStatus
 from core.base.abstractions import KGEnrichmentStatus
 
 
-from ...services import KgService
+from ...services import GraphService
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
 
 
-def simple_kg_factory(service: KgService):
+def simple_kg_factory(service: GraphService):
 
 
     def get_input_data_dict(input_data):
     def get_input_data_dict(input_data):
         for key, value in input_data.items():
         for key, value in input_data.items():

+ 2 - 2
core/main/services/__init__.py

@@ -1,6 +1,6 @@
 from .auth_service import AuthService
 from .auth_service import AuthService
+from .graph_service import GraphService
 from .ingestion_service import IngestionService, IngestionServiceAdapter
 from .ingestion_service import IngestionService, IngestionServiceAdapter
-from .kg_service import KgService
 from .management_service import ManagementService
 from .management_service import ManagementService
 from .retrieval_service import RetrievalService
 from .retrieval_service import RetrievalService
 
 
@@ -9,6 +9,6 @@ __all__ = [
     "IngestionService",
     "IngestionService",
     "IngestionServiceAdapter",
     "IngestionServiceAdapter",
     "ManagementService",
     "ManagementService",
-    "KgService",
+    "GraphService",
     "RetrievalService",
     "RetrievalService",
 ]
 ]

+ 46 - 8
core/main/services/auth_service.py

@@ -1,3 +1,4 @@
+import logging
 from datetime import datetime
 from datetime import datetime
 from typing import Optional
 from typing import Optional
 from uuid import UUID
 from uuid import UUID
@@ -11,6 +12,8 @@ from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
 from ..config import R2RConfig
 from ..config import R2RConfig
 from .base import Service
 from .base import Service
 
 
+logger = logging.getLogger()
+
 
 
 class AuthService(Service):
 class AuthService(Service):
     def __init__(
     def __init__(
@@ -47,11 +50,6 @@ class AuthService(Service):
         user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
         user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
             verification_code
             verification_code
         )
         )
-        if not user_id:
-            raise R2RException(
-                status_code=400, message="Invalid or expired verification code"
-            )
-
         user = await self.providers.database.users_handler.get_user_by_id(
         user = await self.providers.database.users_handler.get_user_by_id(
             user_id
             user_id
         )
         )
@@ -186,12 +184,13 @@ class AuthService(Service):
         )
         )
 
 
         try:
         try:
-            await self.providers.database.graphs_handler.delete_graph_for_collection(
+            await self.providers.database.graphs_handler.delete(
                 collection_id=collection_id,
                 collection_id=collection_id,
             )
             )
         except Exception as e:
         except Exception as e:
-            # print(f"Error deleting graph for collection {collection_id}: {e}")
-            pass
+            logger.warning(
+                f"Error deleting graph for collection {collection_id}: {e}"
+            )
 
 
         if delete_vector_data:
         if delete_vector_data:
             await self.providers.database.chunks_handler.delete_user_vector(
             await self.providers.database.chunks_handler.delete_user_vector(
@@ -268,3 +267,42 @@ class AuthService(Service):
             dict: Contains verification_code and message
             dict: Contains verification_code and message
         """
         """
         return await self.providers.auth.send_reset_email(email)
         return await self.providers.auth.send_reset_email(email)
+
+    async def create_user_api_key(self, user_id: UUID) -> dict:
+        """
+        Generate a new API key for the user.
+
+        Args:
+            user_id (UUID): The ID of the user
+
+        Returns:
+            dict: Contains the API key and message
+        """
+        return await self.providers.auth.create_user_api_key(user_id)
+
+    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
+        """
+        Delete the API key for the user.
+
+        Args:
+            user_id (UUID): The ID of the user
+            key_id (str): The ID of the API key
+
+        Returns:
+            dict: Contains the message
+        """
+        return await self.providers.auth.delete_user_api_key(
+            user_id=user_id, key_id=key_id
+        )
+
+    async def list_user_api_keys(self, user_id: UUID) -> dict:
+        """
+        List all API keys for the user.
+
+        Args:
+            user_id (UUID): The ID of the user
+
+        Returns:
+            dict: Contains the list of API keys
+        """
+        return await self.providers.auth.list_user_api_keys(user_id)

+ 1081 - 0
core/main/services/graph_service.py

@@ -0,0 +1,1081 @@
+import asyncio
+import json
+import logging
+import math
+import re
+import time
+from typing import Any, AsyncGenerator, Optional
+from uuid import UUID
+
+from core.base import (
+    DocumentChunk,
+    KGExtraction,
+    KGExtractionStatus,
+    R2RDocumentProcessingError,
+    RunManager,
+)
+from core.base.abstractions import (
+    Community,
+    Entity,
+    GenerationConfig,
+    KGCreationSettings,
+    KGEnrichmentSettings,
+    KGEnrichmentStatus,
+    KGEntityDeduplicationSettings,
+    KGEntityDeduplicationType,
+    R2RException,
+    Relationship,
+)
+from core.base.api.models import GraphResponse
+from core.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
+from ..config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger()
+
+
+MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
+
+
+async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
+    results = []
+    async for res in result_gen:
+        results.append(res.json() if hasattr(res, "json") else res)
+    return results
+
+
+# TODO - Fix naming convention to read `KGService` instead of `GraphService`
+# this will require a minor change in how services are registered.
+class GraphService(Service):
+    def __init__(
+        self,
+        config: R2RConfig,
+        providers: R2RProviders,
+        pipes: R2RPipes,
+        pipelines: R2RPipelines,
+        agents: R2RAgents,
+        run_manager: RunManager,
+    ):
+        super().__init__(
+            config,
+            providers,
+            pipes,
+            pipelines,
+            agents,
+            run_manager,
+        )
+
+    @telemetry_event("kg_relationships_extraction")
+    async def kg_relationships_extraction(
+        self,
+        document_id: UUID,
+        generation_config: GenerationConfig,
+        chunk_merge_count: int,
+        max_knowledge_relationships: int,
+        entity_types: list[str],
+        relation_types: list[str],
+        **kwargs,
+    ):
+        try:
+
+            logger.info(
+                f"KGService: Processing document {document_id} for KG extraction"
+            )
+
+            await self.providers.database.documents_handler.set_workflow_status(
+                id=document_id,
+                status_type="extraction_status",
+                status=KGExtractionStatus.PROCESSING,
+            )
+
+            relationships = await self.pipes.graph_extraction_pipe.run(
+                input=self.pipes.graph_extraction_pipe.Input(
+                    message={
+                        "document_id": document_id,
+                        "generation_config": generation_config,
+                        "chunk_merge_count": chunk_merge_count,
+                        "max_knowledge_relationships": max_knowledge_relationships,
+                        "entity_types": entity_types,
+                        "relation_types": relation_types,
+                        "logger": logger,
+                    }
+                ),
+                state=None,
+                run_manager=self.run_manager,
+            )
+
+            logger.info(
+                f"KGService: Finished processing document {document_id} for KG extraction"
+            )
+
+            result_gen = await self.pipes.graph_storage_pipe.run(
+                input=self.pipes.graph_storage_pipe.Input(
+                    message=relationships
+                ),
+                state=None,
+                run_manager=self.run_manager,
+            )
+
+        except Exception as e:
+            logger.error(f"KGService: Error in kg_extraction: {e}")
+            await self.providers.database.documents_handler.set_workflow_status(
+                id=document_id,
+                status_type="extraction_status",
+                status=KGExtractionStatus.FAILED,
+            )
+            raise e
+
+        return await _collect_results(result_gen)
+
+    @telemetry_event("create_entity")
+    async def create_entity(
+        self,
+        name: str,
+        description: str,
+        parent_id: UUID,
+        category: Optional[str] = None,
+        metadata: Optional[dict] = None,
+    ) -> Entity:
+
+        description_embedding = str(
+            await self.providers.embedding.async_get_embedding(description)
+        )
+
+        return await self.providers.database.graphs_handler.entities.create(
+            name=name,
+            parent_id=parent_id,
+            store_type="graphs",  # type: ignore
+            category=category,
+            description=description,
+            description_embedding=description_embedding,
+            metadata=metadata,
+        )
+
+    @telemetry_event("update_entity")
+    async def update_entity(
+        self,
+        entity_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+        category: Optional[str] = None,
+        metadata: Optional[dict] = None,
+    ) -> Entity:
+
+        description_embedding = None
+        if description is not None:
+            description_embedding = str(
+                await self.providers.embedding.async_get_embedding(description)
+            )
+
+        return await self.providers.database.graphs_handler.entities.update(
+            entity_id=entity_id,
+            store_type="graphs",  # type: ignore
+            name=name,
+            description=description,
+            description_embedding=description_embedding,
+            category=category,
+            metadata=metadata,
+        )
+
+    @telemetry_event("delete_entity")
+    async def delete_entity(
+        self,
+        parent_id: UUID,
+        entity_id: UUID,
+    ):
+        return await self.providers.database.graphs_handler.entities.delete(
+            parent_id=parent_id,
+            entity_ids=[entity_id],
+            store_type="graphs",  # type: ignore
+        )
+
+    @telemetry_event("get_entities")
+    async def get_entities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        entity_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        return await self.providers.database.graphs_handler.get_entities(
+            parent_id=parent_id,
+            offset=offset,
+            limit=limit,
+            entity_ids=entity_ids,
+            entity_names=entity_names,
+            include_embeddings=include_embeddings,
+        )
+
+    @telemetry_event("create_relationship")
+    async def create_relationship(
+        self,
+        subject: str,
+        subject_id: UUID,
+        predicate: str,
+        object: str,
+        object_id: UUID,
+        parent_id: UUID,
+        description: str | None = None,
+        weight: float | None = 1.0,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Relationship:
+        description_embedding = None
+        if description:
+            description_embedding = str(
+                await self.providers.embedding.async_get_embedding(description)
+            )
+
+        return (
+            await self.providers.database.graphs_handler.relationships.create(
+                subject=subject,
+                subject_id=subject_id,
+                predicate=predicate,
+                object=object,
+                object_id=object_id,
+                parent_id=parent_id,
+                description=description,
+                description_embedding=description_embedding,
+                weight=weight,
+                metadata=metadata,
+                store_type="graphs",  # type: ignore
+            )
+        )
+
+    @telemetry_event("delete_relationship")
+    async def delete_relationship(
+        self,
+        parent_id: UUID,
+        relationship_id: UUID,
+    ):
+        return (
+            await self.providers.database.graphs_handler.relationships.delete(
+                parent_id=parent_id,
+                relationship_ids=[relationship_id],
+                store_type="graphs",  # type: ignore
+            )
+        )
+
+    @telemetry_event("update_relationship")
+    async def update_relationship(
+        self,
+        relationship_id: UUID,
+        subject: Optional[str] = None,
+        subject_id: Optional[UUID] = None,
+        predicate: Optional[str] = None,
+        object: Optional[str] = None,
+        object_id: Optional[UUID] = None,
+        description: Optional[str] = None,
+        weight: Optional[float] = None,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Relationship:
+
+        description_embedding = None
+        if description is not None:
+            description_embedding = str(
+                await self.providers.embedding.async_get_embedding(description)
+            )
+
+        return (
+            await self.providers.database.graphs_handler.relationships.update(
+                relationship_id=relationship_id,
+                subject=subject,
+                subject_id=subject_id,
+                predicate=predicate,
+                object=object,
+                object_id=object_id,
+                description=description,
+                description_embedding=description_embedding,
+                weight=weight,
+                metadata=metadata,
+                store_type="graphs",  # type: ignore
+            )
+        )
+
+    @telemetry_event("get_relationships")
+    async def get_relationships(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        relationship_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+    ):
+        return await self.providers.database.graphs_handler.relationships.get(
+            parent_id=parent_id,
+            store_type="graphs",  # type: ignore
+            offset=offset,
+            limit=limit,
+            relationship_ids=relationship_ids,
+            entity_names=entity_names,
+        )
+
+    @telemetry_event("create_community")
+    async def create_community(
+        self,
+        parent_id: UUID,
+        name: str,
+        summary: str,
+        findings: Optional[list[str]],
+        rating: Optional[float],
+        rating_explanation: Optional[str],
+    ) -> Community:
+        description_embedding = str(
+            await self.providers.embedding.async_get_embedding(summary)
+        )
+        return await self.providers.database.graphs_handler.communities.create(
+            parent_id=parent_id,
+            store_type="graphs",  # type: ignore
+            name=name,
+            summary=summary,
+            description_embedding=description_embedding,
+            findings=findings,
+            rating=rating,
+            rating_explanation=rating_explanation,
+        )
+
+    @telemetry_event("update_community")
+    async def update_community(
+        self,
+        community_id: UUID,
+        name: Optional[str],
+        summary: Optional[str],
+        findings: Optional[list[str]],
+        rating: Optional[float],
+        rating_explanation: Optional[str],
+    ) -> Community:
+        summary_embedding = None
+        if summary is not None:
+            summary_embedding = str(
+                await self.providers.embedding.async_get_embedding(summary)
+            )
+
+        return await self.providers.database.graphs_handler.communities.update(
+            community_id=community_id,
+            store_type="graphs",  # type: ignore
+            name=name,
+            summary=summary,
+            summary_embedding=summary_embedding,
+            findings=findings,
+            rating=rating,
+            rating_explanation=rating_explanation,
+        )
+
+    @telemetry_event("delete_community")
+    async def delete_community(
+        self,
+        parent_id: UUID,
+        community_id: UUID,
+    ) -> None:
+        await self.providers.database.graphs_handler.communities.delete(
+            parent_id=parent_id,
+            community_id=community_id,
+        )
+
+    @telemetry_event("list_communities")
+    async def list_communities(
+        self,
+        collection_id: UUID,
+        offset: int,
+        limit: int,
+    ):
+        return await self.providers.database.graphs_handler.communities.get(
+            parent_id=collection_id,
+            store_type="graphs",  # type: ignore
+            offset=offset,
+            limit=limit,
+        )
+
+    @telemetry_event("get_communities")
+    async def get_communities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        community_ids: Optional[list[UUID]] = None,
+        community_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        return await self.providers.database.graphs_handler.get_communities(
+            parent_id=parent_id,
+            offset=offset,
+            limit=limit,
+            community_ids=community_ids,
+            include_embeddings=include_embeddings,
+        )
+
+    # @telemetry_event("create_new_graph")
+    # async def create_new_graph(
+    #     self,
+    #     collection_id: UUID,
+    #     user_id: UUID,
+    #     name: Optional[str],
+    #     description: str = "",
+    # ) -> GraphResponse:
+    #     return await self.providers.database.graphs_handler.create(
+    #         collection_id=collection_id,
+    #         user_id=user_id,
+    #         name=name,
+    #         description=description,
+    #         graph_id=collection_id,
+    #     )
+
+    async def list_graphs(
+        self,
+        offset: int,
+        limit: int,
+        # user_ids: Optional[list[UUID]] = None,
+        graph_ids: Optional[list[UUID]] = None,
+        collection_id: Optional[UUID] = None,
+    ) -> dict[str, list[GraphResponse] | int]:
+        return await self.providers.database.graphs_handler.list_graphs(
+            offset=offset,
+            limit=limit,
+            # filter_user_ids=user_ids,
+            filter_graph_ids=graph_ids,
+            filter_collection_id=collection_id,
+        )
+
+    @telemetry_event("update_graph")
+    async def update_graph(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+    ) -> GraphResponse:
+        return await self.providers.database.graphs_handler.update(
+            collection_id=collection_id,
+            name=name,
+            description=description,
+        )
+
+    @telemetry_event("reset_graph_v3")
+    async def reset_graph_v3(self, id: UUID) -> bool:
+        await self.providers.database.graphs_handler.reset(
+            parent_id=id,
+        )
+        await self.providers.database.documents_handler.set_workflow_status(
+            id=id,
+            status_type="graph_cluster_status",
+            status=KGEnrichmentStatus.PENDING,
+        )
+        return True
+
+    @telemetry_event("get_document_ids_for_create_graph")
+    async def get_document_ids_for_create_graph(
+        self,
+        collection_id: UUID,
+        force_kg_creation: bool = False,
+        **kwargs,
+    ):
+
+        document_status_filter = [
+            KGExtractionStatus.PENDING,
+            KGExtractionStatus.FAILED,
+        ]
+        if force_kg_creation:
+            document_status_filter += [
+                KGExtractionStatus.PROCESSING,
+            ]
+
+        return await self.providers.database.documents_handler.get_document_ids_by_status(
+            status_type="extraction_status",
+            status=[str(ele) for ele in document_status_filter],
+            collection_id=collection_id,
+        )
+
+    @telemetry_event("kg_entity_description")
+    async def kg_entity_description(
+        self,
+        document_id: UUID,
+        max_description_input_length: int,
+        **kwargs,
+    ):
+
+        start_time = time.time()
+
+        logger.info(
+            f"KGService: Running kg_entity_description for document {document_id}"
+        )
+
+        entity_count = (
+            await self.providers.database.graphs_handler.get_entity_count(
+                document_id=document_id,
+                distinct=True,
+                entity_table_name="documents_entities",
+            )
+        )
+
+        logger.info(
+            f"KGService: Found {entity_count} entities in document {document_id}"
+        )
+
+        # TODO - Do not hardcode the batch size,
+        # make it a configurable parameter at runtime & server-side defaults
+
+        # process 256 entities at a time
+        num_batches = math.ceil(entity_count / 256)
+        logger.info(
+            f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
+        )
+        all_results = []
+        for i in range(num_batches):
+            logger.info(
+                f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
+            )
+
+            node_descriptions = await self.pipes.graph_description_pipe.run(
+                input=self.pipes.graph_description_pipe.Input(
+                    message={
+                        "offset": i * 256,
+                        "limit": 256,
+                        "max_description_input_length": max_description_input_length,
+                        "document_id": document_id,
+                        "logger": logger,
+                    }
+                ),
+                state=None,
+                run_manager=self.run_manager,
+            )
+
+            all_results.append(await _collect_results(node_descriptions))
+
+            logger.info(
+                f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
+            )
+
+        await self.providers.database.documents_handler.set_workflow_status(
+            id=document_id,
+            status_type="extraction_status",
+            status=KGExtractionStatus.SUCCESS,
+        )
+
+        logger.info(
+            f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
+        )
+
+        return all_results
+
+    @telemetry_event("kg_clustering")
+    async def kg_clustering(
+        self,
+        collection_id: UUID,
+        # graph_id: UUID,
+        generation_config: GenerationConfig,
+        leiden_params: dict,
+        **kwargs,
+    ):
+
+        logger.info(
+            f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
+        )
+
+        clustering_result = await self.pipes.graph_clustering_pipe.run(
+            input=self.pipes.graph_clustering_pipe.Input(
+                message={
+                    "collection_id": collection_id,
+                    "generation_config": generation_config,
+                    "leiden_params": leiden_params,
+                    "logger": logger,
+                    "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
+                }
+            ),
+            state=None,
+            run_manager=self.run_manager,
+        )
+        return await _collect_results(clustering_result)
+
+    @telemetry_event("kg_community_summary")
+    async def kg_community_summary(
+        self,
+        offset: int,
+        limit: int,
+        max_summary_input_length: int,
+        generation_config: GenerationConfig,
+        collection_id: UUID | None,
+        # graph_id: UUID | None,
+        **kwargs,
+    ):
+        summary_results = await self.pipes.graph_community_summary_pipe.run(
+            input=self.pipes.graph_community_summary_pipe.Input(
+                message={
+                    "offset": offset,
+                    "limit": limit,
+                    "generation_config": generation_config,
+                    "max_summary_input_length": max_summary_input_length,
+                    "collection_id": collection_id,
+                    # "graph_id": graph_id,
+                    "logger": logger,
+                }
+            ),
+            state=None,
+            run_manager=self.run_manager,
+        )
+        return await _collect_results(summary_results)
+
+    @telemetry_event("delete_graph_for_documents")
+    async def delete_graph_for_documents(
+        self,
+        document_ids: list[UUID],
+        **kwargs,
+    ):
+        # TODO: Implement this, as it needs some checks.
+        raise NotImplementedError
+
+    @telemetry_event("delete_graph")
+    async def delete_graph(
+        self,
+        collection_id: UUID,
+        cascade: bool,
+        **kwargs,
+    ):
+        return await self.delete(collection_id=collection_id, cascade=cascade)
+
+    @telemetry_event("delete")
+    async def delete(
+        self,
+        collection_id: UUID,
+        cascade: bool,
+        **kwargs,
+    ):
+        return await self.providers.database.graphs_handler.delete(
+            collection_id=collection_id,
+            cascade=cascade,
+        )
+
+    @telemetry_event("get_creation_estimate")
+    async def get_creation_estimate(
+        self,
+        graph_creation_settings: KGCreationSettings,
+        document_id: Optional[UUID] = None,
+        collection_id: Optional[UUID] = None,
+        **kwargs,
+    ):
+        return (
+            await self.providers.database.graphs_handler.get_creation_estimate(
+                document_id=document_id,
+                collection_id=collection_id,
+                graph_creation_settings=graph_creation_settings,
+            )
+        )
+
+    @telemetry_event("get_enrichment_estimate")
+    async def get_enrichment_estimate(
+        self,
+        collection_id: Optional[UUID] = None,
+        graph_id: Optional[UUID] = None,
+        graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
+        **kwargs,
+    ):
+
+        if graph_id is None and collection_id is None:
+            raise ValueError(
+                "Either graph_id or collection_id must be provided"
+            )
+
+        return await self.providers.database.graphs_handler.get_enrichment_estimate(
+            collection_id=collection_id,
+            graph_id=graph_id,
+            graph_enrichment_settings=graph_enrichment_settings,
+        )
+
+    @telemetry_event("get_deduplication_estimate")
+    async def get_deduplication_estimate(
+        self,
+        collection_id: UUID,
+        kg_deduplication_settings: KGEntityDeduplicationSettings,
+        **kwargs,
+    ):
+        return await self.providers.database.graphs_handler.get_deduplication_estimate(
+            collection_id=collection_id,
+            kg_deduplication_settings=kg_deduplication_settings,
+        )
+
+    @telemetry_event("kg_entity_deduplication")
+    async def kg_entity_deduplication(
+        self,
+        collection_id: UUID,
+        graph_id: UUID,
+        graph_entity_deduplication_type: KGEntityDeduplicationType,
+        graph_entity_deduplication_prompt: str,
+        generation_config: GenerationConfig,
+        **kwargs,
+    ):
+        deduplication_results = await self.pipes.graph_deduplication_pipe.run(
+            input=self.pipes.graph_deduplication_pipe.Input(
+                message={
+                    "collection_id": collection_id,
+                    "graph_id": graph_id,
+                    "graph_entity_deduplication_type": graph_entity_deduplication_type,
+                    "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
+                    "generation_config": generation_config,
+                    **kwargs,
+                }
+            ),
+            state=None,
+            run_manager=self.run_manager,
+        )
+        return await _collect_results(deduplication_results)
+
+    @telemetry_event("kg_entity_deduplication_summary")
+    async def kg_entity_deduplication_summary(
+        self,
+        collection_id: UUID,
+        offset: int,
+        limit: int,
+        graph_entity_deduplication_type: KGEntityDeduplicationType,
+        graph_entity_deduplication_prompt: str,
+        generation_config: GenerationConfig,
+        **kwargs,
+    ):
+
+        logger.info(
+            f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
+        )
+        deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run(
+            input=self.pipes.graph_deduplication_summary_pipe.Input(
+                message={
+                    "collection_id": collection_id,
+                    "offset": offset,
+                    "limit": limit,
+                    "graph_entity_deduplication_type": graph_entity_deduplication_type,
+                    "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
+                    "generation_config": generation_config,
+                }
+            ),
+            state=None,
+            run_manager=self.run_manager,
+        )
+
+        return await _collect_results(deduplication_summary_results)
+
+    async def kg_extraction(  # type: ignore
+        self,
+        document_id: UUID,
+        generation_config: GenerationConfig,
+        max_knowledge_relationships: int,
+        entity_types: list[str],
+        relation_types: list[str],
+        chunk_merge_count: int,
+        filter_out_existing_chunks: bool = True,
+        total_tasks: Optional[int] = None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
+        start_time = time.time()
+
+        logger.info(
+            f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
+        )
+
+        # Then create the extractions from the results
+        limit = 100
+        offset = 0
+        chunks = []
+        while True:
+            chunk_req = await self.providers.database.chunks_handler.list_document_chunks(  # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+                document_id=document_id,
+                offset=offset,
+                limit=limit,
+            )
+
+            chunks.extend(
+                [
+                    DocumentChunk(
+                        id=chunk["id"],
+                        document_id=chunk["document_id"],
+                        owner_id=chunk["owner_id"],
+                        collection_ids=chunk["collection_ids"],
+                        data=chunk["text"],
+                        metadata=chunk["metadata"],
+                    )
+                    for chunk in chunk_req["results"]
+                ]
+            )
+            if len(chunk_req["results"]) < limit:
+                break
+            offset += limit
+
+        logger.info(f"Found {len(chunks)} chunks for document {document_id}")
+        if len(chunks) == 0:
+            logger.info(f"No chunks found for document {document_id}")
+            raise R2RException(
+                message="No chunks found for document",
+                status_code=404,
+            )
+
+        if filter_out_existing_chunks:
+            existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
+                document_id=document_id
+            )
+            chunks = [
+                chunk for chunk in chunks if chunk.id not in existing_chunk_ids
+            ]
+            logger.info(
+                f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
+            )
+
+            if len(chunks) == 0:
+                logger.info(f"No extractions left for document {document_id}")
+                return
+
+        logger.info(
+            f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
+        )
+
+        # sort the extractions accroding to chunk_order field in metadata in ascending order
+        chunks = sorted(
+            chunks,
+            key=lambda x: x.metadata.get("chunk_order", float("inf")),
+        )
+
+        # group these extractions into groups of chunk_merge_count
+        grouped_chunks = [
+            chunks[i : i + chunk_merge_count]
+            for i in range(0, len(chunks), chunk_merge_count)
+        ]
+
+        logger.info(
+            f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
+        )
+
+        tasks = [
+            asyncio.create_task(
+                self._extract_kg(
+                    chunks=chunk_group,
+                    generation_config=generation_config,
+                    max_knowledge_relationships=max_knowledge_relationships,
+                    entity_types=entity_types,
+                    relation_types=relation_types,
+                    task_id=task_id,
+                    total_tasks=len(grouped_chunks),
+                )
+            )
+            for task_id, chunk_group in enumerate(grouped_chunks)
+        ]
+
+        completed_tasks = 0
+        total_tasks = len(tasks)
+
+        logger.info(
+            f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
+        )
+
+        for completed_task in asyncio.as_completed(tasks):
+            try:
+                yield await completed_task
+                completed_tasks += 1
+                if completed_tasks % 100 == 0:
+                    logger.info(
+                        f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
+                    )
+            except Exception as e:
+                logger.error(f"Error in Extracting KG Relationships: {e}")
+                yield R2RDocumentProcessingError(
+                    document_id=document_id,
+                    error_message=str(e),
+                )
+
+        logger.info(
+            f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
+        )
+
+    async def _extract_kg(
+        self,
+        chunks: list[DocumentChunk],
+        generation_config: GenerationConfig,
+        max_knowledge_relationships: int,
+        entity_types: list[str],
+        relation_types: list[str],
+        retries: int = 5,
+        delay: int = 2,
+        task_id: Optional[int] = None,
+        total_tasks: Optional[int] = None,
+    ) -> KGExtraction:
+        """
+        Extracts NER relationships from a extraction with retries.
+        """
+
+        # combine all extractions into a single string
+        combined_extraction: str = " ".join([chunk.data for chunk in chunks])  # type: ignore
+
+        response = await self.providers.database.documents_handler.get_documents_overview(  # type: ignore
+            offset=0,
+            limit=1,
+            filter_document_ids=[chunks[0].document_id],
+        )
+        document_summary = (
+            response["results"][0].summary if response["results"] else None
+        )
+
+        messages = await self.providers.database.prompts_handler.get_message_payload(
+            task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
+            task_inputs={
+                "document_summary": document_summary,
+                "input": combined_extraction,
+                "max_knowledge_relationships": max_knowledge_relationships,
+                "entity_types": "\n".join(entity_types),
+                "relation_types": "\n".join(relation_types),
+            },
+        )
+
+        for attempt in range(retries):
+            try:
+                response = await self.providers.llm.aget_completion(
+                    messages,
+                    generation_config=generation_config,
+                )
+
+                kg_extraction = response.choices[0].message.content
+
+                if not kg_extraction:
+                    raise R2RException(
+                        "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
+                        400,
+                    )
+
+                entity_pattern = (
+                    r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
+                )
+                relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
+
+                async def parse_fn(response_str: str) -> Any:
+                    entities = re.findall(entity_pattern, response_str)
+
+                    if (
+                        len(kg_extraction)
+                        > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
+                        and len(entities) == 0
+                    ):
+                        raise R2RException(
+                            f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
+                            400,
+                        )
+
+                    relationships = re.findall(
+                        relationship_pattern, response_str
+                    )
+
+                    entities_arr = []
+                    for entity in entities:
+                        entity_value = entity[0]
+                        entity_category = entity[1]
+                        entity_description = entity[2]
+                        description_embedding = (
+                            await self.providers.embedding.async_get_embedding(
+                                entity_description
+                            )
+                        )
+                        entities_arr.append(
+                            Entity(
+                                category=entity_category,
+                                description=entity_description,
+                                name=entity_value,
+                                parent_id=chunks[0].document_id,
+                                chunk_ids=[chunk.id for chunk in chunks],
+                                description_embedding=description_embedding,
+                                attributes={},
+                            )
+                        )
+
+                    relations_arr = []
+                    for relationship in relationships:
+                        subject = relationship[0]
+                        object = relationship[1]
+                        predicate = relationship[2]
+                        description = relationship[3]
+                        weight = float(relationship[4])
+                        relationship_embedding = (
+                            await self.providers.embedding.async_get_embedding(
+                                description
+                            )
+                        )
+
+                        # check if subject and object are in entities_dict
+                        relations_arr.append(
+                            Relationship(
+                                subject=subject,
+                                predicate=predicate,
+                                object=object,
+                                description=description,
+                                weight=weight,
+                                parent_id=chunks[0].document_id,
+                                chunk_ids=[chunk.id for chunk in chunks],
+                                attributes={},
+                                description_embedding=relationship_embedding,
+                            )
+                        )
+
+                    return entities_arr, relations_arr
+
+                entities, relationships = await parse_fn(kg_extraction)
+                return KGExtraction(
+                    entities=entities,
+                    relationships=relationships,
+                )
+
+            except (
+                Exception,
+                json.JSONDecodeError,
+                KeyError,
+                IndexError,
+                R2RException,
+            ) as e:
+                if attempt < retries - 1:
+                    await asyncio.sleep(delay)
+                else:
+                    logger.warning(
+                        f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
+                    )
+
+        logger.info(
+            f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
+        )
+
+        return KGExtraction(
+            entities=[],
+            relationships=[],
+        )
+
+    async def store_kg_extractions(
+        self,
+        kg_extractions: list[KGExtraction],
+    ):
+        """
+        Stores a batch of knowledge graph extractions in the graph database.
+        """
+
+        for extraction in kg_extractions:
+            entities_id_map = {}
+            for entity in extraction.entities:
+                result = await self.providers.database.graphs_handler.entities.create(
+                    name=entity.name,
+                    parent_id=entity.parent_id,
+                    store_type="documents",  # type: ignore
+                    category=entity.category,
+                    description=entity.description,
+                    description_embedding=entity.description_embedding,
+                    chunk_ids=entity.chunk_ids,
+                    metadata=entity.metadata,
+                )
+                entities_id_map[entity.name] = result.id
+
+            if extraction.relationships:
+
+                for relationship in extraction.relationships:
+                    await self.providers.database.graphs_handler.relationships.create(
+                        subject=relationship.subject,
+                        subject_id=entities_id_map.get(relationship.subject),
+                        predicate=relationship.predicate,
+                        object=relationship.object,
+                        object_id=entities_id_map.get(relationship.object),
+                        parent_id=relationship.parent_id,
+                        description=relationship.description,
+                        description_embedding=relationship.description_embedding,
+                        weight=relationship.weight,
+                        metadata=relationship.metadata,
+                        store_type="documents",  # type: ignore
+                    )

+ 169 - 238
core/main/services/management_service.py

@@ -1,15 +1,14 @@
 import logging
 import logging
 import os
 import os
 from collections import defaultdict
 from collections import defaultdict
-from copy import copy
 from typing import Any, BinaryIO, Optional, Tuple
 from typing import Any, BinaryIO, Optional, Tuple
 from uuid import UUID
 from uuid import UUID
 
 
 import toml
 import toml
-from fastapi.responses import StreamingResponse
 
 
 from core.base import (
 from core.base import (
     CollectionResponse,
     CollectionResponse,
+    ConversationResponse,
     DocumentResponse,
     DocumentResponse,
     GenerationConfig,
     GenerationConfig,
     KGEnrichmentStatus,
     KGEnrichmentStatus,
@@ -19,8 +18,6 @@ from core.base import (
     RunManager,
     RunManager,
     User,
     User,
 )
 )
-from core.base.logger.base import RunType
-from core.base.utils import validate_uuid
 from core.telemetry.telemetry_decorator import telemetry_event
 from core.telemetry.telemetry_decorator import telemetry_event
 
 
 from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
 from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
@@ -69,8 +66,6 @@ class ManagementService(Service):
         offset: int,
         offset: int,
         limit: int,
         limit: int,
         user_ids: Optional[list[UUID]] = None,
         user_ids: Optional[list[UUID]] = None,
-        *args,
-        **kwargs,
     ):
     ):
         return await self.providers.database.users_handler.get_users_overview(
         return await self.providers.database.users_handler.get_users_overview(
             offset=offset,
             offset=offset,
@@ -78,212 +73,140 @@ class ManagementService(Service):
             user_ids=user_ids,
             user_ids=user_ids,
         )
         )
 
 
-    @telemetry_event("Delete")
-    async def delete(
+    async def delete_documents_and_chunks_by_filter(
         self,
         self,
         filters: dict[str, Any],
         filters: dict[str, Any],
-        *args,
-        **kwargs,
     ):
     ):
         """
         """
-        Takes a list of filters like
-        "{key: {operator: value}, key: {operator: value}, ...}"
-        and deletes entries matching the given filters from both vector and relational databases.
+        Delete chunks matching the given filters. If any documents are now empty
+        (i.e., have no remaining chunks), delete those documents as well.
 
 
-        NOTE: This method is not atomic and may result in orphaned entries in the documents overview table.
-        NOTE: This method assumes that filters delete entire contents of any touched documents.
+        Args:
+            filters (dict[str, Any]): Filters specifying which chunks to delete.
+            chunks_handler (PostgresChunksHandler): The handler for chunk operations.
+            documents_handler (PostgresDocumentsHandler): The handler for document operations.
+            graphs_handler: Handler for entity and relationship operations in the KG.
+
+        Returns:
+            dict: A summary of what was deleted.
         """
         """
-        ### TODO - FIX THIS, ENSURE THAT DOCUMENTS OVERVIEW IS CLEARED
-
-        def validate_filters(filters: dict[str, Any]) -> None:
-            ALLOWED_FILTERS = {
-                "id",
-                "collection_ids",
-                "chunk_id",
-                # TODO - Modify these checks such that they can be used PROPERLY for nested filters
-                "$and",
-                "$or",
-            }
 
 
-            if not filters:
-                raise R2RException(
-                    status_code=422, message="No filters provided"
-                )
+        def transform_chunk_id_to_id(
+            filters: dict[str, Any]
+        ) -> dict[str, Any]:
+            """
+            Example transformation function if your filters use `chunk_id` instead of `id`.
+            Recursively transform `chunk_id` to `id`.
+            """
+            if isinstance(filters, dict):
+                transformed = {}
+                for key, value in filters.items():
+                    if key == "chunk_id":
+                        transformed["id"] = value
+                    elif key in ["$and", "$or"]:
+                        transformed[key] = [
+                            transform_chunk_id_to_id(item) for item in value
+                        ]
+                    else:
+                        transformed[key] = transform_chunk_id_to_id(value)
+                return transformed
+            return filters
+
+        # 1. (Optional) Validate the input filters based on your rules.
+        #    E.g., check if filters is not empty, allowed fields, etc.
+        # validate_filters(filters)
+
+        # 2. Transform filters if needed.
+        #    For example, if `chunk_id` is used, map it to `id`, or similar transformations.
+        transformed_filters = transform_chunk_id_to_id(filters)
+
+        # 3. First, find out which chunks match these filters *before* deleting, so we know which docs are affected.
+        #    You can do a list operation on chunks to see which chunk IDs and doc IDs would be hit.
+        interim_results = (
+            await self.providers.database.chunks_handler.list_chunks(
+                filters=transformed_filters,
+                offset=0,
+                limit=1_000,  # Arbitrary large limit or pagination logic
+                include_vectors=False,
+            )
+        )
 
 
-            for field in filters:
-                if field not in ALLOWED_FILTERS:
-                    raise R2RException(
-                        status_code=422,
-                        message=f"Invalid filter field: {field}",
-                    )
+        if interim_results["page_info"]["total_entries"] == 0:
+            raise R2RException(
+                status_code=404, message="No entries found for deletion."
+            )
 
 
-            for field in ["document_id", "owner_id", "chunk_id"]:
-                if field in filters:
-                    op = next(iter(filters[field].keys()))
-                    try:
-                        validate_uuid(filters[field][op])
-                    except ValueError:
-                        raise R2RException(
-                            status_code=422,
-                            message=f"Invalid UUID: {filters[field][op]}",
-                        )
+        results = interim_results["results"]
+        while interim_results["page_info"]["total_entries"] == 1_000:
+            # If we hit the limit, we need to paginate to get all results
 
 
-            if "collection_ids" in filters:
-                op = next(iter(filters["collection_ids"].keys()))
-                for id_str in filters["collection_ids"][op]:
-                    try:
-                        validate_uuid(id_str)
-                    except ValueError:
-                        raise R2RException(
-                            status_code=422, message=f"Invalid UUID: {id_str}"
-                        )
+            interim_results = (
+                await self.providers.database.chunks_handler.list_chunks(
+                    filters=transformed_filters,
+                    offset=interim_results["offset"] + 1_000,
+                    limit=1_000,
+                    include_vectors=False,
+                )
+            )
+            results.extend(interim_results["results"])
+        matched_chunk_docs = {UUID(chunk["document_id"]) for chunk in results}
 
 
-        validate_filters(filters)
+        # If no chunks match, raise or return a no-op result
+        if not matched_chunk_docs:
+            return {
+                "success": False,
+                "message": "No chunks match the given filters.",
+            }
 
 
-        logger.info(f"Deleting entries with filters: {filters}")
+        # 4. Delete the matching chunks from the database.
+        delete_results = await self.providers.database.chunks_handler.delete(
+            transformed_filters
+        )
 
 
-        try:
+        # 5. From `delete_results`, extract the document_ids that were affected.
+        #    The delete_results should map chunk_id to details including `document_id`.
+        affected_doc_ids = {
+            UUID(info["document_id"])
+            for info in delete_results.values()
+            if info.get("document_id")
+        }
 
 
-            def transform_chunk_id_to_id(
-                filters: dict[str, Any]
-            ) -> dict[str, Any]:
-                if isinstance(filters, dict):
-                    transformed = {}
-                    for key, value in filters.items():
-                        if key == "chunk_id":
-                            transformed["id"] = value
-                        elif key in ["$and", "$or"]:
-                            transformed[key] = [
-                                transform_chunk_id_to_id(item)
-                                for item in value
-                            ]
-                        else:
-                            transformed[key] = transform_chunk_id_to_id(value)
-                    return transformed
-                return filters
-
-            filters_xf = transform_chunk_id_to_id(copy(filters))
-
-            await self.providers.database.chunks_handler.delete(filters)
-
-            vector_delete_results = (
-                await self.providers.database.chunks_handler.delete(filters_xf)
+        # 6. For each affected document, check if the document still has any chunks left.
+        docs_to_delete = []
+        for doc_id in affected_doc_ids:
+            remaining = await self.providers.database.chunks_handler.list_document_chunks(
+                document_id=doc_id,
+                offset=0,
+                limit=1,  # Just need to know if there's at least one left
+                include_vectors=False,
             )
             )
-        except Exception as e:
-            logger.error(f"Error deleting from vector database: {e}")
-            vector_delete_results = {}
-
-        document_ids_to_purge: set[UUID] = set()
-        if vector_delete_results:
-            document_ids_to_purge.update(
-                UUID(result.get("document_id"))
-                for result in vector_delete_results.values()
-                if result.get("document_id")
+            # If no remaining chunks, we should delete the document.
+            if remaining["total_entries"] == 0:
+                docs_to_delete.append(doc_id)
+
+        # 7. Delete documents that no longer have associated chunks.
+        #    Also update graphs if needed (entities/relationships).
+        for doc_id in docs_to_delete:
+            # Delete related entities & relationships if needed:
+            await self.providers.database.graphs_handler.entities.delete(
+                parent_id=doc_id, store_type="documents"
+            )
+            await self.providers.database.graphs_handler.relationships.delete(
+                parent_id=doc_id, store_type="documents"
             )
             )
 
 
-        # TODO: This might be appropriate to move elsewhere and revisit filter logic in other methods
-        def extract_filters(filters: dict[str, Any]) -> dict[str, list[str]]:
-            relational_filters: dict = {}
-
-            def process_filter(filter_dict: dict[str, Any]):
-                if "document_id" in filter_dict:
-                    relational_filters.setdefault(
-                        "filter_document_ids", []
-                    ).append(filter_dict["document_id"]["$eq"])
-                if "owner_id" in filter_dict:
-                    relational_filters.setdefault(
-                        "filter_user_ids", []
-                    ).append(filter_dict["owner_id"]["$eq"])
-                if "collection_ids" in filter_dict:
-                    relational_filters.setdefault(
-                        "filter_collection_ids", []
-                    ).extend(filter_dict["collection_ids"]["$in"])
-
-            # Handle nested conditions
-            if "$and" in filters:
-                for condition in filters["$and"]:
-                    process_filter(condition)
-            elif "$or" in filters:
-                for condition in filters["$or"]:
-                    process_filter(condition)
-            else:
-                process_filter(filters)
-
-            return relational_filters
-
-        relational_filters = extract_filters(filters)
-        if relational_filters:
-            try:
-                documents_overview = (
-                    await self.providers.database.documents_handler.get_documents_overview(  # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
-                        offset=0,
-                        limit=1000,
-                        **relational_filters,  # type: ignore
-                    )
-                )["results"]
-            except Exception as e:
-                logger.error(
-                    f"Error fetching documents from relational database: {e}"
-                )
-                documents_overview = []
-
-            if documents_overview:
-                document_ids_to_purge.update(
-                    doc.id for doc in documents_overview
-                )
-
-            if not document_ids_to_purge:
-                raise R2RException(
-                    status_code=404, message="No entries found for deletion."
-                )
-
-            for document_id in document_ids_to_purge:
-                remaining_chunks = await self.providers.database.chunks_handler.list_document_chunks(  # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
-                    document_id=document_id,
-                    offset=0,
-                    limit=1000,
-                )
-                if remaining_chunks["total_entries"] == 0:
-                    try:
-                        await self.providers.database.chunks_handler.delete(
-                            {"document_id": {"$eq": document_id}}
-                        )
-                        logger.info(
-                            f"Deleted document ID {document_id} from documents_overview."
-                        )
-                    except Exception as e:
-                        logger.error(
-                            f"Error deleting document ID {document_id} from documents_overview: {e}"
-                        )
-                await self.providers.database.graphs_handler.entities.delete(
-                    parent_id=document_id,
-                    store_type="documents",  # type: ignore
-                )
-                await self.providers.database.graphs_handler.relationships.delete(
-                    parent_id=document_id,
-                    store_type="documents",  # type: ignore
-                )
-                await self.providers.database.documents_handler.delete(
-                    document_id=document_id
-                )
-
-                collections = await self.providers.database.collections_handler.get_collections_overview(
-                    offset=0, limit=1000, filter_document_ids=[document_id]
-                )
-                # TODO - Loop over all collections
-                for collection in collections["results"]:
-                    await self.providers.database.documents_handler.set_workflow_status(
-                        id=collection.id,
-                        status_type="graph_sync_status",
-                        status=KGEnrichmentStatus.OUTDATED,
-                    )
-                    await self.providers.database.documents_handler.set_workflow_status(
-                        id=collection.id,
-                        status_type="graph_cluster_status",
-                        status=KGEnrichmentStatus.OUTDATED,
-                    )
+            # Finally, delete the document from documents_overview:
+            await self.providers.database.documents_handler.delete(
+                document_id=doc_id
+            )
 
 
-        return None
+        # 8. Return a summary of what happened.
+        return {
+            "success": True,
+            "deleted_chunks_count": len(delete_results),
+            "deleted_documents_count": len(docs_to_delete),
+            "deleted_document_ids": [str(d) for d in docs_to_delete],
+        }
 
 
     @telemetry_event("DownloadFile")
     @telemetry_event("DownloadFile")
     async def download_file(
     async def download_file(
@@ -303,8 +226,6 @@ class ManagementService(Service):
         user_ids: Optional[list[UUID]] = None,
         user_ids: Optional[list[UUID]] = None,
         collection_ids: Optional[list[UUID]] = None,
         collection_ids: Optional[list[UUID]] = None,
         document_ids: Optional[list[UUID]] = None,
         document_ids: Optional[list[UUID]] = None,
-        *args: Any,
-        **kwargs: Any,
     ):
     ):
         return await self.providers.database.documents_handler.get_documents_overview(
         return await self.providers.database.documents_handler.get_documents_overview(
             offset=offset,
             offset=offset,
@@ -321,8 +242,6 @@ class ManagementService(Service):
         offset: int,
         offset: int,
         limit: int,
         limit: int,
         include_vectors: bool = False,
         include_vectors: bool = False,
-        *args,
-        **kwargs,
     ):
     ):
         return (
         return (
             await self.providers.database.chunks_handler.list_document_chunks(
             await self.providers.database.chunks_handler.list_document_chunks(
@@ -366,9 +285,9 @@ class ManagementService(Service):
         await self.providers.database.chunks_handler.remove_document_from_collection_vector(
         await self.providers.database.chunks_handler.remove_document_from_collection_vector(
             document_id, collection_id
             document_id, collection_id
         )
         )
-        await self.providers.database.graphs_handler.delete_node_via_document_id(
-            document_id, collection_id
-        )
+        # await self.providers.database.graphs_handler.delete_node_via_document_id(
+        #     document_id, collection_id
+        # )
         return None
         return None
 
 
     def _process_relationships(
     def _process_relationships(
@@ -475,7 +394,6 @@ class ManagementService(Service):
             name=name,
             name=name,
             description=description,
             description=description,
         )
         )
-
         return result
         return result
 
 
     @telemetry_event("UpdateCollection")
     @telemetry_event("UpdateCollection")
@@ -674,26 +592,23 @@ class ManagementService(Service):
     @telemetry_event("GetConversation")
     @telemetry_event("GetConversation")
     async def get_conversation(
     async def get_conversation(
         self,
         self,
-        conversation_id: str,
-        auth_user=None,
+        conversation_id: UUID,
+        user_ids: Optional[list[UUID]] = None,
     ) -> Tuple[str, list[Message], list[dict]]:
     ) -> Tuple[str, list[Message], list[dict]]:
-        return await self.providers.database.conversations_handler.get_conversation(  # type: ignore
-            conversation_id
-        )
-
-    async def verify_conversation_access(
-        self, conversation_id: str, user_id: UUID
-    ) -> bool:
-        return await self.providers.database.conversations_handler.verify_conversation_access(
-            conversation_id, user_id
+        return await self.providers.database.conversations_handler.get_conversation(
+            conversation_id=conversation_id,
+            filter_user_ids=user_ids,
         )
         )
 
 
     @telemetry_event("CreateConversation")
     @telemetry_event("CreateConversation")
     async def create_conversation(
     async def create_conversation(
-        self, user_id: Optional[UUID] = None, auth_user=None
-    ) -> dict:
-        return await self.providers.database.conversations_handler.create_conversation(  # type: ignore
-            user_id=user_id
+        self,
+        user_id: Optional[UUID] = None,
+        name: Optional[str] = None,
+    ) -> ConversationResponse:
+        return await self.providers.database.conversations_handler.create_conversation(
+            user_id=user_id,
+            name=name,
         )
         )
 
 
     @telemetry_event("ConversationsOverview")
     @telemetry_event("ConversationsOverview")
@@ -702,53 +617,69 @@ class ManagementService(Service):
         offset: int,
         offset: int,
         limit: int,
         limit: int,
         conversation_ids: Optional[list[UUID]] = None,
         conversation_ids: Optional[list[UUID]] = None,
-        user_ids: Optional[UUID | list[UUID]] = None,
-        auth_user=None,
+        user_ids: Optional[list[UUID]] = None,
     ) -> dict[str, list[dict] | int]:
     ) -> dict[str, list[dict] | int]:
         return await self.providers.database.conversations_handler.get_conversations_overview(
         return await self.providers.database.conversations_handler.get_conversations_overview(
             offset=offset,
             offset=offset,
             limit=limit,
             limit=limit,
-            user_ids=user_ids,
+            filter_user_ids=user_ids,
             conversation_ids=conversation_ids,
             conversation_ids=conversation_ids,
         )
         )
 
 
     @telemetry_event("AddMessage")
     @telemetry_event("AddMessage")
     async def add_message(
     async def add_message(
         self,
         self,
-        conversation_id: str,
+        conversation_id: UUID,
         content: Message,
         content: Message,
-        parent_id: Optional[str] = None,
+        parent_id: Optional[UUID] = None,
         metadata: Optional[dict] = None,
         metadata: Optional[dict] = None,
-        auth_user=None,
     ) -> str:
     ) -> str:
         return await self.providers.database.conversations_handler.add_message(
         return await self.providers.database.conversations_handler.add_message(
-            conversation_id, content, parent_id, metadata
+            conversation_id=conversation_id,
+            content=content,
+            parent_id=parent_id,
+            metadata=metadata,
         )
         )
 
 
     @telemetry_event("EditMessage")
     @telemetry_event("EditMessage")
     async def edit_message(
     async def edit_message(
         self,
         self,
-        message_id: str,
-        new_content: str,
-        additional_metadata: dict,
-        auth_user=None,
-    ) -> Tuple[str, str]:
+        message_id: UUID,
+        new_content: Optional[str] = None,
+        additional_metadata: Optional[dict] = None,
+    ) -> dict[str, Any]:
         return (
         return (
             await self.providers.database.conversations_handler.edit_message(
             await self.providers.database.conversations_handler.edit_message(
-                message_id, new_content, additional_metadata
+                message_id=message_id,
+                new_content=new_content,
+                additional_metadata=additional_metadata or {},
             )
             )
         )
         )
 
 
-    @telemetry_event("updateMessageMetadata")
-    async def update_message_metadata(
-        self, message_id: str, metadata: dict, auth_user=None
-    ):
-        await self.providers.database.conversations_handler.update_message_metadata(
-            message_id, metadata
+    @telemetry_event("UpdateConversation")
+    async def update_conversation(
+        self, conversation_id: UUID, name: str
+    ) -> ConversationResponse:
+        return await self.providers.database.conversations_handler.update_conversation(
+            conversation_id=conversation_id, name=name
         )
         )
 
 
     @telemetry_event("DeleteConversation")
     @telemetry_event("DeleteConversation")
-    async def delete_conversation(self, conversation_id: str, auth_user=None):
+    async def delete_conversation(
+        self,
+        conversation_id: UUID,
+        user_ids: Optional[list[UUID]] = None,
+    ) -> None:
         await self.providers.database.conversations_handler.delete_conversation(
         await self.providers.database.conversations_handler.delete_conversation(
-            conversation_id
+            conversation_id=conversation_id,
+            filter_user_ids=user_ids,
         )
         )
+
+    async def get_user_max_documents(self, user_id: UUID) -> int:
+        return self.config.app.default_max_documents_per_user
+
+    async def get_user_max_chunks(self, user_id: UUID) -> int:
+        return self.config.app.default_max_chunks_per_user
+
+    async def get_user_max_collections(self, user_id: UUID) -> int:
+        return self.config.app.default_max_collections_per_user

+ 7 - 13
core/main/services/retrieval_service.py

@@ -9,19 +9,15 @@ from fastapi import HTTPException
 from core import R2RStreamingRAGAgent
 from core import R2RStreamingRAGAgent
 from core.base import (
 from core.base import (
     DocumentResponse,
     DocumentResponse,
-    EmbeddingPurpose,
     GenerationConfig,
     GenerationConfig,
-    GraphSearchSettings,
     Message,
     Message,
     R2RException,
     R2RException,
     RunManager,
     RunManager,
-    SearchMode,
     SearchSettings,
     SearchSettings,
     manage_run,
     manage_run,
     to_async_generator,
     to_async_generator,
 )
 )
 from core.base.api.models import CombinedSearchResponse, RAGResponse, User
 from core.base.api.models import CombinedSearchResponse, RAGResponse, User
-from core.base.logger.base import RunType
 from core.telemetry.telemetry_decorator import telemetry_event
 from core.telemetry.telemetry_decorator import telemetry_event
 from shared.api.models.management.responses import MessageResponse
 from shared.api.models.management.responses import MessageResponse
 
 
@@ -59,7 +55,7 @@ class RetrievalService(Service):
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> CombinedSearchResponse:
     ) -> CombinedSearchResponse:
-        async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
+        async with manage_run(self.run_manager) as run_id:
             t0 = time.time()
             t0 = time.time()
 
 
             if (
             if (
@@ -149,7 +145,7 @@ class RetrievalService(Service):
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> RAGResponse:
     ) -> RAGResponse:
-        async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
+        async with manage_run(self.run_manager) as run_id:
             try:
             try:
                 # TODO - Remove these transforms once we have a better way to handle this
                 # TODO - Remove these transforms once we have a better way to handle this
                 for (
                 for (
@@ -245,7 +241,7 @@ class RetrievalService(Service):
         message: Optional[Message] = None,
         message: Optional[Message] = None,
         messages: Optional[list[Message]] = None,
         messages: Optional[list[Message]] = None,
     ):
     ):
-        async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
+        async with manage_run(self.run_manager) as run_id:
             try:
             try:
                 if message and messages:
                 if message and messages:
                     raise R2RException(
                     raise R2RException(
@@ -307,17 +303,15 @@ class RetrievalService(Service):
 
 
                 if conversation_id:  # Fetch the existing conversation
                 if conversation_id:  # Fetch the existing conversation
                     try:
                     try:
-                        conversation = await self.providers.database.conversations_handler.get_conversations_overview(
-                            offset=0,
-                            limit=1,
-                            conversation_ids=[conversation_id],
+                        conversation_messages = await self.providers.database.conversations_handler.get_conversation(
+                            conversation_id=conversation_id,
                         )
                         )
                     except Exception as e:
                     except Exception as e:
                         logger.error(f"Error fetching conversation: {str(e)}")
                         logger.error(f"Error fetching conversation: {str(e)}")
 
 
-                    if conversation is not None:
+                    if conversation_messages is not None:
                         messages_from_conversation: list[Message] = []
                         messages_from_conversation: list[Message] = []
-                        for message_response in conversation:
+                        for message_response in conversation_messages:
                             if isinstance(message_response, MessageResponse):
                             if isinstance(message_response, MessageResponse):
                                 messages_from_conversation.append(
                                 messages_from_conversation.append(
                                     message_response.message
                                     message_response.message

+ 14 - 0
core/parsers/__init__.py

@@ -5,17 +5,31 @@ from .text import *
 __all__ = [
 __all__ = [
     # Media parsers
     # Media parsers
     "AudioParser",
     "AudioParser",
+    "BMPParser",
+    "DOCParser",
     "DOCXParser",
     "DOCXParser",
     "ImageParser",
     "ImageParser",
+    "ODTParser",
     "VLMPDFParser",
     "VLMPDFParser",
     "BasicPDFParser",
     "BasicPDFParser",
     "PDFParserUnstructured",
     "PDFParserUnstructured",
     "VLMPDFParser",
     "VLMPDFParser",
     "PPTParser",
     "PPTParser",
+    "PPTXParser",
+    "RTFParser",
     # Structured parsers
     # Structured parsers
     "CSVParser",
     "CSVParser",
     "CSVParserAdvanced",
     "CSVParserAdvanced",
+    "EMLParser",
+    "EPUBParser",
     "JSONParser",
     "JSONParser",
+    "MSGParser",
+    "ORGParser",
+    "P7SParser",
+    "RSTParser",
+    "TIFFParser",
+    "TSVParser",
+    "XLSParser",
     "XLSXParser",
     "XLSXParser",
     "XLSXParserAdvanced",
     "XLSXParserAdvanced",
     # Text parsers
     # Text parsers

+ 12 - 5
core/parsers/media/__init__.py

@@ -1,19 +1,26 @@
+# type: ignore
 from .audio_parser import AudioParser
 from .audio_parser import AudioParser
+from .bmp_parser import BMPParser
+from .doc_parser import DOCParser
 from .docx_parser import DOCXParser
 from .docx_parser import DOCXParser
 from .img_parser import ImageParser
 from .img_parser import ImageParser
-from .pdf_parser import (  # type: ignore
-    BasicPDFParser,
-    PDFParserUnstructured,
-    VLMPDFParser,
-)
+from .odt_parser import ODTParser
+from .pdf_parser import BasicPDFParser, PDFParserUnstructured, VLMPDFParser
 from .ppt_parser import PPTParser
 from .ppt_parser import PPTParser
+from .pptx_parser import PPTXParser
+from .rtf_parser import RTFParser
 
 
 __all__ = [
 __all__ = [
     "AudioParser",
     "AudioParser",
+    "BMPParser",
+    "DOCParser",
     "DOCXParser",
     "DOCXParser",
     "ImageParser",
     "ImageParser",
+    "ODTParser",
     "VLMPDFParser",
     "VLMPDFParser",
     "BasicPDFParser",
     "BasicPDFParser",
     "PDFParserUnstructured",
     "PDFParserUnstructured",
     "PPTParser",
     "PPTParser",
+    "PPTXParser",
+    "RTFParser",
 ]
 ]

+ 0 - 1
core/parsers/media/audio_parser.py

@@ -1,4 +1,3 @@
-import base64
 import logging
 import logging
 import os
 import os
 import tempfile
 import tempfile

+ 74 - 0
core/parsers/media/bmp_parser.py

@@ -0,0 +1,74 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class BMPParser(AsyncParser[str | bytes]):
+    """A parser for BMP image data."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        import struct
+
+        self.struct = struct
+
+    async def extract_bmp_metadata(self, data: bytes) -> dict:
+        """Extract metadata from BMP file header."""
+        try:
+            # BMP header format
+            header_format = "<2sIHHI"
+            header_size = self.struct.calcsize(header_format)
+
+            # Unpack header data
+            signature, file_size, reserved, reserved2, data_offset = (
+                self.struct.unpack(header_format, data[:header_size])
+            )
+
+            # DIB header
+            dib_format = "<IiiHHIIiiII"
+            dib_size = self.struct.calcsize(dib_format)
+            dib_data = self.struct.unpack(dib_format, data[14 : 14 + dib_size])
+
+            width = dib_data[1]
+            height = abs(dib_data[2])  # Height can be negative
+            bits_per_pixel = dib_data[4]
+            compression = dib_data[5]
+
+            return {
+                "width": width,
+                "height": height,
+                "bits_per_pixel": bits_per_pixel,
+                "file_size": file_size,
+                "compression": compression,
+            }
+        except Exception as e:
+            return {"error": f"Failed to parse BMP header: {str(e)}"}
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest BMP data and yield metadata description."""
+        if isinstance(data, str):
+            # Convert base64 string to bytes if needed
+            import base64
+
+            data = base64.b64decode(data)
+
+        metadata = await self.extract_bmp_metadata(data)
+
+        # Generate description of the BMP file
+        yield f"BMP image with dimensions {metadata.get('width', 'unknown')}x{metadata.get('height', 'unknown')} pixels, {metadata.get('bits_per_pixel', 'unknown')} bits per pixel, file size: {metadata.get('file_size', 'unknown')} bytes"

+ 115 - 0
core/parsers/media/doc_parser.py

@@ -0,0 +1,115 @@
+# type: ignore
+import re
+from io import BytesIO
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class DOCParser(AsyncParser[str | bytes]):
+    """A parser for DOC (legacy Microsoft Word) data."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            import olefile
+
+            self.olefile = olefile
+        except ImportError:
+            raise ImportError(
+                "Error: 'olefile' is required to run DOCParser. "
+                "Please install it using pip: pip install olefile"
+            )
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest DOC data and yield text from the document."""
+        if isinstance(data, str):
+            raise ValueError("DOC data must be in bytes format.")
+
+        # Create BytesIO object from the data
+        file_obj = BytesIO(data)
+
+        try:
+            # Open the DOC file using olefile
+            ole = self.olefile.OleFileIO(file_obj)
+
+            # Check if it's a Word document
+            if not ole.exists("WordDocument"):
+                raise ValueError("Not a valid Word document")
+
+            # Read the WordDocument stream
+            word_stream = ole.openstream("WordDocument").read()
+
+            # Read the text from the 0Table or 1Table stream (contains the text)
+            if ole.exists("1Table"):
+                table_stream = ole.openstream("1Table").read()
+            elif ole.exists("0Table"):
+                table_stream = ole.openstream("0Table").read()
+            else:
+                table_stream = b""
+
+            # Extract text content
+            text = self._extract_text(word_stream, table_stream)
+
+            # Clean and split the text
+            paragraphs = self._clean_text(text)
+
+            # Yield non-empty paragraphs
+            for paragraph in paragraphs:
+                if paragraph.strip():
+                    yield paragraph.strip()
+
+        except Exception as e:
+            raise ValueError(f"Error processing DOC file: {str(e)}")
+        finally:
+            ole.close()
+            file_obj.close()
+
+    def _extract_text(self, word_stream: bytes, table_stream: bytes) -> str:
+        """Extract text from Word document streams."""
+        try:
+            text = word_stream.replace(b"\x00", b"").decode(
+                "utf-8", errors="ignore"
+            )
+
+            # If table_stream exists, try to extract additional text
+            if table_stream:
+                table_text = table_stream.replace(b"\x00", b"").decode(
+                    "utf-8", errors="ignore"
+                )
+                text += table_text
+
+            return text
+        except Exception as e:
+            raise ValueError(f"Error extracting text: {str(e)}")
+
+    def _clean_text(self, text: str) -> list[str]:
+        """Clean and split the extracted text into paragraphs."""
+        # Remove binary artifacts and control characters
+        text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\xFF]", "", text)
+
+        # Remove multiple spaces and newlines
+        text = re.sub(r"\s+", " ", text)
+
+        # Split into paragraphs on double newlines or other common separators
+        paragraphs = re.split(r"\n\n|\r\n\r\n|\f", text)
+
+        # Remove empty or whitespace-only paragraphs
+        paragraphs = [p.strip() for p in paragraphs if p.strip()]
+
+        return paragraphs

+ 75 - 25
core/parsers/media/img_parser.py

@@ -1,5 +1,7 @@
+# type: ignore
 import base64
 import base64
 import logging
 import logging
+from io import BytesIO
 from typing import AsyncGenerator
 from typing import AsyncGenerator
 
 
 from core.base.abstractions import GenerationConfig
 from core.base.abstractions import GenerationConfig
@@ -14,8 +16,6 @@ logger = logging.getLogger()
 
 
 
 
 class ImageParser(AsyncParser[str | bytes]):
 class ImageParser(AsyncParser[str | bytes]):
-    """A parser for image data using vision models."""
-
     def __init__(
     def __init__(
         self,
         self,
         config: IngestionConfig,
         config: IngestionConfig,
@@ -28,52 +28,104 @@ class ImageParser(AsyncParser[str | bytes]):
         self.vision_prompt_text = None
         self.vision_prompt_text = None
 
 
         try:
         try:
+            import pillow_heif  # for HEIC support
             from litellm import supports_vision
             from litellm import supports_vision
+            from PIL import Image
 
 
             self.supports_vision = supports_vision
             self.supports_vision = supports_vision
-        except ImportError:
-            logger.error("Failed to import LiteLLM vision support")
+            self.Image = Image
+            self.pillow_heif = pillow_heif
+            self.pillow_heif.register_heif_opener()
+        except ImportError as e:
+            logger.error(f"Failed to import required packages: {str(e)}")
             raise ImportError(
             raise ImportError(
-                "Please install the `litellm` package to use the ImageParser."
+                "Please install the required packages: litellm, Pillow, pillow-heif"
             )
             )
 
 
-    async def ingest(  # type: ignore
-        self, data: str | bytes, **kwargs
-    ) -> AsyncGenerator[str, None]:
-        """
-        Ingest image data and yield a description using vision model.
+    def _is_heic(self, data: bytes) -> bool:
+        """More robust HEIC detection using magic numbers and patterns."""
+        heic_patterns = [
+            b"ftyp",
+            b"heic",
+            b"heix",
+            b"hevc",
+            b"HEIC",
+            b"mif1",
+            b"msf1",
+            b"hevc",
+            b"hevx",
+        ]
+
+        # Check for HEIC file signature
+        try:
+            header = data[:32]  # Get first 32 bytes
+            return any(pattern in header for pattern in heic_patterns)
+        except:
+            return False
+
+    async def _convert_heic_to_jpeg(self, data: bytes) -> bytes:
+        """Convert HEIC image to JPEG format."""
+        try:
+            # Create BytesIO object for input
+            input_buffer = BytesIO(data)
+
+            # Load HEIC image using pillow_heif
+            heif_file = self.pillow_heif.read_heif(input_buffer)
+
+            # Get the primary image - API changed, need to get first image
+            heif_image = heif_file[0]  # Get first image in the container
 
 
-        Args:
-            data: Image data (bytes or base64 string)
-            *args, **kwargs: Additional arguments passed to the completion call
+            # Convert to PIL Image directly from the HEIF image
+            pil_image = heif_image.to_pillow()
 
 
-        Yields:
-            Chunks of image description text
-        """
+            # Convert to RGB if needed
+            if pil_image.mode != "RGB":
+                pil_image = pil_image.convert("RGB")
+
+            # Save as JPEG
+            output_buffer = BytesIO()
+            pil_image.save(output_buffer, format="JPEG", quality=95)
+            return output_buffer.getvalue()
+
+        except Exception as e:
+            logger.error(f"Error converting HEIC to JPEG: {str(e)}")
+            raise
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
         if not self.vision_prompt_text:
         if not self.vision_prompt_text:
-            self.vision_prompt_text = await self.database_provider.prompts_handler.get_cached_prompt(  # type: ignore
-                prompt_name=self.config.vision_img_prompt_name
+            self.vision_prompt_text = (
+                await self.database_provider.prompts_handler.get_cached_prompt(
+                    prompt_name=self.config.vision_img_prompt_name
+                )
             )
             )
         try:
         try:
-            # Verify model supports vision
             if not self.supports_vision(model=self.config.vision_img_model):
             if not self.supports_vision(model=self.config.vision_img_model):
                 raise ValueError(
                 raise ValueError(
                     f"Model {self.config.vision_img_model} does not support vision"
                     f"Model {self.config.vision_img_model} does not support vision"
                 )
                 )
 
 
-            # Encode image data if needed
             if isinstance(data, bytes):
             if isinstance(data, bytes):
-                image_data = base64.b64encode(data).decode("utf-8")
+                try:
+                    # Check if it's HEIC and convert if necessary
+                    if self._is_heic(data):
+                        logger.debug(
+                            "Detected HEIC format, converting to JPEG"
+                        )
+                        data = await self._convert_heic_to_jpeg(data)
+                    image_data = base64.b64encode(data).decode("utf-8")
+                except Exception as e:
+                    logger.error(f"Error processing image data: {str(e)}")
+                    raise
             else:
             else:
                 image_data = data
                 image_data = data
 
 
-            # Configure the generation parameters
             generation_config = GenerationConfig(
             generation_config = GenerationConfig(
                 model=self.config.vision_img_model,
                 model=self.config.vision_img_model,
                 stream=False,
                 stream=False,
             )
             )
 
 
-            # Prepare message with image
             messages = [
             messages = [
                 {
                 {
                     "role": "user",
                     "role": "user",
@@ -89,12 +141,10 @@ class ImageParser(AsyncParser[str | bytes]):
                 }
                 }
             ]
             ]
 
 
-            # Get completion from LiteLLM provider
             response = await self.llm_provider.aget_completion(
             response = await self.llm_provider.aget_completion(
                 messages=messages, generation_config=generation_config
                 messages=messages, generation_config=generation_config
             )
             )
 
 
-            # Extract description from response
             if response.choices and response.choices[0].message:
             if response.choices and response.choices[0].message:
                 content = response.choices[0].message.content
                 content = response.choices[0].message.content
                 if not content:
                 if not content:

+ 65 - 0
core/parsers/media/odt_parser.py

@@ -0,0 +1,65 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class ODTParser(AsyncParser[str | bytes]):
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            import xml.etree.ElementTree as ET
+            import zipfile
+
+            self.zipfile = zipfile
+            self.ET = ET
+        except ImportError:
+            raise ImportError("XML parsing libraries not available")
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        if isinstance(data, str):
+            raise ValueError("ODT data must be in bytes format.")
+
+        from io import BytesIO
+
+        file_obj = BytesIO(data)
+
+        try:
+            with self.zipfile.ZipFile(file_obj) as odt:
+                # ODT files are zip archives containing content.xml
+                content = odt.read("content.xml")
+                root = self.ET.fromstring(content)
+
+                # ODT XML namespace
+                ns = {"text": "urn:oasis:names:tc:opendocument:xmlns:text:1.0"}
+
+                # Extract paragraphs and headers
+                for p in root.findall(".//text:p", ns):
+                    text = "".join(p.itertext())
+                    if text.strip():
+                        yield text.strip()
+
+                for h in root.findall(".//text:h", ns):
+                    text = "".join(h.itertext())
+                    if text.strip():
+                        yield text.strip()
+
+        except Exception as e:
+            raise ValueError(f"Error processing ODT file: {str(e)}")
+        finally:
+            file_obj.close()

+ 9 - 4
core/parsers/media/pdf_parser.py

@@ -12,6 +12,7 @@ from typing import AsyncGenerator
 
 
 import aiofiles
 import aiofiles
 from pdf2image import convert_from_path
 from pdf2image import convert_from_path
+from pdf2image.exceptions import PDFInfoNotInstalledError
 
 
 from core.base.abstractions import GenerationConfig
 from core.base.abstractions import GenerationConfig
 from core.base.parsers.base_parser import AsyncParser
 from core.base.parsers.base_parser import AsyncParser
@@ -20,6 +21,7 @@ from core.base.providers import (
     DatabaseProvider,
     DatabaseProvider,
     IngestionConfig,
     IngestionConfig,
 )
 )
+from shared.abstractions import PDFParsingError, PopperNotFoundError
 
 
 logger = logging.getLogger()
 logger = logging.getLogger()
 
 
@@ -70,11 +72,14 @@ class VLMPDFParser(AsyncParser[str | bytes]):
             "paths_only": True,
             "paths_only": True,
         }
         }
         try:
         try:
-            image_paths = await asyncio.to_thread(convert_from_path, **options)
-            return image_paths
+            return await asyncio.to_thread(convert_from_path, **options)
+        except PDFInfoNotInstalledError:
+            raise PopperNotFoundError()
         except Exception as err:
         except Exception as err:
-            logger.error(f"Error converting PDF to images: {err}")
-            raise
+            logger.error(
+                f"Error converting PDF to images: {err} type: {type(err)}"
+            )
+            raise PDFParsingError(f"Failed to process PDF: {str(err)}", err)
 
 
     async def process_page(
     async def process_page(
         self, image_path: str, page_num: int
         self, image_path: str, page_num: int

+ 64 - 11
core/parsers/media/ppt_parser.py

@@ -1,3 +1,5 @@
+# type: ignore
+import struct
 from io import BytesIO
 from io import BytesIO
 from typing import AsyncGenerator
 from typing import AsyncGenerator
 
 
@@ -10,7 +12,7 @@ from core.base.providers import (
 
 
 
 
 class PPTParser(AsyncParser[str | bytes]):
 class PPTParser(AsyncParser[str | bytes]):
-    """A parser for PPT data."""
+    """A parser for legacy PPT (PowerPoint 97-2003) files."""
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -21,22 +23,73 @@ class PPTParser(AsyncParser[str | bytes]):
         self.database_provider = database_provider
         self.database_provider = database_provider
         self.llm_provider = llm_provider
         self.llm_provider = llm_provider
         self.config = config
         self.config = config
+
         try:
         try:
-            from pptx import Presentation
+            import olefile
 
 
-            self.Presentation = Presentation
+            self.olefile = olefile
         except ImportError:
         except ImportError:
-            raise ValueError(
-                "Error, `python-pptx` is required to run `PPTParser`. Please install it using `pip install python-pptx`."
+            raise ImportError(
+                "Error: 'olefile' is required to run PPTParser. "
+                "Please install it using pip: pip install olefile"
             )
             )
 
 
-    async def ingest(self, data: str | bytes, **kwargs) -> AsyncGenerator[str, None]:  # type: ignore
+    def _extract_text_from_record(self, data: bytes) -> str:
+        """Extract text from a PPT text record."""
+        try:
+            # Skip record header
+            text_data = data[8:]
+            # Convert from UTF-16-LE
+            return text_data.decode("utf-16-le", errors="ignore").strip()
+        except Exception:
+            return ""
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
         """Ingest PPT data and yield text from each slide."""
         """Ingest PPT data and yield text from each slide."""
         if isinstance(data, str):
         if isinstance(data, str):
             raise ValueError("PPT data must be in bytes format.")
             raise ValueError("PPT data must be in bytes format.")
 
 
-        prs = self.Presentation(BytesIO(data))
-        for slide in prs.slides:
-            for shape in slide.shapes:
-                if hasattr(shape, "text"):
-                    yield shape.text
+        try:
+            ole = self.olefile.OleFileIO(BytesIO(data))
+
+            # PPT stores text in PowerPoint Document stream
+            if not ole.exists("PowerPoint Document"):
+                raise ValueError("Not a valid PowerPoint file")
+
+            # Read PowerPoint Document stream
+            ppt_stream = ole.openstream("PowerPoint Document")
+            content = ppt_stream.read()
+
+            # Text records start with 0x0FA0 or 0x0FD0
+            text_markers = [b"\xA0\x0F", b"\xD0\x0F"]
+
+            current_position = 0
+            while current_position < len(content):
+                # Look for text markers
+                for marker in text_markers:
+                    marker_pos = content.find(marker, current_position)
+                    if marker_pos != -1:
+                        # Get record size from header (4 bytes after marker)
+                        size_bytes = content[marker_pos + 2 : marker_pos + 6]
+                        record_size = struct.unpack("<I", size_bytes)[0]
+
+                        # Extract record data
+                        record_data = content[
+                            marker_pos : marker_pos + record_size + 8
+                        ]
+                        text = self._extract_text_from_record(record_data)
+
+                        if text.strip():
+                            yield text.strip()
+
+                        current_position = marker_pos + record_size + 8
+                        break
+                else:
+                    current_position += 1
+
+        except Exception as e:
+            raise ValueError(f"Error processing PPT file: {str(e)}")
+        finally:
+            ole.close()

+ 43 - 0
core/parsers/media/pptx_parser.py

@@ -0,0 +1,43 @@
+# type: ignore
+from io import BytesIO
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class PPTXParser(AsyncParser[str | bytes]):
+    """A parser for PPT data."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+        try:
+            from pptx import Presentation
+
+            self.Presentation = Presentation
+        except ImportError:
+            raise ValueError(
+                "Error, `python-pptx` is required to run `PPTXParser`. Please install it using `pip install python-pptx`."
+            )
+
+    async def ingest(self, data: str | bytes, **kwargs) -> AsyncGenerator[str, None]:  # type: ignore
+        """Ingest PPT data and yield text from each slide."""
+        if isinstance(data, str):
+            raise ValueError("PPT data must be in bytes format.")
+
+        prs = self.Presentation(BytesIO(data))
+        for slide in prs.slides:
+            for shape in slide.shapes:
+                if hasattr(shape, "text"):
+                    yield shape.text

+ 52 - 0
core/parsers/media/rtf_parser.py

@@ -0,0 +1,52 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class RTFParser(AsyncParser[str | bytes]):
+    """Parser for Rich Text Format (.rtf) files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            from striprtf.striprtf import rtf_to_text
+
+            self.striprtf = rtf_to_text
+        except ImportError:
+            raise ImportError(
+                "Error: 'striprtf' is required to run RTFParser. "
+                "Please install it using pip: pip install striprtf"
+            )
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        if isinstance(data, bytes):
+            data = data.decode("utf-8", errors="ignore")
+
+        try:
+            # Convert RTF to plain text
+            plain_text = self.striprtf(data)
+
+            # Split into paragraphs and yield non-empty ones
+            paragraphs = plain_text.split("\n\n")
+            for paragraph in paragraphs:
+                if paragraph.strip():
+                    yield paragraph.strip()
+
+        except Exception as e:
+            raise ValueError(f"Error processing RTF file: {str(e)}")

+ 18 - 0
core/parsers/structured/__init__.py

@@ -1,12 +1,30 @@
 # type: ignore
 # type: ignore
 from .csv_parser import CSVParser, CSVParserAdvanced
 from .csv_parser import CSVParser, CSVParserAdvanced
+from .eml_parser import EMLParser
+from .epub_parser import EPUBParser
 from .json_parser import JSONParser
 from .json_parser import JSONParser
+from .msg_parser import MSGParser
+from .org_parser import ORGParser
+from .p7s_parser import P7SParser
+from .rst_parser import RSTParser
+from .tiff_parser import TIFFParser
+from .tsv_parser import TSVParser
+from .xls_parser import XLSParser
 from .xlsx_parser import XLSXParser, XLSXParserAdvanced
 from .xlsx_parser import XLSXParser, XLSXParserAdvanced
 
 
 __all__ = [
 __all__ = [
     "CSVParser",
     "CSVParser",
     "CSVParserAdvanced",
     "CSVParserAdvanced",
+    "EMLParser",
+    "EPUBParser",
     "JSONParser",
     "JSONParser",
+    "MSGParser",
+    "ORGParser",
+    "P7SParser",
+    "RSTParser",
+    "TIFFParser",
+    "TSVParser",
+    "XLSParser",
     "XLSXParser",
     "XLSXParser",
     "XLSXParserAdvanced",
     "XLSXParserAdvanced",
 ]
 ]

+ 63 - 0
core/parsers/structured/eml_parser.py

@@ -0,0 +1,63 @@
+# type: ignore
+from email import message_from_bytes, policy
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class EMLParser(AsyncParser[str | bytes]):
+    """Parser for EML (email) files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest EML data and yield email content."""
+        if isinstance(data, str):
+            raise ValueError("EML data must be in bytes format.")
+
+        # Parse email with policy for modern email handling
+        email_message = message_from_bytes(data, policy=policy.default)
+
+        # Extract and yield email metadata
+        metadata = []
+        if email_message["Subject"]:
+            metadata.append(f"Subject: {email_message['Subject']}")
+        if email_message["From"]:
+            metadata.append(f"From: {email_message['From']}")
+        if email_message["To"]:
+            metadata.append(f"To: {email_message['To']}")
+        if email_message["Date"]:
+            metadata.append(f"Date: {email_message['Date']}")
+
+        if metadata:
+            yield "\n".join(metadata)
+
+        # Extract and yield email body
+        if email_message.is_multipart():
+            for part in email_message.walk():
+                if part.get_content_type() == "text/plain":
+                    text = part.get_content()
+                    if text.strip():
+                        yield text.strip()
+                elif part.get_content_type() == "text/html":
+                    # Could add HTML parsing here if needed
+                    continue
+        else:
+            body = email_message.get_content()
+            if body.strip():
+                yield body.strip()

+ 128 - 0
core/parsers/structured/epub_parser.py

@@ -0,0 +1,128 @@
+# type: ignore
+import logging
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class EPUBParser(AsyncParser[str | bytes]):
+    """Parser for EPUB electronic book files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            import epub
+
+            self.epub = epub
+        except ImportError:
+            raise ImportError(
+                "Error: 'epub' is required to run EPUBParser. "
+                "Please install it using pip: pip install epub"
+            )
+
+    def _safe_get_metadata(self, book, field: str) -> str | None:
+        """Safely extract metadata field from epub book."""
+        try:
+            return getattr(book, field, None) or getattr(book.opf, field, None)
+        except Exception as e:
+            logger.debug(f"Error getting {field} metadata: {e}")
+            return None
+
+    def _clean_text(self, content: bytes) -> str:
+        """Clean HTML content and return plain text."""
+        try:
+            import re
+
+            text = content.decode("utf-8", errors="ignore")
+            # Remove HTML tags
+            text = re.sub(r"<[^>]+>", " ", text)
+            # Normalize whitespace
+            text = re.sub(r"\s+", " ", text)
+            # Remove any remaining HTML entities
+            text = re.sub(r"&[^;]+;", " ", text)
+            return text.strip()
+        except Exception as e:
+            logger.warning(f"Error cleaning text: {e}")
+            return ""
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest EPUB data and yield book content."""
+        if isinstance(data, str):
+            raise ValueError("EPUB data must be in bytes format.")
+
+        from io import BytesIO
+
+        file_obj = BytesIO(data)
+
+        try:
+            book = self.epub.open_epub(file_obj)
+
+            # Safely extract metadata
+            metadata = []
+            for field, label in [
+                ("title", "Title"),
+                ("creator", "Author"),
+                ("language", "Language"),
+                ("publisher", "Publisher"),
+                ("date", "Date"),
+            ]:
+                if value := self._safe_get_metadata(book, field):
+                    metadata.append(f"{label}: {value}")
+
+            if metadata:
+                yield "\n".join(metadata)
+
+            # Extract content from items
+            try:
+                manifest = getattr(book.opf, "manifest", {}) or {}
+                for item in manifest.values():
+                    try:
+                        if (
+                            getattr(item, "mime_type", "")
+                            == "application/xhtml+xml"
+                        ):
+                            if content := book.read_item(item):
+                                if cleaned_text := self._clean_text(content):
+                                    yield cleaned_text
+                    except Exception as e:
+                        logger.warning(f"Error processing item: {e}")
+                        continue
+
+            except Exception as e:
+                logger.warning(f"Error accessing manifest: {e}")
+                # Fallback: try to get content directly
+                if hasattr(book, "read_item"):
+                    for item_id in getattr(book, "items", []):
+                        try:
+                            if content := book.read_item(item_id):
+                                if cleaned_text := self._clean_text(content):
+                                    yield cleaned_text
+                        except Exception as e:
+                            logger.warning(f"Error in fallback reading: {e}")
+                            continue
+
+        except Exception as e:
+            logger.error(f"Error processing EPUB file: {str(e)}")
+            raise ValueError(f"Error processing EPUB file: {str(e)}")
+        finally:
+            try:
+                file_obj.close()
+            except Exception as e:
+                logger.warning(f"Error closing file: {e}")

+ 75 - 0
core/parsers/structured/msg_parser.py

@@ -0,0 +1,75 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class MSGParser(AsyncParser[str | bytes]):
+    """Parser for MSG (Outlook Message) files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            import extract_msg
+
+            self.extract_msg = extract_msg
+        except ImportError:
+            raise ImportError(
+                "Error: 'extract-msg' is required to run MSGParser. "
+                "Please install it using pip: pip install extract-msg"
+            )
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest MSG data and yield email content."""
+        if isinstance(data, str):
+            raise ValueError("MSG data must be in bytes format.")
+
+        from io import BytesIO
+
+        file_obj = BytesIO(data)
+
+        try:
+            msg = self.extract_msg.Message(file_obj)
+
+            # Extract metadata
+            metadata = []
+            if msg.subject:
+                metadata.append(f"Subject: {msg.subject}")
+            if msg.sender:
+                metadata.append(f"From: {msg.sender}")
+            if msg.to:
+                metadata.append(f"To: {msg.to}")
+            if msg.date:
+                metadata.append(f"Date: {msg.date}")
+
+            if metadata:
+                yield "\n".join(metadata)
+
+            # Extract body
+            if msg.body:
+                yield msg.body.strip()
+
+            # Extract attachments (optional)
+            for attachment in msg.attachments:
+                if hasattr(attachment, "name"):
+                    yield f"\nAttachment: {attachment.name}"
+
+        except Exception as e:
+            raise ValueError(f"Error processing MSG file: {str(e)}")
+        finally:
+            file_obj.close()

+ 79 - 0
core/parsers/structured/org_parser.py

@@ -0,0 +1,79 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class ORGParser(AsyncParser[str | bytes]):
+    """Parser for ORG (Emacs Org-mode) files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            import orgparse
+
+            self.orgparse = orgparse
+        except ImportError:
+            raise ImportError(
+                "Error: 'orgparse' is required to run ORGParser. "
+                "Please install it using pip: pip install orgparse"
+            )
+
+    def _process_node(self, node) -> list[str]:
+        """Process an org-mode node and return its content."""
+        contents = []
+
+        # Add heading with proper level of asterisks
+        if node.level > 0:
+            contents.append(f"{'*' * node.level} {node.heading}")
+
+        # Add body content if exists
+        if node.body:
+            contents.append(node.body.strip())
+
+        return contents
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest ORG data and yield document content."""
+        if isinstance(data, bytes):
+            data = data.decode("utf-8")
+
+        try:
+            # Create a temporary file-like object for orgparse
+            from io import StringIO
+
+            file_obj = StringIO(data)
+
+            # Parse the org file
+            root = self.orgparse.load(file_obj)
+
+            # Process root node if it has content
+            if root.body:
+                yield root.body.strip()
+
+            # Process all nodes
+            for node in root[1:]:  # Skip root node in iteration
+                contents = self._process_node(node)
+                for content in contents:
+                    if content.strip():
+                        yield content.strip()
+
+        except Exception as e:
+            raise ValueError(f"Error processing ORG file: {str(e)}")
+        finally:
+            file_obj.close()

+ 184 - 0
core/parsers/structured/p7s_parser.py

@@ -0,0 +1,184 @@
+# type: ignore
+import email
+import logging
+from base64 import b64decode
+from datetime import datetime
+from email.message import Message
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class P7SParser(AsyncParser[str | bytes]):
+    """Parser for S/MIME messages containing a P7S (PKCS#7 Signature) file."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            from cryptography import x509
+            from cryptography.hazmat.primitives.serialization import pkcs7
+            from cryptography.x509.oid import NameOID
+
+            self.x509 = x509
+            self.pkcs7 = pkcs7
+            self.NameOID = NameOID
+        except ImportError:
+            raise ImportError(
+                "Error: 'cryptography' is required to run P7SParser. "
+                "Please install it using pip: pip install cryptography"
+            )
+
+    def _format_datetime(self, dt: datetime) -> str:
+        """Format datetime in a readable way."""
+        return dt.strftime("%Y-%m-%d %H:%M:%S UTC")
+
+    def _get_name_attribute(self, name, oid):
+        """Safely get name attribute."""
+        try:
+            return name.get_attributes_for_oid(oid)[0].value
+        except (IndexError, ValueError):
+            return None
+
+    def _extract_cert_info(self, cert) -> dict:
+        """Extract relevant information from a certificate."""
+        try:
+            subject = cert.subject
+            issuer = cert.issuer
+
+            info = {
+                "common_name": self._get_name_attribute(
+                    subject, self.NameOID.COMMON_NAME
+                ),
+                "organization": self._get_name_attribute(
+                    subject, self.NameOID.ORGANIZATION_NAME
+                ),
+                "email": self._get_name_attribute(
+                    subject, self.NameOID.EMAIL_ADDRESS
+                ),
+                "issuer_common_name": self._get_name_attribute(
+                    issuer, self.NameOID.COMMON_NAME
+                ),
+                "issuer_organization": self._get_name_attribute(
+                    issuer, self.NameOID.ORGANIZATION_NAME
+                ),
+                "serial_number": hex(cert.serial_number)[2:],
+                "not_valid_before": self._format_datetime(
+                    cert.not_valid_before
+                ),
+                "not_valid_after": self._format_datetime(cert.not_valid_after),
+                "version": cert.version.name,
+            }
+
+            return {k: v for k, v in info.items() if v is not None}
+
+        except Exception as e:
+            logger.warning(f"Error extracting certificate info: {e}")
+            return {}
+
+    def _try_parse_signature(self, data: bytes):
+        """Try to parse the signature data as PKCS7 containing certificates."""
+        exceptions = []
+
+        # Try DER format PKCS7
+        try:
+            certs = self.pkcs7.load_der_pkcs7_certificates(data)
+            if certs is not None:
+                return certs
+        except Exception as e:
+            exceptions.append(f"DER PKCS7 parsing failed: {str(e)}")
+
+        # Try PEM format PKCS7
+        try:
+            certs = self.pkcs7.load_pem_pkcs7_certificates(data)
+            if certs is not None:
+                return certs
+        except Exception as e:
+            exceptions.append(f"PEM PKCS7 parsing failed: {str(e)}")
+
+        raise ValueError(
+            "Unable to parse signature file as PKCS7 with certificates. Attempted methods:\n"
+            + "\n".join(exceptions)
+        )
+
+    def _extract_p7s_data_from_mime(self, raw_data: bytes) -> bytes:
+        """Extract the raw PKCS#7 signature data from a MIME message."""
+        msg: Message = email.message_from_bytes(raw_data)
+
+        # If the message is multipart, find the part with application/x-pkcs7-signature
+        if msg.is_multipart():
+            for part in msg.walk():
+                ctype = part.get_content_type()
+                if ctype == "application/x-pkcs7-signature":
+                    # Get the base64 encoded data from the payload
+                    payload = part.get_payload(decode=False)
+                    # payload at this stage is a base64 string
+                    try:
+                        return b64decode(payload)
+                    except Exception as e:
+                        raise ValueError(
+                            f"Failed to decode base64 PKCS#7 signature: {str(e)}"
+                        )
+            # If we reach here, no PKCS#7 part was found
+            raise ValueError(
+                "No application/x-pkcs7-signature part found in the MIME message."
+            )
+        else:
+            # Not multipart, try to parse directly if it's just a raw P7S
+            # This scenario is less common; usually it's multipart.
+            if msg.get_content_type() == "application/x-pkcs7-signature":
+                payload = msg.get_payload(decode=False)
+                return b64decode(payload)
+
+            raise ValueError(
+                "The provided data does not contain a valid S/MIME signed message."
+            )
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        """Ingest an S/MIME message and extract the PKCS#7 signature information."""
+        # If data is a string, it might be base64 encoded, or it might be the raw MIME text.
+        # We should assume it's raw MIME text here because the input includes MIME headers.
+        if isinstance(data, str):
+            # Convert to bytes (raw MIME)
+            data = data.encode("utf-8")
+
+        try:
+            # Extract the raw PKCS#7 data (der/pem) from the MIME message
+            p7s_data = self._extract_p7s_data_from_mime(data)
+
+            # Parse the PKCS#7 data for certificates
+            certificates = self._try_parse_signature(p7s_data)
+
+            if not certificates:
+                yield "No certificates found in the provided P7S file."
+                return
+
+            # Process each certificate
+            for i, cert in enumerate(certificates, 1):
+                if cert_info := self._extract_cert_info(cert):
+                    yield f"Certificate {i}:"
+                    for key, value in cert_info.items():
+                        if value:
+                            yield f"{key.replace('_', ' ').title()}: {value}"
+                    yield ""  # Empty line between certificates
+                else:
+                    yield f"Certificate {i}: No detailed information extracted."
+
+        except Exception as e:
+            raise ValueError(f"Error processing P7S file: {str(e)}")

+ 65 - 0
core/parsers/structured/rst_parser.py

@@ -0,0 +1,65 @@
+# type: ignore
+from typing import AsyncGenerator
+
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+
+class RSTParser(AsyncParser[str | bytes]):
+    """Parser for reStructuredText (.rst) files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+
+        try:
+            from docutils.core import publish_string
+            from docutils.writers import html5_polyglot
+
+            self.publish_string = publish_string
+            self.html5_polyglot = html5_polyglot
+        except ImportError:
+            raise ImportError(
+                "Error: 'docutils' is required to run RSTParser. "
+                "Please install it using pip: pip install docutils"
+            )
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        if isinstance(data, bytes):
+            data = data.decode("utf-8")
+
+        try:
+            # Convert RST to HTML
+            html = self.publish_string(
+                source=data,
+                writer=self.html5_polyglot.Writer(),
+                settings_overrides={"report_level": 5},
+            )
+
+            # Basic HTML cleanup
+            import re
+
+            text = html.decode("utf-8")
+            text = re.sub(r"<[^>]+>", " ", text)
+            text = re.sub(r"\s+", " ", text)
+
+            # Split into paragraphs and yield non-empty ones
+            paragraphs = text.split("\n\n")
+            for paragraph in paragraphs:
+                if paragraph.strip():
+                    yield paragraph.strip()
+
+        except Exception as e:
+            raise ValueError(f"Error processing RST file: {str(e)}")

+ 116 - 0
core/parsers/structured/tiff_parser.py

@@ -0,0 +1,116 @@
+# type: ignore
+import base64
+import logging
+from io import BytesIO
+from typing import AsyncGenerator
+
+from core.base.abstractions import GenerationConfig
+from core.base.parsers.base_parser import AsyncParser
+from core.base.providers import (
+    CompletionProvider,
+    DatabaseProvider,
+    IngestionConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class TIFFParser(AsyncParser[str | bytes]):
+    """Parser for TIFF image files."""
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: DatabaseProvider,
+        llm_provider: CompletionProvider,
+    ):
+        self.database_provider = database_provider
+        self.llm_provider = llm_provider
+        self.config = config
+        self.vision_prompt_text = None
+
+        try:
+            from litellm import supports_vision
+            from PIL import Image
+
+            self.supports_vision = supports_vision
+            self.Image = Image
+        except ImportError:
+            raise ImportError("Required packages not available.")
+
+    async def _convert_tiff_to_jpeg(self, data: bytes) -> bytes:
+        """Convert TIFF image to JPEG format."""
+        try:
+            # Open TIFF image
+            with BytesIO(data) as input_buffer:
+                tiff_image = self.Image.open(input_buffer)
+
+                # Convert to RGB if needed
+                if tiff_image.mode not in ("RGB", "L"):
+                    tiff_image = tiff_image.convert("RGB")
+
+                # Save as JPEG
+                output_buffer = BytesIO()
+                tiff_image.save(output_buffer, format="JPEG", quality=95)
+                return output_buffer.getvalue()
+        except Exception as e:
+            raise ValueError(f"Error converting TIFF to JPEG: {str(e)}")
+
+    async def ingest(
+        self, data: str | bytes, **kwargs
+    ) -> AsyncGenerator[str, None]:
+        if not self.vision_prompt_text:
+            self.vision_prompt_text = (
+                await self.database_provider.prompts_handler.get_cached_prompt(
+                    prompt_name=self.config.vision_img_prompt_name
+                )
+            )
+
+        try:
+            if not self.supports_vision(model=self.config.vision_img_model):
+                raise ValueError(
+                    f"Model {self.config.vision_img_model} does not support vision"
+                )
+
+            # Convert TIFF to JPEG
+            if isinstance(data, bytes):
+                jpeg_data = await self._convert_tiff_to_jpeg(data)
+                image_data = base64.b64encode(jpeg_data).decode("utf-8")
+            else:
+                image_data = data
+
+            # Use vision model to analyze image
+            generation_config = GenerationConfig(
+                model=self.config.vision_img_model,
+                stream=False,
+            )
+
+            messages = [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": self.vision_prompt_text},
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/jpeg;base64,{image_data}"
+                            },
+                        },
+                    ],
+                }
+            ]
+
+            response = await self.llm_provider.aget_completion(
+                messages=messages, generation_config=generation_config
+            )
+
+            if response.choices and response.choices[0].message:
+                content = response.choices[0].message.content
+                if not content:
+                    raise ValueError("No content in response")
+                yield content
+            else:
+                raise ValueError("No response content")
+
+        except Exception as e:
+            raise ValueError(f"Error processing TIFF file: {str(e)}")

Some files were not shown because too many files changed in this diff