소스 검색

first commit

Gogs 3 달 전
커밋
6a36d2c526
100개의 변경된 파일17198개의 추가작업 그리고 0개의 파일을 삭제
  1. 20 0
      .dockerignore
  2. 10 0
      .isort.cfg
  3. 66 0
      Dockerfile
  4. 29 0
      Dockerfile.unstructured
  5. 103 0
      README.md
  6. 0 0
      cli/__init__.py
  7. 41 0
      cli/command_group.py
  8. 0 0
      cli/commands/__init__.py
  9. 141 0
      cli/commands/collections.py
  10. 124 0
      cli/commands/conversations.py
  11. 138 0
      cli/commands/database.py
  12. 393 0
      cli/commands/documents.py
  13. 382 0
      cli/commands/graphs.py
  14. 89 0
      cli/commands/indices.py
  15. 60 0
      cli/commands/prompts.py
  16. 189 0
      cli/commands/retrieval.py
  17. 435 0
      cli/commands/system.py
  18. 143 0
      cli/commands/users.py
  19. 61 0
      cli/main.py
  20. 0 0
      cli/utils/__init__.py
  21. 225 0
      cli/utils/database_utils.py
  22. 578 0
      cli/utils/docker_utils.py
  23. 21 0
      cli/utils/param_types.py
  24. 152 0
      cli/utils/telemetry.py
  25. 16 0
      cli/utils/timer.py
  26. 400 0
      compose.full.yaml
  27. 420 0
      compose.full_with_replicas.yaml
  28. 123 0
      compose.yaml.back
  29. 232 0
      core/__init__.py
  30. 11 0
      core/agent/__init__.py
  31. 240 0
      core/agent/base.py
  32. 159 0
      core/agent/rag.py
  33. 104 0
      core/agent/serper.py
  34. 139 0
      core/base/__init__.py
  35. 169 0
      core/base/abstractions/__init__.py
  36. 10 0
      core/base/agent/__init__.py
  37. 247 0
      core/base/agent/agent.py
  38. 22 0
      core/base/agent/base.py
  39. 159 0
      core/base/api/models/__init__.py
  40. 11 0
      core/base/logger/__init__.py
  41. 32 0
      core/base/logger/base.py
  42. 62 0
      core/base/logger/run_manager.py
  43. 5 0
      core/base/parsers/__init__.py
  44. 13 0
      core/base/parsers/base_parser.py
  45. 5 0
      core/base/pipeline/__init__.py
  46. 180 0
      core/base/pipeline/base_pipeline.py
  47. 3 0
      core/base/pipes/__init__.py
  48. 128 0
      core/base/pipes/base_pipe.py
  49. 57 0
      core/base/providers/__init__.py
  50. 155 0
      core/base/providers/auth.py
  51. 68 0
      core/base/providers/base.py
  52. 39 0
      core/base/providers/crypto.py
  53. 206 0
      core/base/providers/database.py
  54. 73 0
      core/base/providers/email.py
  55. 197 0
      core/base/providers/embedding.py
  56. 204 0
      core/base/providers/ingestion.py
  57. 184 0
      core/base/providers/llm.py
  58. 70 0
      core/base/providers/orchestration.py
  59. 41 0
      core/base/utils/__init__.py
  60. 25 0
      core/configs/full.toml
  61. 57 0
      core/configs/full_azure.toml
  62. 70 0
      core/configs/full_local_llm.toml
  63. 68 0
      core/configs/local_llm.toml
  64. 46 0
      core/configs/r2r_azure.toml
  65. 8 0
      core/configs/r2r_with_auth.toml
  66. 5 0
      core/database/__init__.py
  67. 209 0
      core/database/base.py
  68. 1488 0
      core/database/chunks.py
  69. 471 0
      core/database/collections.py
  70. 376 0
      core/database/conversations.py
  71. 933 0
      core/database/documents.py
  72. 275 0
      core/database/files.py
  73. 2790 0
      core/database/graphs.py
  74. 229 0
      core/database/limits.py
  75. 296 0
      core/database/postgres.py
  76. 0 0
      core/database/prompts/__init__.py
  77. 27 0
      core/database/prompts/chunk_enrichment.yaml
  78. 41 0
      core/database/prompts/default_collection_summary.yaml
  79. 28 0
      core/database/prompts/default_rag.yaml
  80. 18 0
      core/database/prompts/default_summary.yaml
  81. 3 0
      core/database/prompts/default_system.yaml
  82. 109 0
      core/database/prompts/graphrag_communities.yaml
  83. 24 0
      core/database/prompts/graphrag_entity_deduplication.yaml
  84. 39 0
      core/database/prompts/graphrag_entity_description.yaml
  85. 55 0
      core/database/prompts/graphrag_map_system.yaml
  86. 43 0
      core/database/prompts/graphrag_reduce_system.yaml
  87. 134 0
      core/database/prompts/graphrag_relationships_extraction_few_shot.yaml
  88. 29 0
      core/database/prompts/hyde.yaml
  89. 16 0
      core/database/prompts/rag_agent.yaml
  90. 23 0
      core/database/prompts/rag_context.yaml
  91. 27 0
      core/database/prompts/rag_fusion.yaml
  92. 4 0
      core/database/prompts/vision_img.yaml
  93. 42 0
      core/database/prompts/vision_pdf.yaml
  94. 639 0
      core/database/prompts_handler.py
  95. 67 0
      core/database/tokens.py
  96. 660 0
      core/database/users.py
  97. 5 0
      core/database/vecs/__init__.py
  98. 16 0
      core/database/vecs/adapter/__init__.py
  99. 126 0
      core/database/vecs/adapter/base.py
  100. 93 0
      core/database/vecs/adapter/markdown.py

+ 20 - 0
.dockerignore

@@ -0,0 +1,20 @@
+__pycache__
+*.pyc
+*.pyo
+*.pyd
+.Python
+env
+pip-log.txt
+pip-delete-this-directory.txt
+.tox
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.log
+.git
+.mypy_cache
+.pytest_cache
+.hypothesis

+ 10 - 0
.isort.cfg

@@ -0,0 +1,10 @@
+[settings]
+profile = black
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+ensure_newline_before_comments = true
+line_length = 79
+sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
+skip = .tox,__pycache__,*.pyc,venv*/*,reports,venv,env,node_modules,.env,.venv,dist,my_env

+ 66 - 0
Dockerfile

@@ -0,0 +1,66 @@
+FROM python:3.12-slim AS builder
+
+# 修改APT源为传入的镜像
+
+RUN apt-get update && apt-get install -y gnupg2
+RUN apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-keys 3B4FE6ACC0B21F32 871920D1991BC93C
+# Install system dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    gcc g++ musl-dev curl libffi-dev gfortran libopenblas-dev \
+    poppler-utils \
+    && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+RUN pip install --no-cache-dir poetry -i https://pypi.tuna.tsinghua.edu.cn/simple
+
+
+# Add Rust to PATH
+ENV PATH="/root/.cargo/bin:${PATH}"
+
+RUN mkdir -p /app/py
+WORKDIR /app/py
+COPY pyproject.toml /app/py/pyproject.toml
+
+#RUN poetry config repositories.pypi https://mirrors.aliyun.com/pypi/simple/
+#RUN export POETRY_PYPI_REPOSITORIES="https://pypi.tuna.tsinghua.edu.cn/simple"
+
+# Install dependencies
+RUN poetry config virtualenvs.create false \
+    && poetry install --extras "core ingestion-bundle" --only main --no-root \
+    && pip install --no-cache-dir gunicorn uvicorn -i https://pypi.tuna.tsinghua.edu.cn/simple
+
+# Create the final image
+FROM python:3.12-slim
+
+# Install runtime dependencies
+RUN apt-get update \
+    && apt-get install -y --no-install-recommends curl poppler-utils \
+    && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Add poppler to PATH
+ENV PATH="/usr/bin:${PATH}"
+
+# Debugging steps
+RUN echo "PATH: $PATH"
+RUN which pdfinfo
+RUN pdfinfo -v
+
+WORKDIR /app
+
+COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
+COPY --from=builder /usr/local/bin /usr/local/bin
+
+# Expose the port and set environment variables
+ARG R2R_PORT=8000 R2R_HOST=0.0.0.0
+ENV R2R_PORT=$R2R_PORT R2R_HOST=$R2R_HOST
+EXPOSE $R2R_PORT
+
+COPY . /app
+# Copy the application and config
+COPY core /app/core
+COPY r2r /app/r2r
+COPY shared /app/shared
+COPY r2r.toml /app/r2r.toml
+COPY pyproject.toml /app/pyproject.toml
+
+# Run the application
+CMD ["sh", "-c", "uvicorn core.main.app_entry:app --host $R2R_HOST --port $R2R_PORT"]

+ 29 - 0
Dockerfile.unstructured

@@ -0,0 +1,29 @@
+FROM python:3.12-slim AS builder
+
+# Install system dependencies (including those needed for Unstructured and OpenCV)
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    gcc g++ musl-dev curl libffi-dev gfortran libopenblas-dev \
+    tesseract-ocr libtesseract-dev libleptonica-dev pkg-config \
+    poppler-utils libmagic1 pandoc libreoffice \
+    libgl1-mesa-glx libglib2.0-0 \
+    && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+ENV TESSDATA_PREFIX=/usr/share/tesseract-ocr/5/tessdata
+
+ENV PYTHONDONTWRITEBYTECODE=1
+ENV PYTHONUNBUFFERED=1
+
+WORKDIR /app
+
+RUN pip install --no-cache-dir unstructured "unstructured[all-docs]"
+
+
+RUN python -c "from unstructured.partition.model_init import initialize; initialize()"
+
+RUN pip install gunicorn uvicorn fastapi httpx
+
+COPY core/integrations/unstructured/main.py .
+
+EXPOSE 7275
+
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7275", "--workers", "8"]

+ 103 - 0
README.md

@@ -0,0 +1,103 @@
+<p align="left">
+  <a href="https://r2r-docs.sciphi.ai"><img src="https://img.shields.io/badge/docs.sciphi.ai-3F16E4" alt="Docs"></a>
+  <a href="https://discord.gg/p6KqD2kjtB"><img src="https://img.shields.io/discord/1120774652915105934?style=social&logo=discord" alt="Discord"></a>
+  <a href="https://github.com/SciPhi-AI"><img src="https://img.shields.io/github/stars/SciPhi-AI/R2R" alt="Github Stars"></a>
+  <a href="https://github.com/SciPhi-AI/R2R/pulse"><img src="https://img.shields.io/github/commit-activity/w/SciPhi-AI/R2R" alt="Commits-per-week"></a>
+  <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-purple.svg" alt="License: MIT"></a>
+  <a href="https://gurubase.io/g/r2r"><img src="https://img.shields.io/badge/Gurubase-Ask%20R2R%20Guru-006BFF" alt="Gurubase: R2R Guru"></a>
+</p>
+
+<img width="1041" alt="r2r" src="https://github.com/user-attachments/assets/b6ee6a78-5d37-496d-ae10-ce18eee7a1d6">
+<h3 align="center">
+  Containerized, state of the art Retrieval-Augmented Generation (RAG) with a RESTful API
+</h3>
+
+# About
+R2R (RAG to Riches) is the most advanced AI retrieval system, supporting Retrieval-Augmented Generation (RAG) with production-ready features. Built around a containerized [RESTful API]([https://r2r-docs.sciphi.ai/api-reference/introduction](https://r2r-docs.sciphi.ai/api-and-sdks/introduction)), R2R offers multimodal content ingestion, hybrid search functionality, configurable GraphRAG, and comprehensive user and document management.
+
+For a more complete view of R2R, check out the [full documentation](https://r2r-docs.sciphi.ai/).
+
+## Key Features
+- [**📁 Multimodal Ingestion**](https://r2r-docs.sciphi.ai/documentation/configuration/ingestion): Parse `.txt`, `.pdf`, `.json`, `.png`, `.mp3`, and more.
+- [**🔍 Hybrid Search**](https://r2r-docs.sciphi.ai/cookbooks/hybrid-search): Combine semantic and keyword search with reciprocal rank fusion for enhanced relevancy.
+- [**🔗 Knowledge Graphs**](https://r2r-docs.sciphi.ai/cookbooks/knowledge-graphs): Automatically extract entities and relationships and build knowledge graphs.
+- [**📊 GraphRAG**](https://r2r-docs.sciphi.ai/cookbooks/graphrag): Cluster and summarize communities with over your created graphs for even richer insights.
+- [**🗂️ User Management**](https://r2r-docs.sciphi.ai/cookbooks/user-auth): Efficiently manage documents and user roles within R2R.
+- [**🔭 Observability**](https://r2r-docs.sciphi.ai/cookbooks/observability): Observe and analyze your RAG engine performance.
+- [**🧩 Configuration**](https://r2r-docs.sciphi.ai/documentation/configuration/overview): Setup your application using intuitive configuration files.
+- [**🖥️ Dashboard**](https://r2r-docs.sciphi.ai/cookbooks/application): An open-source React+Next.js admin dashboard to interact with R2R via GUI.
+
+
+## [What's New](https://r2r-docs.sciphi.ai/introduction/whats-new)
+
+- Release 3.3.0&nbsp;&nbsp;&nbsp;&nbsp;December 3, 2024&nbsp;&nbsp;&nbsp;&nbsp;
+
+  Warning: These changes are breaking!
+  - [V3 API Specification](https://r2r-docs.sciphi.ai/api-and-sdks/introduction)
+
+## Install with pip
+The recommended way to get started with R2R is by using our CLI.
+
+```bash
+pip install r2r
+```
+
+
+You may run R2R directly from the python package, but additional dependencies like Postgres+pgvector must be configured and the full R2R core is required:
+
+```bash
+# export OPENAI_API_KEY=sk-...
+# export POSTGRES...
+pip install 'r2r[core,ingestion-bundle]'
+r2r --config-name=default serve
+```
+
+Alternatively, R2R can be launched alongside its requirements inside Docker:
+
+```bash
+# export OPENAI_API_KEY=sk-...
+r2r serve --docker --full
+```
+
+The command above will install the `full` installation which includes Hatchet for orchestration and Unstructured.io for parsing.
+
+## Getting Started
+
+- [Installation](https://r2r-docs.sciphi.ai/documentation/installation/overview): Quick installation of R2R using Docker or pip
+- [Quickstart](https://r2r-docs.sciphi.ai/documentation/quickstart): A quick introduction to R2R's core features
+- [Setup](https://r2r-docs.sciphi.ai/documentation/configuration/overview): Learn how to setup and configure R2R
+- [API & SDKs](https://r2r-docs.sciphi.ai/api-and-sdks/introduction): API reference and Python/JS SDKs for interacting with R2R
+
+## Cookbooks
+
+- Advanced RAG Pipelines
+  - [RAG Agent](https://r2r-docs.sciphi.ai/cookbooks/agent): R2R's powerful RAG agent
+  - [Hybrid Search](https://r2r-docs.sciphi.ai/cookbooks/hybrid-search): Introduction to hybrid search
+  - [Advanced RAG](https://r2r-docs.sciphi.ai/cookbooks/advanced-rag): Advanced RAG features
+
+- Orchestration
+  - [Orchestration](https://r2r-docs.sciphi.ai/cookbooks/orchestration): R2R event orchestration
+
+- User Management
+  - [Web Development](https://r2r-docs.sciphi.ai/cookbooks/web-dev): Building webapps using R2R
+  - [User Auth](https://r2r-docs.sciphi.ai/cookbooks/user-auth): Authenticating users
+  - [Collections](https://r2r-docs.sciphi.ai/cookbooks/collections): Document collections
+  - [Analytics & Observability](https://r2r-docs.sciphi.ai/cookbooks/observability): End-to-end logging and analytics
+  - [Web Application](https://r2r-docs.sciphi.ai/cookbooks/application): Connecting with the R2R Application
+
+
+## Community
+
+[Join our Discord server](https://discord.gg/p6KqD2kjtB) to get support and connect with both the R2R team and other developers in the community. Whether you're encountering issues, looking for advice on best practices, or just want to share your experiences, we're here to help.
+
+# Contributing
+
+We welcome contributions of all sizes! Here's how you can help:
+
+- Open a PR for new features, improvements, or better documentation.
+- Submit a [feature request](https://github.com/SciPhi-AI/R2R/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=) or [bug report](https://github.com/SciPhi-AI/R2R/issues/new?assignees=&labels=&projects=&template=bug_report.md&title=)
+
+### Our Contributors
+<a href="https://github.com/SciPhi-AI/R2R/graphs/contributors">
+  <img src="https://contrib.rocks/image?repo=SciPhi-AI/R2R" />
+</a>

+ 0 - 0
cli/__init__.py


+ 41 - 0
cli/command_group.py

@@ -0,0 +1,41 @@
+from functools import wraps
+
+import asyncclick as click
+from asyncclick import pass_context
+from asyncclick.exceptions import Exit
+
+from sdk import R2RAsyncClient
+
+
+def deprecated_command(new_name):
+    def decorator(f):
+        @wraps(f)
+        async def wrapped(*args, **kwargs):
+            click.secho(
+                f"Warning: This command is deprecated. Please use '{new_name}' instead.",
+                fg="yellow",
+                err=True,
+            )
+            return await f(*args, **kwargs)
+
+        return wrapped
+
+    return decorator
+
+
+@click.group()
+@click.option(
+    "--base-url", default="http://localhost:7272", help="Base URL for the API"
+)
+@pass_context
+async def cli(ctx, base_url):
+    """R2R CLI for all core operations."""
+
+    ctx.obj = R2RAsyncClient(base_url=base_url)
+
+    # Override the default exit behavior
+    def silent_exit(self, code=0):
+        if code != 0:
+            raise Exit(code)
+
+    ctx.exit = silent_exit.__get__(ctx)

+ 0 - 0
cli/commands/__init__.py


+ 141 - 0
cli/commands/collections.py

@@ -0,0 +1,141 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def collections():
+    """Collections commands."""
+    pass
+
+
+@collections.command()
+@click.argument("name", required=True, type=str)
+@click.option("--description", type=str)
+@pass_context
+async def create(ctx, name, description):
+    """Create a collection."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.create(
+            name=name,
+            description=description,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@collections.command()
+@click.option("--ids", multiple=True, help="Collection IDs to fetch")
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, ids, offset, limit):
+    """Get an overview of collections."""
+    client: R2RAsyncClient = ctx.obj
+    ids = list(ids) if ids else None
+
+    with timer():
+        response = await client.collections.list(
+            ids=ids,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@collections.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def retrieve(ctx, id):
+    """Retrieve a collection by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.retrieve(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@collections.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def delete(ctx, id):
+    """Delete a collection by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.delete(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@collections.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_documents(ctx, id, offset, limit):
+    """Get an overview of collections."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.list_documents(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@collections.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_users(ctx, id, offset, limit):
+    """Get an overview of collections."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.list_users(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))

+ 124 - 0
cli/commands/conversations.py

@@ -0,0 +1,124 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def conversations():
+    """Conversations commands."""
+    pass
+
+
+@conversations.command()
+@pass_context
+async def create(ctx):
+    """Create a conversation."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.conversations.create()
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@conversations.command()
+@click.option("--ids", multiple=True, help="Conversation IDs to fetch")
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, ids, offset, limit):
+    """Get an overview of conversations."""
+    client: R2RAsyncClient = ctx.obj
+    ids = list(ids) if ids else None
+
+    with timer():
+        response = await client.conversations.list(
+            ids=ids,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@conversations.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def retrieve(ctx, id):
+    """Retrieve a collection by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.conversations.retrieve(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@conversations.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def delete(ctx, id):
+    """Delete a collection by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.conversations.delete(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@conversations.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def list_branches(ctx, id):
+    """List all branches in a conversation."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.conversations.list_branches(
+            id=id,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@conversations.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_users(ctx, id, offset, limit):
+    """Get an overview of collections."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.collections.list_users(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))

+ 138 - 0
cli/commands/database.py

@@ -0,0 +1,138 @@
+import sys
+
+import asyncclick as click
+
+from ..utils.database_utils import (
+    check_database_connection,
+    get_database_url_from_env,
+    run_alembic_command,
+)
+
+
+@click.group()
+def db():
+    """Database management commands."""
+    pass
+
+
+@db.command()
+@click.option(
+    "--schema", help="Schema name to operate on (defaults to R2R_PROJECT_NAME)"
+)
+async def history(schema):
+    """Show database migration history for a specific schema."""
+    try:
+        db_url = get_database_url_from_env(False)
+        if not await check_database_connection(db_url):
+            click.secho(
+                "Database connection failed. Please check your environment variables.",
+                fg="red",
+            )
+            sys.exit(1)
+
+        result = await run_alembic_command("history", schema_name=schema)
+        if result != 0:
+            click.secho("Failed to get migration history.", fg="red")
+            sys.exit(1)
+    except Exception as e:
+        click.secho(f"Error getting migration history: {str(e)}", fg="red")
+        sys.exit(1)
+
+
+@db.command()
+@click.option(
+    "--schema", help="Schema name to operate on (defaults to R2R_PROJECT_NAME)"
+)
+async def current(schema):
+    """Show current database revision for a specific schema."""
+    try:
+        db_url = get_database_url_from_env(False)
+        if not await check_database_connection(db_url):
+            click.secho(
+                "Database connection failed. Please check your environment variables.",
+                fg="red",
+            )
+            sys.exit(1)
+
+        result = await run_alembic_command("current", schema_name=schema)
+        if result != 0:
+            click.secho("Failed to get current revision.", fg="red")
+            sys.exit(1)
+    except Exception as e:
+        click.secho(f"Error getting current revision: {str(e)}", fg="red")
+        sys.exit(1)
+
+
+@db.command()
+@click.option(
+    "--schema", help="Schema name to operate on (defaults to R2R_PROJECT_NAME)"
+)
+@click.option("--revision", help="Upgrade to a specific revision")
+async def upgrade(schema, revision):
+    """Upgrade database schema to the latest revision or a specific revision."""
+    try:
+        db_url = get_database_url_from_env(False)
+        if not await check_database_connection(db_url):
+            click.secho(
+                "Database connection failed. Please check your environment variables.",
+                fg="red",
+            )
+            sys.exit(1)
+
+        click.echo(
+            f"Running database upgrade for schema {schema or 'default'}..."
+        )
+        print(f"Upgrading revision = {revision}")
+        command = f"upgrade {revision}" if revision else "upgrade"
+        result = await run_alembic_command(command, schema_name=schema)
+
+        if result == 0:
+            click.secho("Database upgrade completed successfully.", fg="green")
+        else:
+            click.secho("Database upgrade failed.", fg="red")
+            sys.exit(1)
+
+    except Exception as e:
+        click.secho(f"Unexpected error: {str(e)}", fg="red")
+        sys.exit(1)
+
+
+@db.command()
+@click.option(
+    "--schema", help="Schema name to operate on (defaults to R2R_PROJECT_NAME)"
+)
+@click.option("--revision", help="Downgrade to a specific revision")
+async def downgrade(schema, revision):
+    """Downgrade database schema to the previous revision or a specific revision."""
+    if not revision:
+        if not click.confirm(
+            "No revision specified. This will downgrade the database by one revision. Continue?"
+        ):
+            return
+
+    try:
+        db_url = get_database_url_from_env(log=False)
+        if not await check_database_connection(db_url):
+            click.secho(
+                "Database connection failed. Please check your environment variables.",
+                fg="red",
+            )
+            sys.exit(1)
+
+        click.echo(
+            f"Running database downgrade for schema {schema or 'default'}..."
+        )
+        command = f"downgrade {revision}" if revision else "downgrade"
+        result = await run_alembic_command(command, schema_name=schema)
+
+        if result == 0:
+            click.secho(
+                "Database downgrade completed successfully.", fg="green"
+            )
+        else:
+            click.secho("Database downgrade failed.", fg="red")
+            sys.exit(1)
+
+    except Exception as e:
+        click.secho(f"Unexpected error: {str(e)}", fg="red")
+        sys.exit(1)

+ 393 - 0
cli/commands/documents.py

@@ -0,0 +1,393 @@
+import json
+import os
+import tempfile
+import uuid
+from urllib.parse import urlparse
+
+import asyncclick as click
+import requests
+from asyncclick import pass_context
+
+from cli.utils.param_types import JSON
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def documents():
+    """Documents commands."""
+    pass
+
+
+@documents.command()
+@click.argument(
+    "file_paths", nargs=-1, required=True, type=click.Path(exists=True)
+)
+@click.option("--ids", multiple=True, help="Document IDs for ingestion")
+@click.option(
+    "--metadatas", type=JSON, help="Metadatas for ingestion as a JSON string"
+)
+@click.option(
+    "--run-without-orchestration", is_flag=True, help="Run with orchestration"
+)
+@pass_context
+async def create(ctx, file_paths, ids, metadatas, run_without_orchestration):
+    """Ingest files into R2R."""
+    client: R2RAsyncClient = ctx.obj
+    run_with_orchestration = not run_without_orchestration
+    responses = []
+
+    for idx, file_path in enumerate(file_paths):
+        with timer():
+            current_id = [ids[idx]] if ids and idx < len(ids) else None
+            current_metadata = (
+                metadatas[idx] if metadatas and idx < len(metadatas) else None
+            )
+
+            click.echo(
+                f"Processing file {idx + 1}/{len(file_paths)}: {file_path}"
+            )
+            response = await client.documents.create(
+                file_path=file_path,
+                metadata=current_metadata,
+                id=current_id,
+                run_with_orchestration=run_with_orchestration,
+            )
+            responses.append(response)
+            click.echo(json.dumps(response, indent=2))
+            click.echo("-" * 40)
+
+    click.echo(f"\nProcessed {len(responses)} files successfully.")
+
+
+@documents.command()
+@click.argument("file_path", required=True, type=click.Path(exists=True))
+@click.option("--id", required=True, help="Existing document ID to update")
+@click.option(
+    "--metadata", type=JSON, help="Metadatas for ingestion as a JSON string"
+)
+@click.option(
+    "--run-without-orchestration", is_flag=True, help="Run with orchestration"
+)
+@pass_context
+async def update(ctx, file_path, id, metadata, run_without_orchestration):
+    """Update an existing file in R2R."""
+    client: R2RAsyncClient = ctx.obj
+    run_with_orchestration = not run_without_orchestration
+    responses = []
+
+    with timer():
+        click.echo(f"Updating file {id}: {file_path}")
+        response = await client.documents.update(
+            file_path=file_path,
+            metadata=metadata,
+            id=id,
+            run_with_orchestration=run_with_orchestration,
+        )
+        responses.append(response)
+        click.echo(json.dumps(response, indent=2))
+        click.echo("-" * 40)
+
+    click.echo(f"Updated file {id} file successfully.")
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def retrieve(ctx, id):
+    """Retrieve a document by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.retrieve(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def delete(ctx, id):
+    """Delete a document by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.delete(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_chunks(ctx, id, offset, limit):
+    """List collections for a specific document."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.list_chunks(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_collections(ctx, id, offset, limit):
+    """List collections for a specific document."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.list_collections(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+# TODO
+async def ingest_files_from_urls(client, urls):
+    """Download and ingest files from given URLs."""
+    files_to_ingest = []
+    metadatas = []
+    document_ids = []
+    temp_files = []
+
+    try:
+        for url in urls:
+            filename = os.path.basename(urlparse(url).path)
+            is_pdf = filename.lower().endswith(".pdf")
+
+            temp_file = tempfile.NamedTemporaryFile(
+                mode="wb" if is_pdf else "w+",
+                delete=False,
+                suffix=f"_{filename}",
+            )
+            temp_files.append(temp_file)
+
+            response = requests.get(url)
+            response.raise_for_status()
+            if is_pdf:
+                temp_file.write(response.content)
+            else:
+                temp_file.write(response.text)
+            temp_file.close()
+
+            files_to_ingest.append(temp_file.name)
+            metadatas.append({"title": filename})
+            # TODO: use the utils function generate_document_id
+            document_ids.append(str(uuid.uuid5(uuid.NAMESPACE_DNS, url)))
+
+        for it, file in enumerate(files_to_ingest):
+            click.echo(f"Ingesting file: {file}")
+            response = await client.documents.create(
+                file, metadata=metadatas[it], id=document_ids[it]
+            )
+
+        return response["results"]
+    finally:
+        # Clean up temporary files
+        for temp_file in temp_files:
+            os.unlink(temp_file.name)
+
+
+# Missing CLI Commands
+@documents.command()
+@click.argument("id", required=True, type=str)
+@click.option("--run-type", help="Extraction run type (estimate or run)")
+@click.option("--settings", type=JSON, help="Extraction settings as JSON")
+@click.option(
+    "--run-without-orchestration",
+    is_flag=True,
+    help="Run without orchestration",
+)
+@pass_context
+async def extract(ctx, id, run_type, settings, run_without_orchestration):
+    """Extract entities and relationships from a document."""
+    client: R2RAsyncClient = ctx.obj
+    run_with_orchestration = not run_without_orchestration
+
+    with timer():
+        response = await client.documents.extract(
+            id=id,
+            run_type=run_type,
+            settings=settings,
+            run_with_orchestration=run_with_orchestration,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of items to return. Defaults to 100.",
+)
+@click.option(
+    "--include-embeddings",
+    is_flag=True,
+    help="Include embeddings in response",
+)
+@pass_context
+async def list_entities(ctx, id, offset, limit, include_embeddings):
+    """List entities extracted from a document."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.list_entities(
+            id=id,
+            offset=offset,
+            limit=limit,
+            include_embeddings=include_embeddings,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of items to return. Defaults to 100.",
+)
+@click.option(
+    "--entity-names",
+    multiple=True,
+    help="Filter by entity names",
+)
+@click.option(
+    "--relationship-types",
+    multiple=True,
+    help="Filter by relationship types",
+)
+@pass_context
+async def list_relationships(
+    ctx, id, offset, limit, entity_names, relationship_types
+):
+    """List relationships extracted from a document."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.documents.list_relationships(
+            id=id,
+            offset=offset,
+            limit=limit,
+            entity_names=list(entity_names) if entity_names else None,
+            relationship_types=(
+                list(relationship_types) if relationship_types else None
+            ),
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@documents.command()
+@click.option(
+    "--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)"
+)
+@click.option(
+    "--v3", is_flag=True, help="use aristotle_v3.txt (a larger file)"
+)
+@pass_context
+async def create_sample(ctx, v2=True, v3=False):
+    """Ingest the first sample file into R2R."""
+    sample_file_url = f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt"
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await ingest_files_from_urls(client, [sample_file_url])
+    click.echo(
+        f"Sample file ingestion completed. Ingest files response:\n\n{response}"
+    )
+
+
+@documents.command()
+@pass_context
+async def create_samples(ctx):
+    """Ingest multiple sample files into R2R."""
+    client: R2RAsyncClient = ctx.obj
+    urls = [
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_3.html",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_4.html",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_5.html",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/lyft_2021.pdf",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/uber_2021.pdf",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/got.txt",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_1.html",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_2.html",
+        "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt",
+    ]
+    with timer():
+        response = await ingest_files_from_urls(client, urls)
+
+    click.echo(
+        f"Sample files ingestion completed. Ingest files response:\n\n{response}"
+    )
+
+
+@documents.command()
+@click.option("--ids", multiple=True, help="Document IDs to fetch")
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, ids, offset, limit):
+    """Get an overview of documents."""
+    client: R2RAsyncClient = ctx.obj
+    ids = list(ids) if ids else None
+
+    with timer():
+        response = await client.documents.list(
+            ids=ids,
+            offset=offset,
+            limit=limit,
+        )
+
+    for document in response["results"]:
+        click.echo(document)

+ 382 - 0
cli/commands/graphs.py

@@ -0,0 +1,382 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.param_types import JSON
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def graphs():
+    """Graphs commands."""
+    pass
+
+
+@graphs.command()
+@click.option(
+    "--collection-ids", multiple=True, help="Collection IDs to fetch"
+)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of graphs to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, collection_ids, offset, limit):
+    """List available graphs."""
+    client: R2RAsyncClient = ctx.obj
+    collection_ids = list(collection_ids) if collection_ids else None
+
+    with timer():
+        response = await client.graphs.list(
+            collection_ids=collection_ids,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@pass_context
+async def retrieve(ctx, collection_id):
+    """Retrieve a specific graph by collection ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.retrieve(collection_id=collection_id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@pass_context
+async def reset(ctx, collection_id):
+    """Reset a graph, removing all its data."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.reset(collection_id=collection_id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.option("--name", help="New name for the graph")
+@click.option("--description", help="New description for the graph")
+@pass_context
+async def update(ctx, collection_id, name, description):
+    """Update graph information."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.update(
+            collection_id=collection_id,
+            name=name,
+            description=description,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of entities to return. Defaults to 100.",
+)
+@pass_context
+async def list_entities(ctx, collection_id, offset, limit):
+    """List entities in a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.list_entities(
+            collection_id=collection_id,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("entity_id", required=True, type=str)
+@pass_context
+async def get_entity(ctx, collection_id, entity_id):
+    """Get entity information from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.get_entity(
+            collection_id=collection_id,
+            entity_id=entity_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("entity_id", required=True, type=str)
+@pass_context
+async def remove_entity(ctx, collection_id, entity_id):
+    """Remove an entity from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.remove_entity(
+            collection_id=collection_id,
+            entity_id=entity_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of relationships to return. Defaults to 100.",
+)
+@pass_context
+async def list_relationships(ctx, collection_id, offset, limit):
+    """List relationships in a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.list_relationships(
+            collection_id=collection_id,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("relationship_id", required=True, type=str)
+@pass_context
+async def get_relationship(ctx, collection_id, relationship_id):
+    """Get relationship information from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.get_relationship(
+            collection_id=collection_id,
+            relationship_id=relationship_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("relationship_id", required=True, type=str)
+@pass_context
+async def remove_relationship(ctx, collection_id, relationship_id):
+    """Remove a relationship from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.remove_relationship(
+            collection_id=collection_id,
+            relationship_id=relationship_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.option(
+    "--settings", required=True, type=JSON, help="Build settings as JSON"
+)
+@click.option("--run-type", default="estimate", help="Type of build to run")
+@click.option(
+    "--run-without-orchestration",
+    is_flag=True,
+    help="Run without orchestration",
+)
+@pass_context
+async def build(
+    ctx, collection_id, settings, run_type, run_without_orchestration
+):
+    """Build a graph with specified settings."""
+    client: R2RAsyncClient = ctx.obj
+    run_with_orchestration = not run_without_orchestration
+
+    with timer():
+        response = await client.graphs.build(
+            collection_id=collection_id,
+            settings=settings,
+            run_type=run_type,
+            run_with_orchestration=run_with_orchestration,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of communities to return. Defaults to 100.",
+)
+@pass_context
+async def list_communities(ctx, collection_id, offset, limit):
+    """List communities in a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.list_communities(
+            collection_id=collection_id,
+            offset=offset,
+            limit=limit,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("community_id", required=True, type=str)
+@pass_context
+async def get_community(ctx, collection_id, community_id):
+    """Get community information from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.get_community(
+            collection_id=collection_id,
+            community_id=community_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("community_id", required=True, type=str)
+@click.option("--name", help="New name for the community")
+@click.option("--summary", help="New summary for the community")
+@click.option(
+    "--findings",
+    type=JSON,
+    help="New findings for the community as JSON array",
+)
+@click.option("--rating", type=int, help="New rating for the community")
+@click.option(
+    "--rating-explanation", help="New rating explanation for the community"
+)
+@click.option("--level", type=int, help="New level for the community")
+@click.option(
+    "--attributes", type=JSON, help="New attributes for the community as JSON"
+)
+@pass_context
+async def update_community(
+    ctx,
+    collection_id,
+    community_id,
+    name,
+    summary,
+    findings,
+    rating,
+    rating_explanation,
+    level,
+    attributes,
+):
+    """Update community information."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.update_community(
+            collection_id=collection_id,
+            community_id=community_id,
+            name=name,
+            summary=summary,
+            findings=findings,
+            rating=rating,
+            rating_explanation=rating_explanation,
+            level=level,
+            attributes=attributes,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("community_id", required=True, type=str)
+@pass_context
+async def delete_community(ctx, collection_id, community_id):
+    """Delete a community from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.delete_community(
+            collection_id=collection_id,
+            community_id=community_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@pass_context
+async def pull(ctx, collection_id):
+    """Pull documents into a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.pull(collection_id=collection_id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@graphs.command()
+@click.argument("collection_id", required=True, type=str)
+@click.argument("document_id", required=True, type=str)
+@pass_context
+async def remove_document(ctx, collection_id, document_id):
+    """Remove a document from a graph."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.graphs.remove_document(
+            collection_id=collection_id,
+            document_id=document_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))

+ 89 - 0
cli/commands/indices.py

@@ -0,0 +1,89 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def indices():
+    """Indices commands."""
+    pass
+
+
+@indices.command()
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, offset, limit):
+    """Get an overview of indices."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.indices.list(
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@indices.command()
+@click.argument("index_name", required=True, type=str)
+@click.argument("table_name", required=True, type=str)
+@pass_context
+async def retrieve(ctx, index_name, table_name):
+    """Retrieve an index by name."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.indices.retrieve(
+            index_name=index_name,
+            table_name=table_name,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@indices.command()
+@click.argument("index_name", required=True, type=str)
+@click.argument("table_name", required=True, type=str)
+@pass_context
+async def delete(ctx, index_name, table_name):
+    """Delete an index by name."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.indices.retrieve(
+            index_name=index_name,
+            table_name=table_name,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@indices.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def list_branches(ctx, id):
+    """List all branches in a conversation."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.indices.list_branches(
+            id=id,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))

+ 60 - 0
cli/commands/prompts.py

@@ -0,0 +1,60 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def prompts():
+    """Prompts commands."""
+    pass
+
+
+@prompts.command()
+@pass_context
+async def list(ctx):
+    """Get an overview of prompts."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.prompts.list()
+
+    for prompt in response["results"]:
+        click.echo(json.dumps(prompt, indent=2))
+
+
+@prompts.command()
+@click.argument("name", type=str)
+@click.option("--inputs", default=None, type=str)
+@click.option("--prompt-override", default=None, type=str)
+@pass_context
+async def retrieve(ctx, name, inputs, prompt_override):
+    """Retrieve an prompts by name."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.prompts.retrieve(
+            name=name,
+            inputs=inputs,
+            prompt_override=prompt_override,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@prompts.command()
+@click.argument("name", required=True, type=str)
+@pass_context
+async def delete(ctx, name):
+    """Delete an index by name."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.prompts.delete(
+            name=name,
+        )
+
+    click.echo(json.dumps(response, indent=2))

+ 189 - 0
cli/commands/retrieval.py

@@ -0,0 +1,189 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.param_types import JSON
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def retrieval():
+    """Retrieval commands."""
+    pass
+
+
+@retrieval.command()
+@click.option(
+    "--query", prompt="Enter your search query", help="The search query"
+)
+@click.option(
+    "--limit", default=None, help="Number of search results to return"
+)
+@click.option(
+    "--use-hybrid-search",
+    default=None,
+    help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
+)
+@click.option(
+    "--use-semantic-search", default=None, help="Perform semantic search?"
+)
+@click.option(
+    "--use-fulltext-search", default=None, help="Perform fulltext search?"
+)
+@click.option(
+    "--filters",
+    type=JSON,
+    help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
+)
+@click.option(
+    "--search-strategy",
+    type=str,
+    default="vanilla",
+    help="Vanilla RAG or complex method like query fusion or HyDE.",
+)
+@click.option(
+    "--graph-search-enabled", default=None, help="Use knowledge graph search?"
+)
+@click.option(
+    "--chunk-search-enabled",
+    default=None,
+    help="Use search over document chunks?",
+)
+@pass_context
+async def search(ctx, query, **kwargs):
+    """Perform a search query."""
+    client: R2RAsyncClient = ctx.obj
+    search_settings = {
+        k: v
+        for k, v in kwargs.items()
+        if k
+        in [
+            "filters",
+            "limit",
+            "search_strategy",
+            "use_hybrid_search",
+            "use_semantic_search",
+            "use_fulltext_search",
+            "search_strategy",
+        ]
+        and v is not None
+    }
+    graph_search_enabled = kwargs.get("graph_search_enabled")
+    if graph_search_enabled != None:
+        search_settings["graph_settings"] = {"enabled": graph_search_enabled}
+
+    chunk_search_enabled = kwargs.get("chunk_search_enabled")
+    if chunk_search_enabled != None:
+        search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
+
+    with timer():
+        results = await client.retrieval.search(
+            query,
+            "custom",
+            search_settings,
+        )
+
+        if isinstance(results, dict) and "results" in results:
+            results = results["results"]
+
+        if "chunk_search_results" in results:
+            click.echo("Vector search results:")
+            for result in results["chunk_search_results"]:
+                click.echo(json.dumps(result, indent=2))
+
+        if (
+            "graph_search_results" in results
+            and results["graph_search_results"]
+        ):
+            click.echo("KG search results:")
+            for result in results["graph_search_results"]:
+                click.echo(json.dumps(result, indent=2))
+
+
+@retrieval.command()
+@click.option(
+    "--query", prompt="Enter your search query", help="The search query"
+)
+@click.option(
+    "--limit", default=None, help="Number of search results to return"
+)
+@click.option(
+    "--use-hybrid-search",
+    default=None,
+    help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
+)
+@click.option(
+    "--use-semantic-search", default=None, help="Perform semantic search?"
+)
+@click.option(
+    "--use-fulltext-search", default=None, help="Perform fulltext search?"
+)
+@click.option(
+    "--filters",
+    type=JSON,
+    help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
+)
+@click.option(
+    "--search-strategy",
+    type=str,
+    default="vanilla",
+    help="Vanilla RAG or complex method like query fusion or HyDE.",
+)
+@click.option(
+    "--graph-search-enabled", default=None, help="Use knowledge graph search?"
+)
+@click.option(
+    "--chunk-search-enabled",
+    default=None,
+    help="Use search over document chunks?",
+)
+@click.option("--stream", is_flag=True, help="Stream the RAG response")
+@click.option("--rag-model", default=None, help="Model for RAG")
+@pass_context
+async def rag(ctx, query, **kwargs):
+    """Perform a RAG query."""
+    client: R2RAsyncClient = ctx.obj
+    rag_generation_config = {
+        "stream": kwargs.get("stream", False),
+    }
+    if kwargs.get("rag_model"):
+        rag_generation_config["model"] = kwargs["rag_model"]
+
+    search_settings = {
+        k: v
+        for k, v in kwargs.items()
+        if k
+        in [
+            "filters",
+            "limit",
+            "search_strategy",
+            "use_hybrid_search",
+            "use_semantic_search",
+            "use_fulltext_search",
+            "search_strategy",
+        ]
+        and v is not None
+    }
+    graph_search_enabled = kwargs.get("graph_search_enabled")
+    if graph_search_enabled != None:
+        search_settings["graph_settings"] = {"enabled": graph_search_enabled}
+
+    chunk_search_enabled = kwargs.get("chunk_search_enabled")
+    if chunk_search_enabled != None:
+        search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
+
+    with timer():
+        response = await client.retrieval.rag(
+            query=query,
+            rag_generation_config=rag_generation_config,
+            search_settings={**search_settings},
+        )
+
+        if rag_generation_config.get("stream"):
+            async for chunk in response:
+                click.echo(chunk, nl=False)
+            click.echo()
+        else:
+            click.echo(json.dumps(response["results"]["completion"], indent=2))

+ 435 - 0
cli/commands/system.py

@@ -0,0 +1,435 @@
+import json
+import os
+import platform
+import subprocess
+import sys
+from importlib.metadata import version as get_version
+
+import asyncclick as click
+from asyncclick import pass_context
+from dotenv import load_dotenv
+
+from cli.command_group import cli
+from cli.utils.docker_utils import (
+    bring_down_docker_compose,
+    remove_r2r_network,
+    run_docker_serve,
+    run_local_serve,
+    wait_for_container_health,
+)
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def system():
+    """System commands."""
+    pass
+
+
+@cli.command()
+@pass_context
+async def health(ctx):
+    """Check the health of the server."""
+    client: R2RAsyncClient = ctx.obj
+    with timer():
+        response = await client.system.health()
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@system.command()
+@click.option("--run-type-filter", help="Filter for log types")
+@click.option(
+    "--offset", default=None, help="Pagination offset. Default is None."
+)
+@click.option(
+    "--limit", default=None, help="Pagination limit. Defaults to 100."
+)
+@pass_context
+async def logs(ctx, run_type_filter, offset, limit):
+    """Retrieve logs with optional type filter."""
+    client: R2RAsyncClient = ctx.obj
+    with timer():
+        response = await client.system.logs(
+            run_type_filter=run_type_filter,
+            offset=offset,
+            limit=limit,
+        )
+
+    for log in response["results"]:
+        click.echo(f"Run ID: {log['run_id']}")
+        click.echo(f"Run Type: {log['run_type']}")
+        click.echo(f"Timestamp: {log['timestamp']}")
+        click.echo(f"User ID: {log['user_id']}")
+        click.echo("Entries:")
+        for entry in log["entries"]:
+            click.echo(f"  - {entry['key']}: {entry['value'][:100]}")
+        click.echo("---")
+
+    click.echo(f"Total runs: {len(response['results'])}")
+
+
+@system.command()
+@pass_context
+async def settings(ctx):
+    """Retrieve application settings."""
+    client: R2RAsyncClient = ctx.obj
+    with timer():
+        response = await client.system.settings()
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@system.command()
+@pass_context
+async def status(ctx):
+    """Get statistics about the server, including the start time, uptime, CPU usage, and memory usage."""
+    client: R2RAsyncClient = ctx.obj
+    with timer():
+        response = await client.system.status()
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@cli.command()
+@click.option("--host", default=None, help="Host to run the server on")
+@click.option(
+    "--port", default=None, type=int, help="Port to run the server on"
+)
+@click.option("--docker", is_flag=True, help="Run using Docker")
+@click.option(
+    "--full",
+    is_flag=True,
+    help="Run the full R2R compose? This includes Hatchet and Unstructured.",
+)
+@click.option(
+    "--project-name", default=None, help="Project name for Docker deployment"
+)
+@click.option(
+    "--config-name", default=None, help="Name of the R2R configuration to use"
+)
+@click.option(
+    "--config-path",
+    default=None,
+    help="Path to a custom R2R configuration file",
+)
+@click.option(
+    "--build",
+    is_flag=True,
+    default=False,
+    help="Run in debug mode. Only for development.",
+)
+@click.option("--image", help="Docker image to use")
+@click.option(
+    "--image-env",
+    default="prod",
+    help="Which dev environment to pull the image from?",
+)
+@click.option(
+    "--scale",
+    default=None,
+    help="How many instances of the R2R service to run",
+)
+@click.option(
+    "--exclude-postgres",
+    is_flag=True,
+    default=False,
+    help="Excludes creating a Postgres container in the Docker setup.",
+)
+async def serve(
+    host,
+    port,
+    docker,
+    full,
+    project_name,
+    config_name,
+    config_path,
+    build,
+    image,
+    image_env,
+    scale,
+    exclude_postgres,
+):
+    """Start the R2R server."""
+    load_dotenv()
+    click.echo("Spinning up an R2R deployment...")
+
+    if host is None:
+        host = os.getenv("R2R_HOST", "0.0.0.0")
+
+    if port is None:
+        port = int(os.getenv("R2R_PORT", (os.getenv("PORT", "7272"))))
+
+    click.echo(f"Running on {host}:{port}, with docker={docker}")
+
+    if full:
+        click.echo(
+            "Running the full R2R setup which includes `Hatchet` and `Unstructured.io`."
+        )
+
+    if config_path and config_name:
+        raise click.UsageError(
+            "Both `config-path` and `config-name` were provided. Please provide only one."
+        )
+    if config_name and os.path.isfile(config_name):
+        click.echo(
+            "Warning: `config-name` corresponds to an existing file. If you intended a custom config, use `config-path`."
+        )
+
+    if build:
+        click.echo(
+            "`build` flag detected. Building Docker image from local repository..."
+        )
+    if image and image_env:
+        click.echo(
+            "WARNING: Both `image` and `image_env` were provided. Using `image`."
+        )
+    if not image and docker:
+        r2r_version = get_version("r2r")
+
+        version_specific_image = f"ragtoriches/{image_env}:{r2r_version}"
+        latest_image = f"ragtoriches/{image_env}:latest"
+
+        def image_exists(img):
+            try:
+                subprocess.run(
+                    ["docker", "manifest", "inspect", img],
+                    check=True,
+                    capture_output=True,
+                    text=True,
+                )
+                return True
+            except subprocess.CalledProcessError:
+                return False
+
+        if image_exists(version_specific_image):
+            click.echo(f"Using image: {version_specific_image}")
+            image = version_specific_image
+        elif image_exists(latest_image):
+            click.echo(
+                f"Version-specific image not found. Using latest: {latest_image}"
+            )
+            image = latest_image
+        else:
+            click.echo(
+                f"Neither {version_specific_image} nor {latest_image} found in remote registry. Confirm the sanity of your output for `docker manifest inspect ragtoriches/{version_specific_image}` and  `docker manifest inspect ragtoriches/{latest_image}`."
+            )
+            click.echo(
+                "Please pull the required image or build it using the --build flag."
+            )
+            raise click.Abort()
+
+    if docker:
+        os.environ["R2R_IMAGE"] = image
+
+    if build:
+        subprocess.run(
+            ["docker", "build", "-t", image, "-f", "Dockerfile", "."],
+            check=True,
+        )
+
+    if config_path:
+        config_path = os.path.abspath(config_path)
+
+        # For Windows, convert backslashes to forward slashes and prepend /host_mnt/
+        if platform.system() == "Windows":
+            drive, path = os.path.splitdrive(config_path)
+            config_path = f"/host_mnt/{drive[0].lower()}" + path.replace(
+                "\\", "/"
+            )
+
+    if docker:
+        run_docker_serve(
+            host,
+            port,
+            full,
+            project_name,
+            image,
+            config_name,
+            config_path,
+            exclude_postgres,
+            scale,
+        )
+        if (
+            "pytest" in sys.modules
+            or "unittest" in sys.modules
+            or os.environ.get("PYTEST_CURRENT_TEST")
+        ):
+            click.echo("Test environment detected. Skipping browser open.")
+        else:
+            # Open browser after Docker setup is complete
+            import webbrowser
+
+            click.echo("Waiting for all services to become healthy...")
+            if not wait_for_container_health(
+                project_name or ("r2r-full" if full else "r2r"), "r2r"
+            ):
+                click.secho(
+                    "r2r container failed to become healthy.", fg="red"
+                )
+                return
+
+            traefik_port = os.environ.get("R2R_DASHBOARD_PORT", "80")
+            url = f"http://localhost:{traefik_port}"
+
+            click.secho(f"Navigating to R2R application at {url}.", fg="blue")
+            webbrowser.open(url)
+    else:
+        await run_local_serve(host, port, config_name, config_path, full)
+
+
+@cli.command()
+@click.option(
+    "--volumes",
+    is_flag=True,
+    help="Remove named volumes declared in the `volumes` section of the Compose file",
+)
+@click.option(
+    "--remove-orphans",
+    is_flag=True,
+    help="Remove containers for services not defined in the Compose file",
+)
+@click.option(
+    "--project-name",
+    default=None,
+    help="Which Docker Compose project to bring down",
+)
+def docker_down(volumes, remove_orphans, project_name):
+    """Bring down the Docker Compose setup and attempt to remove the network if necessary."""
+
+    if not project_name:
+        print("Bringing down the default R2R Docker setup(s)...")
+        try:
+            result = bring_down_docker_compose(
+                project_name or "r2r", volumes, remove_orphans
+            )
+        except:
+            pass
+        try:
+            result = bring_down_docker_compose(
+                project_name or "r2r-full", volumes, remove_orphans
+            )
+        except:
+            pass
+    else:
+        print(f"Bringing down the `{project_name}` R2R Docker setup...")
+        result = bring_down_docker_compose(
+            project_name, volumes, remove_orphans
+        )
+
+        if result != 0:
+            click.echo(
+                f"An error occurred while bringing down the {project_name} Docker Compose setup. Attempting to remove the network..."
+            )
+        else:
+            click.echo(
+                f"{project_name} Docker Compose setup has been successfully brought down."
+            )
+    remove_r2r_network()
+
+
+@cli.command()
+def generate_report():
+    """Generate a system report including R2R version, Docker info, and OS details."""
+
+    # Get R2R version
+    from importlib.metadata import version
+
+    report = {"r2r_version": version("r2r")}
+
+    # Get Docker info
+    try:
+        subprocess.run(
+            ["docker", "version"], check=True, capture_output=True, timeout=5
+        )
+
+        docker_ps_output = subprocess.check_output(
+            ["docker", "ps", "--format", "{{.ID}}\t{{.Names}}\t{{.Status}}"],
+            text=True,
+            timeout=5,
+        ).strip()
+        report["docker_ps"] = [
+            dict(zip(["id", "name", "status"], line.split("\t")))
+            for line in docker_ps_output.split("\n")
+            if line
+        ]
+
+        docker_network_output = subprocess.check_output(
+            ["docker", "network", "ls", "--format", "{{.ID}}\t{{.Name}}"],
+            text=True,
+            timeout=5,
+        ).strip()
+        networks = [
+            dict(zip(["id", "name"], line.split("\t")))
+            for line in docker_network_output.split("\n")
+            if line
+        ]
+
+        report["docker_subnets"] = []
+        for network in networks:
+            inspect_output = subprocess.check_output(
+                [
+                    "docker",
+                    "network",
+                    "inspect",
+                    network["id"],
+                    "--format",
+                    "{{range .IPAM.Config}}{{.Subnet}}{{end}}",
+                ],
+                text=True,
+                timeout=5,
+            ).strip()
+            if subnet := inspect_output:
+                network["subnet"] = subnet
+                report["docker_subnets"].append(network)
+
+    except subprocess.CalledProcessError as e:
+        report["docker_error"] = f"Error running Docker command: {e}"
+    except FileNotFoundError:
+        report["docker_error"] = (
+            "Docker command not found. Is Docker installed and in PATH?"
+        )
+    except subprocess.TimeoutExpired:
+        report["docker_error"] = (
+            "Docker command timed out. Docker might be unresponsive."
+        )
+
+    # Get OS information
+    report["os_info"] = {
+        "system": platform.system(),
+        "release": platform.release(),
+        "version": platform.version(),
+        "machine": platform.machine(),
+        "processor": platform.processor(),
+    }
+
+    click.echo("System Report:")
+    click.echo(json.dumps(report, indent=2))
+
+
+@cli.command()
+def update():
+    """Update the R2R package to the latest version."""
+    try:
+        cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "r2r"]
+
+        click.echo("Updating R2R...")
+        result = subprocess.run(
+            cmd, check=True, capture_output=True, text=True
+        )
+        click.echo(result.stdout)
+        click.echo("R2R has been successfully updated.")
+    except subprocess.CalledProcessError as e:
+        click.echo(f"An error occurred while updating R2R: {e}")
+        click.echo(e.stderr)
+    except Exception as e:
+        click.echo(f"An unexpected error occurred: {e}")
+
+
+@cli.command()
+def version():
+    """Reports the SDK version."""
+    from importlib.metadata import version
+
+    click.echo(json.dumps(version("r2r"), indent=2))

+ 143 - 0
cli/commands/users.py

@@ -0,0 +1,143 @@
+import json
+
+import asyncclick as click
+from asyncclick import pass_context
+
+from cli.utils.timer import timer
+from r2r import R2RAsyncClient
+
+
+@click.group()
+def users():
+    """Users commands."""
+    pass
+
+
+@users.command()
+@click.argument("email", required=True, type=str)
+@click.argument("password", required=True, type=str)
+@pass_context
+async def create(ctx, email, password):
+    """Create a new user."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.create(email=email, password=password)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@users.command()
+@click.option("--ids", multiple=True, help="Document IDs to fetch")
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list(ctx, ids, offset, limit):
+    """Get an overview of users."""
+    client: R2RAsyncClient = ctx.obj
+    ids = list(ids) if ids else None
+
+    with timer():
+        response = await client.users.list(
+            ids=ids,
+            offset=offset,
+            limit=limit,
+        )
+
+    for user in response["results"]:
+        click.echo(json.dumps(user, indent=2))
+
+
+@users.command()
+@click.argument("id", required=True, type=str)
+@pass_context
+async def retrieve(ctx, id):
+    """Retrieve a user by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.retrieve(id=id)
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@users.command()
+@pass_context
+async def me(ctx):
+    """Retrieve the current user."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.me()
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@users.command()
+@click.argument("id", required=True, type=str)
+@click.option(
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
+)
+@click.option(
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
+)
+@pass_context
+async def list_collections(ctx, id, offset, limit):
+    """List collections for a specific user."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.list_collections(
+            id=id,
+            offset=offset,
+            limit=limit,
+        )
+
+    for collection in response["results"]:
+        click.echo(json.dumps(collection, indent=2))
+
+
+@users.command()
+@click.argument("id", required=True, type=str)
+@click.argument("collection_id", required=True, type=str)
+@pass_context
+async def add_to_collection(ctx, id, collection_id):
+    """Retrieve a user by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.add_to_collection(
+            id=id,
+            collection_id=collection_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))
+
+
+@users.command()
+@click.argument("id", required=True, type=str)
+@click.argument("collection_id", required=True, type=str)
+@pass_context
+async def remove_from_collection(ctx, id, collection_id):
+    """Retrieve a user by ID."""
+    client: R2RAsyncClient = ctx.obj
+
+    with timer():
+        response = await client.users.remove_from_collection(
+            id=id,
+            collection_id=collection_id,
+        )
+
+    click.echo(json.dumps(response, indent=2))

+ 61 - 0
cli/main.py

@@ -0,0 +1,61 @@
+from cli.command_group import cli
+from cli.commands import (
+    collections,
+    conversations,
+    database,
+    documents,
+    graphs,
+    indices,
+    prompts,
+    retrieval,
+    system,
+    users,
+)
+from cli.utils.telemetry import posthog, telemetry
+
+
+def add_command_with_telemetry(command):
+    cli.add_command(telemetry(command))
+
+
+# Chunks
+add_command_with_telemetry(collections.collections)
+add_command_with_telemetry(conversations.conversations)
+add_command_with_telemetry(documents.documents)
+add_command_with_telemetry(graphs.graphs)
+
+# Graph
+add_command_with_telemetry(indices.indices)
+add_command_with_telemetry(prompts.prompts)
+add_command_with_telemetry(retrieval.retrieval)
+add_command_with_telemetry(users.users)
+add_command_with_telemetry(system.system)
+
+
+# Database
+add_command_with_telemetry(database.db)
+add_command_with_telemetry(database.upgrade)
+add_command_with_telemetry(database.downgrade)
+add_command_with_telemetry(database.current)
+add_command_with_telemetry(database.history)
+
+
+def main():
+    try:
+        cli()
+    except SystemExit:
+        # Silently exit without printing the traceback
+        pass
+    except Exception as e:
+        # Handle other exceptions if needed
+        print("CLI error: An error occurred")
+        raise e
+    finally:
+        # Ensure all events are flushed before exiting
+        if posthog:
+            posthog.flush()
+            posthog.shutdown()
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
cli/utils/__init__.py


+ 225 - 0
cli/utils/database_utils.py

@@ -0,0 +1,225 @@
+import logging.config
+import os
+import sys
+from pathlib import Path
+from typing import Optional
+
+import alembic.config
+import asyncclick as click
+from alembic import command as alembic_command
+from sqlalchemy import create_engine, text
+from sqlalchemy.exc import OperationalError
+
+
+def get_default_db_vars() -> dict[str, str]:
+    """Get default database environment variables."""
+    return {
+        "R2R_POSTGRES_HOST": "localhost",
+        "R2R_POSTGRES_PORT": "5432",
+        "R2R_POSTGRES_DBNAME": "postgres",
+        "R2R_POSTGRES_USER": "postgres",
+        "R2R_POSTGRES_PASSWORD": "postgres",
+        "R2R_PROJECT_NAME": "r2r_default",
+    }
+
+
+def get_schema_version_table(schema_name: str) -> str:
+    """Get the schema-specific version of alembic_version table name."""
+    return f"{schema_name}_alembic_version"
+
+
+def get_database_url_from_env(log: bool = True) -> str:
+    """Construct database URL from environment variables."""
+    env_vars = {
+        k: os.environ.get(k, v) for k, v in get_default_db_vars().items()
+    }
+
+    if log:
+        for k, v in env_vars.items():
+            click.secho(
+                f"Using value for {k}: {v}",
+                fg="yellow" if v == get_default_db_vars()[k] else "green",
+            )
+
+    return (
+        f"postgresql://{env_vars['R2R_POSTGRES_USER']}:{env_vars['R2R_POSTGRES_PASSWORD']}"
+        f"@{env_vars['R2R_POSTGRES_HOST']}:{env_vars['R2R_POSTGRES_PORT']}"
+        f"/{env_vars['R2R_POSTGRES_DBNAME']}"
+    )
+
+
+def ensure_schema_exists(engine, schema_name: str):
+    """Create schema if it doesn't exist and set up schema-specific version table."""
+    with engine.begin() as conn:
+        # Create schema if it doesn't exist
+        conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name}"))
+
+        # Move or create alembic_version table in the specific schema
+        version_table = get_schema_version_table(schema_name)
+        conn.execute(
+            text(
+                f"""
+            CREATE TABLE IF NOT EXISTS {schema_name}.{version_table} (
+                version_num VARCHAR(32) NOT NULL
+            )
+        """
+            )
+        )
+
+
+def check_current_revision(engine, schema_name: str) -> Optional[str]:
+    """Check the current revision in the version table."""
+    version_table = get_schema_version_table(schema_name)
+    with engine.connect() as conn:
+        result = conn.execute(
+            text(f"SELECT version_num FROM {schema_name}.{version_table}")
+        ).fetchone()
+        return result[0] if result else None
+
+
+async def check_database_connection(db_url: str) -> bool:
+    """Check if we can connect to the database."""
+    try:
+        engine = create_engine(db_url)
+        with engine.connect():
+            return True
+    except OperationalError as e:
+        click.secho(f"Could not connect to database: {str(e)}", fg="red")
+        if "Connection refused" in str(e):
+            click.secho(
+                "Make sure PostgreSQL is running and accessible with the provided credentials.",
+                fg="yellow",
+            )
+        return False
+    except Exception as e:
+        click.secho(
+            f"Unexpected error checking database connection: {str(e)}",
+            fg="red",
+        )
+        return False
+
+
+def create_schema_config(
+    project_root: Path, schema_name: str, db_url: str
+) -> alembic.config.Config:
+    """Create an Alembic config for a specific schema."""
+    config = alembic.config.Config()
+
+    # Calculate the path to the migrations folder
+    current_file = Path(__file__)
+    migrations_path = current_file.parent.parent.parent / "migrations"
+
+    if not migrations_path.exists():
+        raise FileNotFoundError(
+            f"Migrations folder not found at {migrations_path}"
+        )
+
+    # Set basic options
+    config.set_main_option("script_location", str(migrations_path))
+    config.set_main_option("sqlalchemy.url", db_url)
+
+    # Set schema-specific version table
+    version_table = get_schema_version_table(schema_name)
+    config.set_main_option("version_table", version_table)
+    config.set_main_option("version_table_schema", schema_name)
+
+    return config
+
+
+def setup_alembic_logging():
+    """Set up logging configuration for Alembic."""
+    # Reset existing loggers to prevent duplication
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+
+    logging_config = {
+        "version": 1,
+        "formatters": {
+            "generic": {
+                "format": "%(levelname)s [%(name)s] %(message)s",
+                "datefmt": "%H:%M:%S",
+            },
+        },
+        "handlers": {
+            "console": {
+                "class": "logging.StreamHandler",
+                "formatter": "generic",
+                "stream": sys.stderr,
+            },
+        },
+        "loggers": {
+            "alembic": {
+                "level": "INFO",
+                "handlers": ["console"],
+                "propagate": False,  # Prevent propagation to root logger
+            },
+            "sqlalchemy": {
+                "level": "WARN",
+                "handlers": ["console"],
+                "propagate": False,  # Prevent propagation to root logger
+            },
+        },
+        "root": {
+            "level": "WARN",
+            "handlers": ["console"],
+        },
+    }
+    logging.config.dictConfig(logging_config)
+
+
+async def run_alembic_command(
+    command_name: str,
+    project_root: Optional[Path] = None,
+    schema_name: Optional[str] = None,
+) -> int:
+    """Run an Alembic command with schema awareness."""
+    try:
+        if project_root is None:
+            project_root = Path(__file__).parent.parent.parent
+
+        if schema_name is None:
+            schema_name = os.environ.get("R2R_PROJECT_NAME", "r2r_default")
+
+        # Set up logging
+        setup_alembic_logging()
+
+        # Get database URL and create engine
+        db_url = get_database_url_from_env()
+        engine = create_engine(db_url)
+
+        # Ensure schema exists and has version table
+        ensure_schema_exists(engine, schema_name)
+
+        # Create schema-specific config
+        config = create_schema_config(project_root, schema_name, db_url)
+
+        click.secho(f"\nRunning command for schema: {schema_name}", fg="blue")
+
+        # Execute the command
+        if command_name == "current":
+            current_rev = check_current_revision(engine, schema_name)
+            if current_rev:
+                click.secho(f"Current revision: {current_rev}", fg="green")
+            else:
+                click.secho("No migrations applied yet.", fg="yellow")
+            alembic_command.current(config)
+        elif command_name == "history":
+            alembic_command.history(config)
+        elif command_name.startswith("upgrade"):
+            revision = "head"
+            if " " in command_name:
+                _, revision = command_name.split(" ", 1)
+            alembic_command.upgrade(config, revision)
+        elif command_name.startswith("downgrade"):
+            revision = "-1"
+            if " " in command_name:
+                _, revision = command_name.split(" ", 1)
+            alembic_command.downgrade(config, revision)
+        else:
+            raise ValueError(f"Unsupported command: {command_name}")
+
+        return 0
+
+    except Exception as e:
+        click.secho(f"Error running migration command: {str(e)}", fg="red")
+        return 1

+ 578 - 0
cli/utils/docker_utils.py

@@ -0,0 +1,578 @@
+import ipaddress
+import json
+import os
+import re
+import socket
+import subprocess
+import sys
+import time
+from typing import Optional
+
+import asyncclick as click
+import requests
+from requests.exceptions import RequestException
+
+
+def bring_down_docker_compose(project_name, volumes, remove_orphans):
+    compose_files = get_compose_files()
+    if project_name == "r2r":
+        docker_command = f"docker compose -f {compose_files['base']}"
+    elif project_name == "r2r-full":
+        docker_command = f"docker compose -f {compose_files['full']}"
+    else:
+        docker_command = f"docker compose  -f {compose_files['full']}"
+
+    docker_command += f" --project-name {project_name}"
+    docker_command += " --profile postgres"
+
+    if volumes:
+        docker_command += " --volumes"
+
+    if remove_orphans:
+        docker_command += " --remove-orphans"
+
+    docker_command += " down"
+
+    click.echo(
+        f"Bringing down {project_name} Docker Compose setup with command {docker_command}..."
+    )
+    return os.system(docker_command)
+
+
+def remove_r2r_network():
+    networks = (
+        subprocess.check_output(
+            ["docker", "network", "ls", "--format", "{{.Name}}"]
+        )
+        .decode()
+        .split()
+    )
+
+    r2r_network = next(
+        (
+            network
+            for network in networks
+            if network.startswith("r2r") and "network" in network
+        ),
+        None,
+    )
+
+    if not r2r_network:
+        click.echo("Could not find the r2r network to remove.")
+        return
+
+    for _ in range(2):  # Try twice
+        remove_command = f"docker network rm {r2r_network}"
+        if os.system(remove_command) == 0:
+            click.echo(f"Successfully removed network: {r2r_network}")
+            return
+        click.echo(
+            f"Failed to remove network: {r2r_network}. Retrying in 5 seconds..."
+        )
+        time.sleep(5)
+
+    click.echo(
+        "Failed to remove the network after multiple attempts. Please try the following steps:\n"
+        "1. Run 'docker ps' to check for any running containers using this network.\n"
+        "2. Stop any running containers with 'docker stop <container_id>'.\n"
+        f"3. Try removing the network manually with 'docker network rm {r2r_network}'.\n"
+        "4. If the above steps don't work, you may need to restart the Docker daemon."
+    )
+
+
+async def run_local_serve(
+    host: str,
+    port: int,
+    config_name: Optional[str] = None,
+    config_path: Optional[str] = None,
+    full: bool = False,
+) -> None:
+    try:
+        from core import R2RBuilder, R2RConfig
+    except ImportError as e:
+        click.echo(
+            "Error: You must install the `r2r core` package to run the R2R server locally."
+        )
+        raise e
+
+    if config_path and config_name:
+        raise ValueError("Cannot specify both config_path and config_name")
+    if not config_path and not config_name:
+        config_name = "full" if full else "default"
+
+    r2r_instance = await R2RBuilder(
+        config=R2RConfig.load(config_name, config_path)
+    ).build()
+
+    if config_name or config_path:
+        completion_config = r2r_instance.config.completion
+        llm_provider = completion_config.provider
+        llm_model = completion_config.generation_config.model
+        model_provider = llm_model.split("/")[0]
+        check_llm_reqs(llm_provider, model_provider)
+
+    click.echo("R2R now runs on port 7272 by default!")
+    available_port = find_available_port(port)
+
+    await r2r_instance.orchestration_provider.start_worker()
+
+    # TODO: make this work with autoreload, currently due to hatchet, it causes a reload error
+    # import uvicorn
+    # uvicorn.run(
+    #     "core.main.app_entry:app", host=host, port=available_port, reload=False
+    # )
+
+    await r2r_instance.serve(host, available_port)
+
+
+def run_docker_serve(
+    host: str,
+    port: int,
+    full: bool,
+    project_name: str,
+    image: str,
+    config_name: Optional[str] = None,
+    config_path: Optional[str] = None,
+    exclude_postgres: bool = False,
+    scale: Optional[int] = None,
+):
+    check_docker_compose_version()
+    check_set_docker_env_vars(project_name, exclude_postgres)
+
+    if config_path and config_name:
+        raise ValueError("Cannot specify both config_path and config_name")
+
+    no_conflict, message = check_subnet_conflict()
+    if not no_conflict:
+        click.secho(f"Warning: {message}", fg="red", bold=True)
+        click.echo("This may cause issues when starting the Docker setup.")
+        if not click.confirm("Do you want to continue?", default=True):
+            click.echo("Aborting Docker setup.")
+            return
+
+    compose_files = get_compose_files()
+    pull_command, up_command = build_docker_command(
+        compose_files,
+        host,
+        port,
+        full,
+        project_name,
+        image,
+        config_name,
+        config_path,
+        exclude_postgres,
+        scale,
+    )
+
+    click.secho("R2R now runs on port 7272 by default!", fg="yellow")
+    click.echo("Pulling Docker images...")
+    click.echo(f"Calling `{pull_command}`")
+    os.system(pull_command)
+
+    click.echo("Starting Docker Compose setup...")
+    click.echo(f"Calling `{up_command}`")
+    os.system(up_command)
+
+
+def check_llm_reqs(llm_provider, model_provider):
+    providers = {
+        "openai": {"env_vars": ["OPENAI_API_KEY"]},
+        "anthropic": {"env_vars": ["ANTHROPIC_API_KEY"]},
+        "azure": {
+            "env_vars": [
+                "AZURE_API_KEY",
+                "AZURE_API_BASE",
+                "AZURE_API_VERSION",
+            ]
+        },
+        "vertex": {
+            "env_vars": [
+                "GOOGLE_APPLICATION_CREDENTIALS",
+                "VERTEX_PROJECT",
+                "VERTEX_LOCATION",
+            ]
+        },
+        "bedrock": {
+            "env_vars": [
+                "AWS_ACCESS_KEY_ID",
+                "AWS_SECRET_ACCESS_KEY",
+                "AWS_REGION_NAME",
+            ]
+        },
+        "groq": {"env_vars": ["GROQ_API_KEY"]},
+        "cohere": {"env_vars": ["COHERE_API_KEY"]},
+        "anyscale": {"env_vars": ["ANYSCALE_API_KEY"]},
+    }
+
+    for provider, config in providers.items():
+        if llm_provider == provider or model_provider == provider:
+            if missing_vars := [
+                var for var in config["env_vars"] if not os.environ.get(var)
+            ]:
+                message = f"You have specified `{provider}` as a default LLM provider, but the following environment variables are missing: {', '.join(missing_vars)}. Would you like to continue?"
+                if not click.confirm(message, default=False):
+                    click.echo("Aborting Docker setup.")
+                    sys.exit(1)
+
+    if model_provider == "ollama":
+        check_external_ollama()
+
+
+def check_external_ollama(ollama_url="http://localhost:11434/api/version"):
+    try:
+        response = requests.get(ollama_url, timeout=5)
+        if response.status_code == 200:
+            click.echo("External Ollama instance detected and responsive.")
+        else:
+            warning_text = click.style("Warning:", fg="red", bold=True)
+            click.echo(
+                f"{warning_text} External Ollama instance returned unexpected status code: {response.status_code}"
+            )
+            if not click.confirm(
+                "Do you want to continue without Ollama connection?",
+                default=False,
+            ):
+                click.echo("Aborting Docker setup.")
+                sys.exit(1)
+    except RequestException as e:
+        warning_text = click.style("Warning:", fg="red", bold=True)
+        click.echo(
+            f"{warning_text} Unable to connect to external Ollama instance. Error: {e}"
+        )
+        click.echo(
+            "Please ensure Ollama is running externally if you've excluded it from Docker and plan on running Local LLMs."
+        )
+        if not click.confirm(
+            "Do you want to continue without confirming an `Ollama` connection?",
+            default=False,
+        ):
+            click.echo("Aborting Docker setup.")
+            sys.exit(1)
+
+
+def check_set_docker_env_vars(
+    project_name: str, exclude_postgres: bool = False
+):
+    env_vars = {"R2R_PROJECT_NAME": "r2r_default"}
+    if project_name:
+        if os.environ.get("R2R_PROJECT_NAME"):
+            warning_text = click.style("Warning:", fg="red", bold=True)
+            prompt = f"{warning_text} You have set R2R_PROJECT_NAME in your environment. Do you want to override it with '{project_name}'?"
+            if not click.confirm(prompt, default=False):
+                project_name = os.environ["R2R_PROJECT_NAME"]
+        else:
+            env_vars["R2R_PROJECT_NAME"] = project_name
+
+    if not exclude_postgres:
+        env_vars |= {
+            "R2R_POSTGRES_HOST": "postgres",
+            "R2R_POSTGRES_PORT": "5432",
+            "R2R_POSTGRES_DBNAME": "postgres",
+            "R2R_POSTGRES_USER": "postgres",
+            "R2R_POSTGRES_PASSWORD": "postgres",
+        }
+
+    # Mapping of old variables to new variables
+    old_to_new_vars = {
+        "POSTGRES_HOST": "R2R_POSTGRES_HOST",
+        "POSTGRES_PORT": "R2R_POSTGRES_PORT",
+        "POSTGRES_DBNAME": "R2R_POSTGRES_DBNAME",
+        "POSTGRES_USER": "R2R_POSTGRES_USER",
+        "POSTGRES_PASSWORD": "R2R_POSTGRES_PASSWORD",
+    }
+
+    # Check for old variables and warn if found
+    for old_var, new_var in old_to_new_vars.items():
+        if old_var in os.environ:
+            warning_text = click.style("Warning:", fg="yellow", bold=True)
+            click.echo(
+                f"{warning_text} The environment variable {old_var} is deprecated and support for it will be removed in release 3.5.0. Please use {new_var} instead."
+            )
+
+    is_test = (
+        "pytest" in sys.modules
+        or "unittest" in sys.modules
+        or os.environ.get("PYTEST_CURRENT_TEST")
+    )
+
+    if not is_test:
+        for var in env_vars:
+            if value := os.environ.get(var):
+                warning_text = click.style("Warning:", fg="red", bold=True)
+
+                if value == env_vars[var]:
+                    continue
+
+                prompt = (
+                    f"{warning_text} It's only necessary to set this environment variable when connecting to an instance not managed by R2R.\n"
+                    f"Environment variable {var} is set to '{value}'. Unset it?"
+                )
+                if click.confirm(prompt, default=True):
+                    os.environ[var] = ""
+                    click.echo(f"Unset {var}")
+                else:
+                    click.echo(f"Kept {var}")
+
+
+def get_compose_files():
+    package_dir = os.path.join(
+        os.path.dirname(os.path.abspath(__file__)),
+        "..",
+        "..",
+    )
+    compose_files = {
+        "base": os.path.join(package_dir, "compose.yaml"),
+        "full": os.path.join(package_dir, "compose.full.yaml"),
+        "full_scale": os.path.join(
+            package_dir, "compose.full_with_replicas.yaml"
+        ),
+    }
+
+    for name, path in compose_files.items():
+        if not os.path.exists(path):
+            click.echo(
+                f"Error: Docker Compose file {name} not found at {path}"
+            )
+            sys.exit(1)
+
+    return compose_files
+
+
+def find_available_port(start_port: int):
+    port = start_port
+    while True:
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            if s.connect_ex(("localhost", port)) != 0:
+                if port != start_port:
+                    click.secho(
+                        f"Warning: Port {start_port} is in use. Using {port}",
+                        fg="red",
+                        bold=True,
+                    )
+                return port
+            port += 1
+
+
+def build_docker_command(
+    compose_files,
+    host,
+    port,
+    full,
+    project_name,
+    image,
+    config_name,
+    config_path,
+    exclude_postgres,
+    scale,
+):
+    if not full:
+        base_command = f"docker compose -f {compose_files['base']}"
+    else:
+        if not scale:
+            base_command = f"docker compose -f {compose_files['full']}"
+        else:
+            base_command = f"docker compose -f {compose_files['full_scale']}"
+
+    print("base_command = ", base_command)
+    base_command += (
+        f" --project-name {project_name or ('r2r-full' if full else 'r2r')}"
+    )
+
+    # Find available ports
+    r2r_dashboard_port = port + 1
+    hatchet_dashboard_port = r2r_dashboard_port + 1
+
+    os.environ["R2R_DASHBOARD_PORT"] = str(r2r_dashboard_port)
+    os.environ["HATCHET_DASHBOARD_PORT"] = str(hatchet_dashboard_port)
+    os.environ["R2R_IMAGE"] = image or ""
+
+    if config_name is not None:
+        os.environ["R2R_CONFIG_NAME"] = config_name
+    elif config_path:
+        os.environ["R2R_CONFIG_PATH"] = (
+            os.path.abspath(config_path) if config_path else ""
+        )
+    elif full:
+        os.environ["R2R_CONFIG_NAME"] = "full"
+
+    if not exclude_postgres:
+        pull_command = f"{base_command} --profile postgres pull"
+        up_command = f"{base_command} --profile postgres up -d"
+        if scale:
+            up_command += f" --scale r2r={scale}"
+    else:
+        pull_command = f"{base_command} pull"
+        up_command = f"{base_command} up -d"
+
+    return pull_command, up_command
+
+
+def check_subnet_conflict():
+    r2r_subnet = ipaddress.ip_network("172.28.0.0/16")
+
+    try:
+        networks_output = subprocess.check_output(
+            ["docker", "network", "ls", "--format", "{{json .}}"]
+        ).decode("utf-8")
+        networks = [
+            json.loads(line)
+            for line in networks_output.splitlines()
+            if line.strip()
+        ]
+
+        for network in networks:
+            network_id = network["ID"]
+            network_name = network["Name"]
+
+            if network_name == "r2r-network":
+                continue
+
+            try:
+                network_info_output = subprocess.check_output(
+                    ["docker", "network", "inspect", network_id]
+                ).decode("utf-8")
+
+                network_info = json.loads(network_info_output)
+
+                if (
+                    not network_info
+                    or not isinstance(network_info, list)
+                    or len(network_info) == 0
+                ):
+                    continue
+
+                network_data = network_info[0]
+                if "IPAM" in network_data and "Config" in network_data["IPAM"]:
+                    ipam_config = network_data["IPAM"]["Config"]
+                    if ipam_config is None:
+                        continue
+                    for config in ipam_config:
+                        if "Subnet" in config:
+                            existing_subnet = ipaddress.ip_network(
+                                config["Subnet"]
+                            )
+                            if r2r_subnet.overlaps(existing_subnet):
+                                return (
+                                    False,
+                                    f"Subnet conflict detected with network '{network_name}' using subnet {existing_subnet}",
+                                )
+            except subprocess.CalledProcessError as e:
+                click.echo(f"Error inspecting network {network_name}: {e}")
+            except json.JSONDecodeError as e:
+                click.echo(
+                    f"Error parsing network info for {network_name}: {e}"
+                )
+            except Exception as e:
+                click.echo(
+                    f"Unexpected error inspecting network {network_name}: {e}"
+                )
+
+        return True, "No subnet conflicts detected"
+    except subprocess.CalledProcessError as e:
+        return False, f"Error checking Docker networks: {e}"
+    except json.JSONDecodeError as e:
+        return False, f"Error parsing Docker network information: {e}"
+    except Exception as e:
+        return False, f"Unexpected error while checking Docker networks: {e}"
+
+
+def check_docker_compose_version():
+    try:
+        version_output = (
+            subprocess.check_output(
+                ["docker", "compose", "version"], stderr=subprocess.STDOUT
+            )
+            .decode("utf-8")
+            .strip()
+        )
+
+        version_match = re.search(r"v?(\d+\.\d+\.\d+)", version_output)
+        if not version_match:
+            raise ValueError(f"Unexpected version format: {version_output}")
+
+        compose_version = version_match[1]
+        min_version = "2.25.0"
+
+        # 2.29.6 throws an `invalid mount config` https://github.com/docker/compose/issues/12139
+        incompatible_versions = ["2.29.6"]
+
+        if parse_version(compose_version) < parse_version(min_version):
+            click.secho(
+                f"Warning: Docker Compose version {compose_version} is outdated. "
+                f"Please upgrade to version {min_version} or higher.",
+                fg="yellow",
+                bold=True,
+            )
+        elif compose_version in incompatible_versions:
+            click.secho(
+                f"Warning: Docker Compose version {compose_version} is known to be incompatible."
+                f"Please upgrade to a newer version.",
+                fg="red",
+                bold=True,
+            )
+
+        return True
+
+    except subprocess.CalledProcessError as e:
+        click.secho(
+            f"Error: Docker Compose is not installed or not working properly. "
+            f"Error message: {e.output.decode('utf-8').strip()}",
+            fg="red",
+            bold=True,
+        )
+    except Exception as e:
+        click.secho(
+            f"Error checking Docker Compose version: {e}",
+            fg="red",
+            bold=True,
+        )
+
+    return False
+
+
+def parse_version(version_string):
+    parts = version_string.split(".")
+    if len(parts) != 3:
+        raise ValueError("Invalid version format")
+    try:
+        return tuple(map(int, parts))
+    except ValueError as e:
+        raise ValueError("Invalid version format") from e
+
+
+def wait_for_container_health(project_name, service_name, timeout=300):
+    container_name = f"{project_name}-{service_name}-1"
+    end_time = time.time() + timeout
+
+    while time.time() < end_time:
+        try:
+            result = subprocess.run(
+                ["docker", "inspect", container_name],
+                capture_output=True,
+                text=True,
+                check=True,
+            )
+            container_info = json.loads(result.stdout)[0]
+
+            health_status = (
+                container_info["State"].get("Health", {}).get("Status")
+            )
+            if health_status == "healthy":
+                return True
+            if health_status is None:
+                click.echo(
+                    f"{service_name} does not have a health check defined."
+                )
+                return True
+
+        except subprocess.CalledProcessError:
+            click.echo(f"Error checking health of {service_name}")
+        except (json.JSONDecodeError, IndexError):
+            click.echo(
+                "Error parsing Docker inspect output or container not found"
+            )
+
+        time.sleep(5)
+
+    click.echo(f"Timeout waiting for {service_name} to be healthy.")
+    return False

+ 21 - 0
cli/utils/param_types.py

@@ -0,0 +1,21 @@
+import json
+from typing import Any, Optional
+
+import asyncclick as click
+
+
+class JsonParamType(click.ParamType):
+    name = "json"
+
+    def convert(self, value, param, ctx) -> Optional[dict[str, Any]]:
+        if value is None:
+            return None
+        if isinstance(value, dict):
+            return value
+        try:
+            return json.loads(value)
+        except json.JSONDecodeError:
+            self.fail(f"'{value}' is not a valid JSON string", param, ctx)
+
+
+JSON = JsonParamType()

+ 152 - 0
cli/utils/telemetry.py

@@ -0,0 +1,152 @@
+import asyncio
+import functools
+import inspect
+import os
+import threading
+import uuid
+from importlib.metadata import version
+from typing import Optional
+
+import asyncclick as click
+from posthog import Posthog
+
+TELEMETRY_DISABLED = (
+    os.getenv("R2R_CLI_DISABLE_TELEMETRY", "false").lower() == "true"
+)
+
+posthog: Optional[Posthog] = None
+
+if not TELEMETRY_DISABLED:
+    posthog = Posthog(
+        project_api_key="phc_OPBbibOIErCGc4NDLQsOrMuYFTKDmRwXX6qxnTr6zpU",
+        host="https://us.i.posthog.com",
+    )
+    posthog.debug = True
+
+
+def telemetry(command):
+    if TELEMETRY_DISABLED or posthog is None:
+        # Return the command unmodified
+        return command
+
+    original_callback = command.callback
+    is_async = inspect.iscoroutinefunction(original_callback)
+
+    if is_async:
+
+        @functools.wraps(original_callback)
+        async def tracked_callback(*args, **kwargs):
+            command_name = command.name
+
+            # Extract context from args[0] if it's a Click Context
+            if args and isinstance(args[0], click.Context):
+                ctx = args[0]
+                command_args = ctx.args
+                command_params = ctx.params
+            else:
+                ctx = None
+                command_args = []
+                command_params = {}
+
+            distinct_id = str(uuid.uuid4())
+
+            try:
+                # Await the original async callback
+                result = await original_callback(*args, **kwargs)
+
+                # Run PostHog capture in a separate thread to avoid blocking
+                await asyncio.to_thread(
+                    posthog.capture,
+                    distinct_id=distinct_id,
+                    event="cli_command",
+                    properties={
+                        "command": command_name,
+                        "status": "success",
+                        "args": command_args,
+                        "params": command_params,
+                        "version": version("r2r"),
+                    },
+                )
+
+                return result
+            except Exception as e:
+                await asyncio.to_thread(
+                    posthog.capture,
+                    distinct_id=distinct_id,
+                    event="cli_command",
+                    properties={
+                        "command": command_name,
+                        "status": "error",
+                        "error_type": type(e).__name__,
+                        "error_message": str(e),
+                        "args": command_args,
+                        "params": command_params,
+                        "version": version("r2r"),
+                    },
+                )
+                raise
+
+    else:
+
+        @functools.wraps(original_callback)
+        def tracked_callback(*args, **kwargs):
+            command_name = command.name
+
+            # Extract context from args[0] if it's a Click Context
+            if args and isinstance(args[0], click.Context):
+                ctx = args[0]
+                command_args = ctx.args
+                command_params = ctx.params
+            else:
+                ctx = None
+                command_args = []
+                command_params = {}
+
+            distinct_id = str(uuid.uuid4())
+
+            try:
+                result = original_callback(*args, **kwargs)
+
+                # Run PostHog capture in a separate thread to avoid blocking
+                thread = threading.Thread(
+                    target=posthog.capture,
+                    args=(
+                        distinct_id,
+                        "cli_command",
+                        {
+                            "command": command_name,
+                            "status": "success",
+                            "args": command_args,
+                            "params": command_params,
+                            "version": version("r2r"),
+                        },
+                    ),
+                    daemon=True,
+                )
+                thread.start()
+
+                return result
+            except Exception as e:
+                # Run PostHog capture in a separate thread to avoid blocking
+                thread = threading.Thread(
+                    target=posthog.capture,
+                    args=(
+                        distinct_id,
+                        "cli_command",
+                        {
+                            "command": command_name,
+                            "status": "error",
+                            "error_type": type(e).__name__,
+                            "error_message": str(e),
+                            "args": command_args,
+                            "params": command_params,
+                            "version": version("r2r"),
+                        },
+                    ),
+                    daemon=True,
+                )
+                thread.start()
+                raise
+
+    command.callback = tracked_callback
+    return command

+ 16 - 0
cli/utils/timer.py

@@ -0,0 +1,16 @@
+"""
+A timer context manager to measure the time taken to execute each command in the CLI.
+"""
+
+import time
+from contextlib import contextmanager
+
+import asyncclick as click
+
+
+@contextmanager
+def timer():
+    start = time.time()
+    yield
+    end = time.time()
+    click.echo(f"Time taken: {end - start:.2f} seconds")

+ 400 - 0
compose.full.yaml

@@ -0,0 +1,400 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  hatchet_certs:
+    name: ${VOLUME_HATCHET_CERTS:-hatchet_certs}
+  hatchet_config:
+    name: ${VOLUME_HATCHET_CONFIG:-hatchet_config}
+  hatchet_api_key:
+    name: ${VOLUME_HATCHET_API_KEY:-hatchet_api_key}
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+  hatchet_rabbitmq_data:
+    name: ${VOLUME_HATCHET_RABBITMQ_DATA:-hatchet_rabbitmq_data}
+  hatchet_rabbitmq_conf:
+    name: ${VOLUME_HATCHET_RABBITMQ_CONF:-hatchet_rabbitmq_conf}
+  hatchet_postgres_data:
+    name: ${VOLUME_HATCHET_POSTGRES_DATA:-hatchet_postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  hatchet-postgres:
+    image: postgres:latest
+    environment:
+      POSTGRES_DB: ${HATCHET_POSTGRES_DBNAME:-hatchet}
+      POSTGRES_USER: ${HATCHET_POSTGRES_USER:-hatchet_user}
+      POSTGRES_PASSWORD: ${HATCHET_POSTGRES_PASSWORD:-hatchet_password}
+    volumes:
+      - hatchet_postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${HATCHET_POSTGRES_USER:-hatchet_user} -d ${HATCHET_POSTGRES_DBNAME:-hatchet}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+
+  hatchet-rabbitmq:
+    image: "rabbitmq:3-management"
+    hostname: "hatchet-rabbitmq"
+    ports:
+      - "${R2R_RABBITMQ_PORT:-5673}:5672"
+      - "${R2R_RABBITMQ_MGMT_PORT:-15673}:15672"
+    environment:
+      RABBITMQ_DEFAULT_USER: "user"
+      RABBITMQ_DEFAULT_PASS: "password"
+    volumes:
+      - hatchet_rabbitmq_data:/var/lib/rabbitmq
+      - hatchet_rabbitmq_conf:/etc/rabbitmq/rabbitmq.conf
+    healthcheck:
+      test: ["CMD", "rabbitmqctl", "status"]
+      interval: 10s
+      timeout: 10s
+      retries: 5
+    networks:
+      - r2r-network
+
+  hatchet-create-db:
+    image: postgres:latest
+    command: >
+      sh -c "
+        set -e
+        echo 'Waiting for PostgreSQL to be ready...'
+        while ! pg_isready -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user}; do
+          sleep 1
+        done
+        echo 'PostgreSQL is ready, checking if database exists...'
+        if ! PGPASSWORD=${HATCHET_POSTGRES_PASSWORD:-hatchet_password} psql -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user} -lqt | grep -qw ${HATCHET_POSTGRES_DBNAME:-hatchet}; then
+          echo 'Database does not exist, creating it...'
+          PGPASSWORD=${HATCHET_POSTGRES_PASSWORD:-hatchet_password} createdb -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user} -w ${HATCHET_POSTGRES_DBNAME:-hatchet}
+        else
+          echo 'Database already exists, skipping creation.'
+        fi
+      "
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    networks:
+      - r2r-network
+
+  hatchet-migration:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-migrate:latest
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    depends_on:
+      hatchet-create-db:
+        condition: service_completed_successfully
+    networks:
+      - r2r-network
+
+  hatchet-setup-config:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-admin:latest
+    command: /hatchet/hatchet-admin quickstart --skip certs --generated-config-dir /hatchet/config --overwrite=false
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+
+      HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH: "${HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH:-134217728}"
+      HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH: "${HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH:-134217728}"
+
+      DATABASE_POSTGRES_PORT: "5432"
+      DATABASE_POSTGRES_HOST: hatchet-postgres
+      DATABASE_POSTGRES_USERNAME: "${HATCHET_POSTGRES_USER:-hatchet_user}"
+      DATABASE_POSTGRES_PASSWORD: "${HATCHET_POSTGRES_PASSWORD:-hatchet_password}"
+      HATCHET_DATABASE_POSTGRES_DB_NAME: "${HATCHET_POSTGRES_DBNAME:-hatchet}"
+
+      SERVER_TASKQUEUE_RABBITMQ_URL: amqp://user:password@hatchet-rabbitmq:5672/
+      SERVER_AUTH_COOKIE_DOMAIN: "http://host.docker.internal:${R2R_HATCHET_DASHBOARD_PORT:-7274}"
+      SERVER_URL: "http://host.docker.internal:${R2R_HATCHET_DASHBOARD_PORT:-7274}"
+      SERVER_AUTH_COOKIE_INSECURE: "t"
+      SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
+      SERVER_GRPC_INSECURE: "t"
+      SERVER_GRPC_BROADCAST_ADDRESS: "hatchet-engine:7077"
+      SERVER_GRPC_MAX_MSG_SIZE: 134217728
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    depends_on:
+      hatchet-migration:
+        condition: service_completed_successfully
+      hatchet-rabbitmq:
+        condition: service_healthy
+    networks:
+      - r2r-network
+
+  hatchet-engine:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-engine:latest
+    command: /hatchet/hatchet-engine --config /hatchet/config
+    restart: on-failure
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+    ports:
+      - "${R2R_HATCHET_ENGINE_PORT:-7077}:7077"
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+      SERVER_GRPC_BROADCAST_ADDRESS: "hatchet-engine:7077"
+      SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
+      SERVER_GRPC_PORT: "7077"
+      SERVER_GRPC_INSECURE: "t"
+      SERVER_GRPC_MAX_MSG_SIZE: 134217728
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "wget", "-q", "-O", "-", "http://localhost:8733/live"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  hatchet-dashboard:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-dashboard:latest
+    command: sh ./entrypoint.sh --config /hatchet/config
+    restart: on-failure
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_HATCHET_DASHBOARD_PORT:-7274}:80"
+
+  setup-token:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-admin:latest
+    command: >
+      sh -c "
+        set -e
+        echo 'Starting token creation process...'
+
+        # Attempt to create token and capture both stdout and stderr
+        TOKEN_OUTPUT=$$(/hatchet/hatchet-admin token create --config /hatchet/config --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52 2>&1)
+
+        # Extract the token (assuming it's the only part that looks like a JWT)
+        TOKEN=$$(echo \"$$TOKEN_OUTPUT\" | grep -Eo 'eyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*')
+
+        if [ -z \"$$TOKEN\" ]; then
+          echo 'Error: Failed to extract token. Full command output:' >&2
+          echo \"$$TOKEN_OUTPUT\" >&2
+          exit 1
+        fi
+
+        echo \"$$TOKEN\" > /tmp/hatchet_api_key
+        echo 'Token created and saved to /tmp/hatchet_api_key'
+
+        # Copy token to final destination
+        echo -n \"$$TOKEN\" > /hatchet_api_key/api_key.txt
+        echo 'Token copied to /hatchet_api_key/api_key.txt'
+
+        # Verify token was copied correctly
+        if [ \"$$(cat /tmp/hatchet_api_key)\" != \"$(cat /hatchet_api_key/api_key.txt)\" ]; then
+          echo 'Error: Token copy failed, files do not match' >&2
+          echo 'Content of /tmp/hatchet_api_key:'
+          cat /tmp/hatchet_api_key
+          echo 'Content of /hatchet_api_key/api_key.txt:'
+          cat /hatchet_api_key/api_key.txt
+          exit 1
+        fi
+
+        echo 'Hatchet API key has been saved successfully'
+        echo 'Token length:' $${#TOKEN}
+        echo 'Token (first 20 chars):' $${TOKEN:0:20}
+        echo 'Token structure:' $$(echo $$TOKEN | awk -F. '{print NF-1}') 'parts'
+        # Check each part of the token
+        for i in 1 2 3; do
+          PART=$$(echo $$TOKEN | cut -d. -f$$i)
+          echo 'Part' $$i 'length:' $${#PART}
+          echo 'Part' $$i 'base64 check:' $$(echo $$PART | base64 -d >/dev/null 2>&1 && echo 'Valid' || echo 'Invalid')
+        done
+        # Final validation attempt
+        if ! echo $$TOKEN | awk -F. '{print $$2}' | base64 -d 2>/dev/null | jq . >/dev/null 2>&1; then
+          echo 'Warning: Token payload is not valid JSON when base64 decoded' >&2
+        else
+          echo 'Token payload appears to be valid JSON'
+        fi
+      "
+    networks:
+      - r2r-network
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+      - hatchet_api_key:/hatchet_api_key
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+
+  unstructured:
+    image: ${UNSTRUCTURED_IMAGE:-ragtoriches/unst-prod}
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:7275/health"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  graph_clustering:
+    image: ${GRAPH_CLUSTERING_IMAGE:-ragtoriches/cluster-prod}
+    ports:
+      - "${R2R_GRAPH_CLUSTERING_PORT:-7276}:7276"
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:7276/health"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    ports:
+      - "${R2R_PORT:-7272}:${R2R_PORT:-7272}"
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+      # Huggingface
+      - HUGGINGFACE_API_BASE=${HUGGINGFACE_API_BASE:-http://host.docker.internal:8080}
+      - HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY}
+
+      # Unstructured
+      - UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
+      - UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
+      - UNSTRUCTURED_SERVICE_URL=${UNSTRUCTURED_SERVICE_URL:-http://unstructured:7275}
+      - UNSTRUCTURED_NUM_WORKERS=${UNSTRUCTURED_NUM_WORKERS:-10}
+
+      # Hatchet
+      - HATCHET_CLIENT_TLS_STRATEGY=none
+      - HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH:-134217728}
+      - HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH:-134217728}
+
+      # Graphologic
+      - CLUSTERING_SERVICE_URL=http://graph_clustering:7276
+
+    command: >
+      sh -c '
+        if [ -z "$${HATCHET_CLIENT_TOKEN}" ]; then
+          export HATCHET_CLIENT_TOKEN=$$(cat /hatchet_api_key/api_key.txt)
+        fi
+        exec uvicorn core.main.app_entry:app --host $${R2R_HOST} --port $${R2R_PORT}
+      '
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+      - hatchet_api_key:/hatchet_api_key:ro
+    extra_hosts:
+      - host.docker.internal:host-gateway
+    depends_on:
+      setup-token:
+        condition: service_completed_successfully
+      unstructured:
+        condition: service_healthy
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+      - NEXT_PUBLIC_HATCHET_DASHBOARD_URL=${HATCHET_DASHBOARD_URL:-http://localhost:${R2R_HATCHET_DASHBOARD_PORT:-7274}}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"

+ 420 - 0
compose.full_with_replicas.yaml

@@ -0,0 +1,420 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  hatchet_certs:
+    name: ${VOLUME_HATCHET_CERTS:-hatchet_certs}
+  hatchet_config:
+    name: ${VOLUME_HATCHET_CONFIG:-hatchet_config}
+  hatchet_api_key:
+    name: ${VOLUME_HATCHET_API_KEY:-hatchet_api_key}
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+  hatchet_rabbitmq_data:
+    name: ${VOLUME_HATCHET_RABBITMQ_DATA:-hatchet_rabbitmq_data}
+  hatchet_rabbitmq_conf:
+    name: ${VOLUME_HATCHET_RABBITMQ_CONF:-hatchet_rabbitmq_conf}
+  hatchet_postgres_data:
+    name: ${VOLUME_HATCHET_POSTGRES_DATA:-hatchet_postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  hatchet-postgres:
+    image: postgres:latest
+    environment:
+      POSTGRES_DB: ${HATCHET_POSTGRES_DBNAME:-hatchet}
+      POSTGRES_USER: ${HATCHET_POSTGRES_USER:-hatchet_user}
+      POSTGRES_PASSWORD: ${HATCHET_POSTGRES_PASSWORD:-hatchet_password}
+    volumes:
+      - hatchet_postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${HATCHET_POSTGRES_USER:-hatchet_user} -d ${HATCHET_POSTGRES_DBNAME:-hatchet}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+
+  hatchet-rabbitmq:
+    image: "rabbitmq:3-management"
+    hostname: "hatchet-rabbitmq"
+    ports:
+      - "${R2R_RABBITMQ_PORT:-5673}:5672"
+      - "${R2R_RABBITMQ_MGMT_PORT:-15673}:15672"
+    environment:
+      RABBITMQ_DEFAULT_USER: "user"
+      RABBITMQ_DEFAULT_PASS: "password"
+    volumes:
+      - hatchet_rabbitmq_data:/var/lib/rabbitmq
+      - hatchet_rabbitmq_conf:/etc/rabbitmq/rabbitmq.conf
+    healthcheck:
+      test: ["CMD", "rabbitmqctl", "status"]
+      interval: 10s
+      timeout: 10s
+      retries: 5
+    networks:
+      - r2r-network
+
+  hatchet-create-db:
+    image: postgres:latest
+    command: >
+      sh -c "
+        set -e
+        echo 'Waiting for PostgreSQL to be ready...'
+        while ! pg_isready -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user}; do
+          sleep 1
+        done
+        echo 'PostgreSQL is ready, checking if database exists...'
+        if ! PGPASSWORD=${HATCHET_POSTGRES_PASSWORD:-hatchet_password} psql -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user} -lqt | grep -qw ${HATCHET_POSTGRES_DBNAME:-hatchet}; then
+          echo 'Database does not exist, creating it...'
+          PGPASSWORD=${HATCHET_POSTGRES_PASSWORD:-hatchet_password} createdb -h hatchet-postgres -p 5432 -U ${HATCHET_POSTGRES_USER:-hatchet_user} -w ${HATCHET_POSTGRES_DBNAME:-hatchet}
+        else
+          echo 'Database already exists, skipping creation.'
+        fi
+      "
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    networks:
+      - r2r-network
+
+  hatchet-migration:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-migrate:latest
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    depends_on:
+      hatchet-create-db:
+        condition: service_completed_successfully
+    networks:
+      - r2r-network
+
+  hatchet-setup-config:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-admin:latest
+    command: /hatchet/hatchet-admin quickstart --skip certs --generated-config-dir /hatchet/config --overwrite=false
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+
+      HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH: "${HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH:-134217728}"
+      HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH: "${HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH:-134217728}"
+
+      DATABASE_POSTGRES_PORT: "5432"
+      DATABASE_POSTGRES_HOST: hatchet-postgres
+      DATABASE_POSTGRES_USERNAME: "${HATCHET_POSTGRES_USER:-hatchet_user}"
+      DATABASE_POSTGRES_PASSWORD: "${HATCHET_POSTGRES_PASSWORD:-hatchet_password}"
+      HATCHET_DATABASE_POSTGRES_DB_NAME: "${HATCHET_POSTGRES_DBNAME:-hatchet}"
+
+      SERVER_TASKQUEUE_RABBITMQ_URL: amqp://user:password@hatchet-rabbitmq:5672/
+      SERVER_AUTH_COOKIE_DOMAIN: "http://host.docker.internal:${R2R_HATCHET_DASHBOARD_PORT:-7274}"
+      SERVER_URL: "http://host.docker.internal:${R2R_HATCHET_DASHBOARD_PORT:-7274}"
+      SERVER_AUTH_COOKIE_INSECURE: "t"
+      SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
+      SERVER_GRPC_INSECURE: "t"
+      SERVER_GRPC_BROADCAST_ADDRESS: "hatchet-engine:7077"
+      SERVER_GRPC_MAX_MSG_SIZE: 134217728
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    depends_on:
+      hatchet-migration:
+        condition: service_completed_successfully
+      hatchet-rabbitmq:
+        condition: service_healthy
+    networks:
+      - r2r-network
+
+  hatchet-engine:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-engine:latest
+    command: /hatchet/hatchet-engine --config /hatchet/config
+    restart: on-failure
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+    ports:
+      - "${R2R_HATCHET_ENGINE_PORT:-7077}:7077"
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+      SERVER_GRPC_BROADCAST_ADDRESS: "hatchet-engine:7077"
+      SERVER_GRPC_BIND_ADDRESS: "0.0.0.0"
+      SERVER_GRPC_PORT: "7077"
+      SERVER_GRPC_INSECURE: "t"
+      SERVER_GRPC_MAX_MSG_SIZE: 134217728
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "wget", "-q", "-O", "-", "http://localhost:8733/live"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  hatchet-dashboard:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-dashboard:latest
+    command: sh ./entrypoint.sh --config /hatchet/config
+    restart: on-failure
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+    environment:
+      DATABASE_URL: "postgres://${HATCHET_POSTGRES_USER:-hatchet_user}:${HATCHET_POSTGRES_PASSWORD:-hatchet_password}@hatchet-postgres:5432/${HATCHET_POSTGRES_DBNAME:-hatchet}?sslmode=disable"
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_HATCHET_DASHBOARD_PORT:-7274}:80"
+
+  setup-token:
+    image: ghcr.io/hatchet-dev/hatchet/hatchet-admin:latest
+    command: >
+      sh -c "
+        set -e
+        echo 'Starting token creation process...'
+
+        # Attempt to create token and capture both stdout and stderr
+        TOKEN_OUTPUT=$$(/hatchet/hatchet-admin token create --config /hatchet/config --tenant-id 707d0855-80ab-4e1f-a156-f1c4546cbf52 2>&1)
+
+        # Extract the token (assuming it's the only part that looks like a JWT)
+        TOKEN=$$(echo \"$$TOKEN_OUTPUT\" | grep -Eo 'eyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*')
+
+        if [ -z \"$$TOKEN\" ]; then
+          echo 'Error: Failed to extract token. Full command output:' >&2
+          echo \"$$TOKEN_OUTPUT\" >&2
+          exit 1
+        fi
+
+        echo \"$$TOKEN\" > /tmp/hatchet_api_key
+        echo 'Token created and saved to /tmp/hatchet_api_key'
+
+        # Copy token to final destination
+        echo -n \"$$TOKEN\" > /hatchet_api_key/api_key.txt
+        echo 'Token copied to /hatchet_api_key/api_key.txt'
+
+        # Verify token was copied correctly
+        if [ \"$$(cat /tmp/hatchet_api_key)\" != \"$(cat /hatchet_api_key/api_key.txt)\" ]; then
+          echo 'Error: Token copy failed, files do not match' >&2
+          echo 'Content of /tmp/hatchet_api_key:'
+          cat /tmp/hatchet_api_key
+          echo 'Content of /hatchet_api_key/api_key.txt:'
+          cat /hatchet_api_key/api_key.txt
+          exit 1
+        fi
+
+        echo 'Hatchet API key has been saved successfully'
+        echo 'Token length:' $${#TOKEN}
+        echo 'Token (first 20 chars):' $${TOKEN:0:20}
+        echo 'Token structure:' $$(echo $$TOKEN | awk -F. '{print NF-1}') 'parts'
+        # Check each part of the token
+        for i in 1 2 3; do
+          PART=$$(echo $$TOKEN | cut -d. -f$$i)
+          echo 'Part' $$i 'length:' $${#PART}
+          echo 'Part' $$i 'base64 check:' $$(echo $$PART | base64 -d >/dev/null 2>&1 && echo 'Valid' || echo 'Invalid')
+        done
+        # Final validation attempt
+        if ! echo $$TOKEN | awk -F. '{print $$2}' | base64 -d 2>/dev/null | jq . >/dev/null 2>&1; then
+          echo 'Warning: Token payload is not valid JSON when base64 decoded' >&2
+        else
+          echo 'Token payload appears to be valid JSON'
+        fi
+      "
+    networks:
+      - r2r-network
+    volumes:
+      - hatchet_certs:/hatchet/certs
+      - hatchet_config:/hatchet/config
+      - hatchet_api_key:/hatchet_api_key
+    depends_on:
+      hatchet-setup-config:
+        condition: service_completed_successfully
+
+  unstructured:
+    image: ${UNSTRUCTURED_IMAGE:-ragtoriches/unst-prod}
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:7275/health"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  graph_clustering:
+    image: ${GRAPH_CLUSTERING_IMAGE:-ragtoriches/cluster-prod}
+    ports:
+      - "${R2R_GRAPH_CLUSTERING_PORT:-7276}:7276"
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:7276/health"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+      # Huggingface
+      - HUGGINGFACE_API_BASE=${HUGGINGFACE_API_BASE:-http://host.docker.internal:8080}
+      - HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY}
+
+      # Unstructured
+      - UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
+      - UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
+      - UNSTRUCTURED_SERVICE_URL=${UNSTRUCTURED_SERVICE_URL:-http://unstructured:7275}
+      - UNSTRUCTURED_NUM_WORKERS=${UNSTRUCTURED_NUM_WORKERS:-10}
+
+      # Hatchet
+      - HATCHET_CLIENT_TLS_STRATEGY=none
+      - HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH:-134217728}
+      - HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH:-134217728}
+
+      # Graphologic
+      - CLUSTERING_SERVICE_URL=http://graph_clustering:7276
+
+    command: >
+      sh -c '
+        if [ -z "$${HATCHET_CLIENT_TOKEN}" ]; then
+          export HATCHET_CLIENT_TOKEN=$$(cat /hatchet_api_key/api_key.txt)
+        fi
+        exec uvicorn core.main.app_entry:app --host $${R2R_HOST} --port $${R2R_PORT}
+      '
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+      - hatchet_api_key:/hatchet_api_key:ro
+    extra_hosts:
+      - host.docker.internal:host-gateway
+    depends_on:
+      setup-token:
+        condition: service_completed_successfully
+      unstructured:
+        condition: service_healthy
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+      - NEXT_PUBLIC_HATCHET_DASHBOARD_URL=${HATCHET_DASHBOARD_URL:-http://localhost:${R2R_HATCHET_DASHBOARD_PORT:-7274}}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"
+
+  nginx:
+    image: nginx:latest
+    ports:
+      - "${R2R_NGINX_PORT:-7280}:80"
+    volumes:
+      - ./nginx.conf:/etc/nginx/nginx.conf:ro
+    depends_on:
+      r2r:
+        condition: service_healthy
+    networks:
+      - r2r-network
+    deploy:
+      resources:
+        limits:
+          cpus: '0.5'
+          memory: 512M
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost/health"]
+      interval: 10s
+      timeout: 5s
+      retries: 3

+ 123 - 0
compose.yaml.back

@@ -0,0 +1,123 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    ports:
+      - "${R2R_PORT:-7272}:${R2R_PORT:-7272}"
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+    extra_hosts:
+      - host.docker.internal:host-gateway
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"

+ 232 - 0
core/__init__.py

@@ -0,0 +1,232 @@
+import logging
+
+# Keep '*' imports for enhanced development velocity
+from .agent import *
+from .base import *
+from .database import *
+from .main import *
+from .parsers import *
+from .pipelines import *
+from .pipes import *
+from .providers import *
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+# Create a console handler and set the level to info
+ch = logging.StreamHandler()
+ch.setLevel(logging.INFO)
+
+# Create a formatter and set it for the handler
+formatter = logging.Formatter(
+    "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
+)
+ch.setFormatter(formatter)
+
+# Add the handler to the logger
+logger.addHandler(ch)
+
+# Optional: Prevent propagation to the root logger
+logger.propagate = False
+
+logging.getLogger("httpx").setLevel(logging.WARNING)
+logging.getLogger("LiteLLM").setLevel(logging.WARNING)
+
+
+__all__ = [
+    ## AGENT
+    # Base
+    "R2RAgent",
+    "R2RStreamingAgent",
+    # RAG Agents
+    "R2RRAGAgent",
+    "R2RStreamingRAGAgent",
+    ## BASE
+    # Base abstractions
+    "AsyncSyncMeta",
+    "syncable",
+    # Completion abstractions
+    "MessageType",
+    # Document abstractions
+    "Document",
+    "DocumentChunk",
+    "DocumentResponse",
+    "IngestionStatus",
+    "KGExtractionStatus",
+    "KGEnrichmentStatus",
+    "DocumentType",
+    # Embedding abstractions
+    "EmbeddingPurpose",
+    "default_embedding_prefixes",
+    # Exception abstractions
+    "R2RDocumentProcessingError",
+    "R2RException",
+    # KG abstractions
+    "Entity",
+    "KGExtraction",
+    "Relationship",
+    # LLM abstractions
+    "GenerationConfig",
+    "LLMChatCompletion",
+    "LLMChatCompletionChunk",
+    "RAGCompletion",
+    # Prompt abstractions
+    "Prompt",
+    # Search abstractions
+    "AggregateSearchResult",
+    "WebSearchResponse",
+    "GraphSearchResult",
+    "ChunkSearchSettings",
+    "GraphSearchSettings",
+    "ChunkSearchResult",
+    "SearchSettings",
+    "select_search_filters",
+    "SearchMode",
+    "HybridSearchSettings",
+    # User abstractions
+    "Token",
+    "TokenData",
+    # Vector abstractions
+    "Vector",
+    "VectorEntry",
+    "VectorType",
+    "IndexConfig",
+    ## AGENT
+    # Agent abstractions
+    "Agent",
+    "AgentConfig",
+    "Conversation",
+    "Message",
+    "Tool",
+    "ToolResult",
+    ## API
+    # Auth Responses
+    "TokenResponse",
+    "User",
+    ## LOGGING
+    # Basic types
+    "RunType",
+    # Run Manager
+    "RunManager",
+    "manage_run",
+    ## PARSERS
+    # Base parser
+    "AsyncParser",
+    ## PIPELINE
+    # Base pipeline
+    "AsyncPipeline",
+    ## PIPES
+    "AsyncPipe",
+    "AsyncState",
+    ## PROVIDERS
+    # Base provider classes
+    "AppConfig",
+    "Provider",
+    "ProviderConfig",
+    # Auth provider
+    "AuthConfig",
+    "AuthProvider",
+    # Crypto provider
+    "CryptoConfig",
+    "CryptoProvider",
+    # Email provider
+    "EmailConfig",
+    "EmailProvider",
+    # Database providers
+    "DatabaseConfig",
+    "DatabaseProvider",
+    # Embedding provider
+    "EmbeddingConfig",
+    "EmbeddingProvider",
+    # LLM provider
+    "CompletionConfig",
+    "CompletionProvider",
+    ## UTILS
+    "RecursiveCharacterTextSplitter",
+    "TextSplitter",
+    "run_pipeline",
+    "to_async_generator",
+    "generate_id",
+    "increment_version",
+    "validate_uuid",
+    ## MAIN
+    ## R2R ABSTRACTIONS
+    "R2RProviders",
+    "R2RPipes",
+    "R2RPipelines",
+    "R2RAgents",
+    ## R2R APP
+    "R2RApp",
+    ## R2R APP ENTRY
+    # "r2r_app",
+    ## R2R ASSEMBLY
+    # Builder
+    "R2RBuilder",
+    # Config
+    "R2RConfig",
+    # Factory
+    "R2RProviderFactory",
+    "R2RPipeFactory",
+    "R2RPipelineFactory",
+    "R2RAgentFactory",
+    ## R2R SERVICES
+    "AuthService",
+    "IngestionService",
+    "ManagementService",
+    "RetrievalService",
+    "KgService",
+    ## PARSERS
+    # Media parsers
+    "AudioParser",
+    "DOCXParser",
+    "ImageParser",
+    "VLMPDFParser",
+    "BasicPDFParser",
+    "PDFParserUnstructured",
+    "PPTParser",
+    # Structured parsers
+    "CSVParser",
+    "CSVParserAdvanced",
+    "JSONParser",
+    "XLSXParser",
+    "XLSXParserAdvanced",
+    # Text parsers
+    "MDParser",
+    "HTMLParser",
+    "TextParser",
+    ## PIPELINES
+    "SearchPipeline",
+    "RAGPipeline",
+    ## PIPES
+    "SearchPipe",
+    "EmbeddingPipe",
+    "KGExtractionPipe",
+    "ParsingPipe",
+    "QueryTransformPipe",
+    "SearchRAGPipe",
+    "StreamingSearchRAGPipe",
+    "VectorSearchPipe",
+    "VectorStoragePipe",
+    "KGStoragePipe",
+    "MultiSearchPipe",
+    ## PROVIDERS
+    # Auth
+    "SupabaseAuthProvider",
+    "R2RAuthProvider",
+    # Crypto
+    "BCryptProvider",
+    "BCryptConfig",
+    # Database
+    "PostgresDatabaseProvider",
+    # Embeddings
+    "LiteLLMEmbeddingProvider",
+    "OpenAIEmbeddingProvider",
+    "OllamaEmbeddingProvider",
+    # LLM
+    "OpenAICompletionProvider",
+    "LiteLLMCompletionProvider",
+    # Ingestion
+    "UnstructuredIngestionProvider",
+    "R2RIngestionProvider",
+    "ChunkingStrategy",
+]

+ 11 - 0
core/agent/__init__.py

@@ -0,0 +1,11 @@
+from .base import R2RAgent, R2RStreamingAgent
+from .rag import R2RRAGAgent, R2RStreamingRAGAgent
+
+__all__ = [
+    # Base
+    "R2RAgent",
+    "R2RStreamingAgent",
+    # RAG Agents
+    "R2RRAGAgent",
+    "R2RStreamingRAGAgent",
+]

+ 240 - 0
core/agent/base.py

@@ -0,0 +1,240 @@
+import asyncio
+import logging
+from abc import ABCMeta
+from typing import AsyncGenerator, Generator, Optional
+
+from core.base.abstractions import (
+    AsyncSyncMeta,
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    Message,
+    syncable,
+)
+from core.base.agent import Agent, Conversation
+
+logger = logging.getLogger()
+
+
+class CombinedMeta(AsyncSyncMeta, ABCMeta):
+    pass
+
+
+def sync_wrapper(async_gen):
+    loop = asyncio.get_event_loop()
+
+    def wrapper():
+        try:
+            while True:
+                try:
+                    yield loop.run_until_complete(async_gen.__anext__())
+                except StopAsyncIteration:
+                    break
+        finally:
+            loop.run_until_complete(async_gen.aclose())
+
+    return wrapper()
+
+
+class R2RAgent(Agent, metaclass=CombinedMeta):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._register_tools()
+        self._reset()
+
+    def _reset(self):
+        self._completed = False
+        self.conversation = Conversation()
+
+    @syncable
+    async def arun(
+        self,
+        messages: list[Message],
+        system_instruction: Optional[str] = None,
+        *args,
+        **kwargs,
+    ) -> list[dict]:
+        # TODO - Make this method return a list of messages.
+        self._reset()
+        await self._setup(system_instruction)
+
+        if messages:
+            for message in messages:
+                await self.conversation.add_message(message)
+
+        while not self._completed:
+            messages_list = await self.conversation.get_messages()
+            generation_config = self.get_generation_config(messages_list[-1])
+            response = await self.llm_provider.aget_completion(
+                messages_list,
+                generation_config,
+            )
+            await self.process_llm_response(response, *args, **kwargs)
+
+        # Get the output messages
+        all_messages: list[dict] = await self.conversation.get_messages()
+        all_messages.reverse()
+
+        output_messages = []
+        for message_2 in all_messages:
+            if (
+                message_2.get("content")
+                and message_2.get("content") != messages[-1].content
+            ):
+                output_messages.append(message_2)
+            else:
+                break
+        output_messages.reverse()
+
+        return output_messages
+
+    async def process_llm_response(
+        self, response: LLMChatCompletion, *args, **kwargs
+    ) -> None:
+        if not self._completed:
+            message = response.choices[0].message
+            if message.function_call:
+                await self.handle_function_or_tool_call(
+                    message.function_call.name,
+                    message.function_call.arguments,
+                    *args,
+                    **kwargs,
+                )
+            elif message.tool_calls:
+                for tool_call in message.tool_calls:
+                    await self.handle_function_or_tool_call(
+                        tool_call.function.name,
+                        tool_call.function.arguments,
+                        *args,
+                        **kwargs,
+                    )
+            else:
+                await self.conversation.add_message(
+                    Message(role="assistant", content=message.content)
+                )
+                self._completed = True
+
+
+class R2RStreamingAgent(R2RAgent):
+    async def arun(  # type: ignore
+        self,
+        system_instruction: Optional[str] = None,
+        messages: Optional[list[Message]] = None,
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[str, None]:
+        self._reset()
+        await self._setup(system_instruction)
+
+        if messages:
+            for message in messages:
+                await self.conversation.add_message(message)
+
+        while not self._completed:
+            messages_list = await self.conversation.get_messages()
+
+            generation_config = self.get_generation_config(
+                messages_list[-1], stream=True
+            )
+            stream = self.llm_provider.aget_completion_stream(
+                messages_list,
+                generation_config,
+            )
+            async for proc_chunk in self.process_llm_response(
+                stream, *args, **kwargs
+            ):
+                yield proc_chunk
+
+    def run(
+        self, system_instruction, messages, *args, **kwargs
+    ) -> Generator[str, None, None]:
+        return sync_wrapper(
+            self.arun(system_instruction, messages, *args, **kwargs)
+        )
+
+    async def process_llm_response(  # type: ignore
+        self,
+        stream: AsyncGenerator[LLMChatCompletionChunk, None],
+        *args,
+        **kwargs,
+    ) -> AsyncGenerator[str, None]:
+        function_name = None
+        function_arguments = ""
+        content_buffer = ""
+
+        async for chunk in stream:
+            delta = chunk.choices[0].delta
+            if delta.tool_calls:
+                for tool_call in delta.tool_calls:
+                    if not tool_call.function:
+                        logger.info("Tool function not found in tool call.")
+                        continue
+                    name = tool_call.function.name
+                    if not name:
+                        logger.info("Tool name not found in tool call.")
+                        continue
+                    arguments = tool_call.function.arguments
+                    if not arguments:
+                        logger.info("Tool arguments not found in tool call.")
+                        continue
+
+                    results = await self.handle_function_or_tool_call(
+                        name,
+                        arguments,
+                        # FIXME: tool_call.id,
+                        *args,
+                        **kwargs,
+                    )
+
+                    yield "<tool_call>"
+                    yield f"<name>{name}</name>"
+                    yield f"<arguments>{arguments}</arguments>"
+                    yield f"<results>{results.llm_formatted_result}</results>"
+                    yield "</tool_call>"
+
+            if delta.function_call:
+                if delta.function_call.name:
+                    function_name = delta.function_call.name
+                if delta.function_call.arguments:
+                    function_arguments += delta.function_call.arguments
+            elif delta.content:
+                if content_buffer == "":
+                    yield "<completion>"
+                content_buffer += delta.content
+                yield delta.content
+
+            if chunk.choices[0].finish_reason == "function_call":
+                if not function_name:
+                    logger.info("Function name not found in function call.")
+                    continue
+
+                yield "<function_call>"
+                yield f"<name>{function_name}</name>"
+                yield f"<arguments>{function_arguments}</arguments>"
+                tool_result = await self.handle_function_or_tool_call(
+                    function_name, function_arguments, *args, **kwargs
+                )
+                if tool_result.stream_result:
+                    yield f"<results>{tool_result.stream_result}</results>"
+                else:
+                    yield f"<results>{tool_result.llm_formatted_result}</results>"
+
+                yield "</function_call>"
+
+                function_name = None
+                function_arguments = ""
+
+            elif chunk.choices[0].finish_reason == "stop":
+                if content_buffer:
+                    await self.conversation.add_message(
+                        Message(role="assistant", content=content_buffer)
+                    )
+                self._completed = True
+                yield "</completion>"
+
+        # Handle any remaining content after the stream ends
+        if content_buffer and not self._completed:
+            await self.conversation.add_message(
+                Message(role="assistant", content=content_buffer)
+            )
+            self._completed = True
+            yield "</completion>"

+ 159 - 0
core/agent/rag.py

@@ -0,0 +1,159 @@
+from typing import Union
+
+from core.agent import R2RAgent, R2RStreamingAgent
+from core.base import (
+    format_search_results_for_llm,
+    format_search_results_for_stream,
+)
+from core.base.abstractions import (
+    AggregateSearchResult,
+    GraphSearchSettings,
+    SearchSettings,
+    WebSearchResponse,
+)
+from core.base.agent import AgentConfig, Tool
+from core.base.providers import CompletionProvider, DatabaseProvider
+from core.base.utils import to_async_generator
+from core.pipelines import SearchPipeline
+from core.providers import (  # PostgresDatabaseProvider,
+    LiteLLMCompletionProvider,
+    OpenAICompletionProvider,
+)
+
+
+class RAGAgentMixin:
+    def __init__(self, search_pipeline: SearchPipeline, *args, **kwargs):
+        self.search_pipeline = search_pipeline
+        super().__init__(*args, **kwargs)
+
+    def _register_tools(self):
+        if not self.config.tool_names:
+            return
+        for tool_name in self.config.tool_names:
+            if tool_name == "local_search":
+                self._tools.append(self.local_search())
+            elif tool_name == "web_search":
+                self._tools.append(self.web_search())
+            else:
+                raise ValueError(f"Unsupported tool name: {tool_name}")
+
+    def web_search(self) -> Tool:
+        return Tool(
+            name="web_search",
+            description="Search for information on the web.",
+            results_function=self._web_search,
+            llm_format_function=RAGAgentMixin.format_search_results_for_llm,
+            stream_function=RAGAgentMixin.format_search_results_for_stream,
+            parameters={
+                "type": "object",
+                "properties": {
+                    "query": {
+                        "type": "string",
+                        "description": "The query to search Google with.",
+                    },
+                },
+                "required": ["query"],
+            },
+        )
+
+    async def _web_search(
+        self,
+        query: str,
+        search_settings: SearchSettings,
+        *args,
+        **kwargs,
+    ) -> list[AggregateSearchResult]:
+        from .serper import SerperClient
+
+        serper_client = SerperClient()
+        # TODO - make async!
+        # TODO - Move to search pipeline, make configurable.
+        raw_results = serper_client.get_raw(query)
+        web_response = WebSearchResponse.from_serper_results(raw_results)
+        return AggregateSearchResult(
+            chunk_search_results=None,
+            graph_search_results=None,
+            web_search_results=web_response.organic_results,  # TODO - How do we feel about throwing away so much info?
+        )
+
+    def local_search(self) -> Tool:
+        return Tool(
+            name="local_search",
+            description="Search your local knowledgebase using the R2R AI system",
+            results_function=self._local_search,
+            llm_format_function=RAGAgentMixin.format_search_results_for_llm,
+            stream_function=RAGAgentMixin.format_search_results_for_stream,
+            parameters={
+                "type": "object",
+                "properties": {
+                    "query": {
+                        "type": "string",
+                        "description": "The query to search the local knowledgebase with.",
+                    },
+                },
+                "required": ["query"],
+            },
+        )
+
+    async def _local_search(
+        self,
+        query: str,
+        search_settings: SearchSettings,
+        *args,
+        **kwargs,
+    ) -> list[AggregateSearchResult]:
+        response = await self.search_pipeline.run(
+            to_async_generator([query]),
+            state=None,
+            search_settings=search_settings,
+        )
+        return response
+
+    @staticmethod
+    def format_search_results_for_stream(
+        results: AggregateSearchResult,
+    ) -> str:
+        return format_search_results_for_stream(results)
+
+    @staticmethod
+    def format_search_results_for_llm(
+        results: AggregateSearchResult,
+    ) -> str:
+        return format_search_results_for_llm(results)
+
+
+class R2RRAGAgent(RAGAgentMixin, R2RAgent):
+    def __init__(
+        self,
+        database_provider: DatabaseProvider,
+        llm_provider: Union[
+            LiteLLMCompletionProvider, OpenAICompletionProvider
+        ],
+        search_pipeline: SearchPipeline,
+        config: AgentConfig,
+    ):
+        super().__init__(
+            database_provider=database_provider,
+            search_pipeline=search_pipeline,
+            llm_provider=llm_provider,
+            config=config,
+        )
+
+
+class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
+    def __init__(
+        self,
+        database_provider: DatabaseProvider,
+        llm_provider: Union[
+            LiteLLMCompletionProvider, OpenAICompletionProvider
+        ],
+        search_pipeline: SearchPipeline,
+        config: AgentConfig,
+    ):
+        config.stream = True
+        super().__init__(
+            database_provider=database_provider,
+            search_pipeline=search_pipeline,
+            llm_provider=llm_provider,
+            config=config,
+        )

+ 104 - 0
core/agent/serper.py

@@ -0,0 +1,104 @@
+# TODO - relocate to a dedicated module
+import http.client
+import json
+import os
+
+
+# TODO - Move process json to dedicated data processing module
+def process_json(json_object, indent=0):
+    """
+    Recursively traverses the JSON object (dicts and lists) to create an unstructured text blob.
+    """
+    text_blob = ""
+    if isinstance(json_object, dict):
+        for key, value in json_object.items():
+            padding = "  " * indent
+            if isinstance(value, (dict, list)):
+                text_blob += (
+                    f"{padding}{key}:\n{process_json(value, indent + 1)}"
+                )
+            else:
+                text_blob += f"{padding}{key}: {value}\n"
+    elif isinstance(json_object, list):
+        for index, item in enumerate(json_object):
+            padding = "  " * indent
+            if isinstance(item, (dict, list)):
+                text_blob += f"{padding}Item {index + 1}:\n{process_json(item, indent + 1)}"
+            else:
+                text_blob += f"{padding}Item {index + 1}: {item}\n"
+    return text_blob
+
+
+# TODO - Introduce abstract "Integration" ABC.
+class SerperClient:
+    def __init__(self, api_base: str = "google.serper.dev") -> None:
+        api_key = os.getenv("SERPER_API_KEY")
+        if not api_key:
+            raise ValueError(
+                "Please set the `SERPER_API_KEY` environment variable to use `SerperClient`."
+            )
+
+        self.api_base = api_base
+        self.headers = {
+            "X-API-KEY": api_key,
+            "Content-Type": "application/json",
+        }
+
+    @staticmethod
+    def _extract_results(result_data: dict) -> list:
+        formatted_results = []
+
+        for key, value in result_data.items():
+            # Skip searchParameters as it's not a result entry
+            if key == "searchParameters":
+                continue
+
+            # Handle 'answerBox' as a single item
+            if key == "answerBox":
+                value["type"] = key  # Add the type key to the dictionary
+                formatted_results.append(value)
+            # Handle lists of results
+            elif isinstance(value, list):
+                for item in value:
+                    item["type"] = key  # Add the type key to the dictionary
+                    formatted_results.append(item)
+            # Handle 'peopleAlsoAsk' and potentially other single item formats
+            elif isinstance(value, dict):
+                value["type"] = key  # Add the type key to the dictionary
+                formatted_results.append(value)
+
+        return formatted_results
+
+    # TODO - Add explicit typing for the return value
+    def get_raw(self, query: str, limit: int = 10) -> list:
+        connection = http.client.HTTPSConnection(self.api_base)
+        payload = json.dumps({"q": query, "num_outputs": limit})
+        connection.request("POST", "/search", payload, self.headers)
+        response = connection.getresponse()
+        data = response.read()
+        json_data = json.loads(data.decode("utf-8"))
+        return SerperClient._extract_results(json_data)
+
+    @staticmethod
+    def construct_context(results: list) -> str:
+        # Organize results by type
+        organized_results = {}
+        for result in results:
+            result_type = result.metadata.pop(
+                "type", "Unknown"
+            )  # Pop the type and use as key
+            if result_type not in organized_results:
+                organized_results[result_type] = [result.metadata]
+            else:
+                organized_results[result_type].append(result.metadata)
+
+        context = ""
+        # Iterate over each result type
+        for result_type, items in organized_results.items():
+            context += f"# {result_type} Results:\n"
+            for index, item in enumerate(items, start=1):
+                # Process each item under the current type
+                context += f"Item {index}:\n"
+                context += process_json(item) + "\n"
+
+        return context

+ 139 - 0
core/base/__init__.py

@@ -0,0 +1,139 @@
+from .abstractions import *
+from .agent import *
+from .api.models import *
+from .logger import *
+from .parsers import *
+from .pipeline import *
+from .pipes import *
+from .providers import *
+from .utils import *
+
+__all__ = [
+    ## ABSTRACTIONS
+    # Base abstractions
+    "AsyncSyncMeta",
+    "syncable",
+    # Completion abstractions
+    "MessageType",
+    # Document abstractions
+    "Document",
+    "DocumentChunk",
+    "DocumentResponse",
+    "IngestionStatus",
+    "KGExtractionStatus",
+    "KGEnrichmentStatus",
+    "DocumentType",
+    # Embedding abstractions
+    "EmbeddingPurpose",
+    "default_embedding_prefixes",
+    # Exception abstractions
+    "R2RDocumentProcessingError",
+    "R2RException",
+    # KG abstractions
+    "Entity",
+    "KGExtraction",
+    "Relationship",
+    "Community",
+    "KGCreationSettings",
+    "KGEnrichmentSettings",
+    "KGRunType",
+    # LLM abstractions
+    "GenerationConfig",
+    "LLMChatCompletion",
+    "LLMChatCompletionChunk",
+    "RAGCompletion",
+    # Prompt abstractions
+    "Prompt",
+    # Search abstractions
+    "AggregateSearchResult",
+    "WebSearchResponse",
+    "GraphSearchResult",
+    "GraphSearchSettings",
+    "ChunkSearchSettings",
+    "ChunkSearchResult",
+    "SearchSettings",
+    "select_search_filters",
+    "SearchMode",
+    "HybridSearchSettings",
+    # User abstractions
+    "Token",
+    "TokenData",
+    # Vector abstractions
+    "Vector",
+    "VectorEntry",
+    "VectorType",
+    "StorageResult",
+    "IndexConfig",
+    ## AGENT
+    # Agent abstractions
+    "Agent",
+    "AgentConfig",
+    "Conversation",
+    "Message",
+    "Tool",
+    "ToolResult",
+    ## API
+    # Auth Responses
+    "TokenResponse",
+    "User",
+    ## LOGGING
+    # Basic types
+    "RunType",
+    # Run Manager
+    "RunManager",
+    "manage_run",
+    ## PARSERS
+    # Base parser
+    "AsyncParser",
+    ## PIPELINE
+    # Base pipeline
+    "AsyncPipeline",
+    ## PIPES
+    "AsyncPipe",
+    "AsyncState",
+    ## PROVIDERS
+    # Base provider classes
+    "AppConfig",
+    "Provider",
+    "ProviderConfig",
+    # Auth provider
+    "AuthConfig",
+    "AuthProvider",
+    # Crypto provider
+    "CryptoConfig",
+    "CryptoProvider",
+    # Email provider
+    "EmailConfig",
+    "EmailProvider",
+    # Database providers
+    "DatabaseConfig",
+    "DatabaseProvider",
+    "Handler",
+    "PostgresConfigurationSettings",
+    # Embedding provider
+    "EmbeddingConfig",
+    "EmbeddingProvider",
+    # Ingestion provider
+    "IngestionMode",
+    "IngestionConfig",
+    "IngestionProvider",
+    "ChunkingStrategy",
+    # LLM provider
+    "CompletionConfig",
+    "CompletionProvider",
+    ## UTILS
+    "RecursiveCharacterTextSplitter",
+    "TextSplitter",
+    "run_pipeline",
+    "to_async_generator",
+    "format_search_results_for_llm",
+    "format_search_results_for_stream",
+    "validate_uuid",
+    # ID generation
+    "generate_id",
+    "generate_document_id",
+    "generate_extraction_id",
+    "generate_default_user_collection_id",
+    "generate_user_id",
+    "increment_version",
+]

+ 169 - 0
core/base/abstractions/__init__.py

@@ -0,0 +1,169 @@
+from shared.abstractions.base import AsyncSyncMeta, R2RSerializable, syncable
+from shared.abstractions.document import (
+    Document,
+    DocumentChunk,
+    DocumentResponse,
+    DocumentType,
+    IngestionStatus,
+    KGEnrichmentStatus,
+    KGExtractionStatus,
+    RawChunk,
+    UnprocessedChunk,
+    UpdateChunk,
+)
+from shared.abstractions.embedding import (
+    EmbeddingPurpose,
+    default_embedding_prefixes,
+)
+from shared.abstractions.exception import (
+    R2RDocumentProcessingError,
+    R2RException,
+)
+from shared.abstractions.graph import (
+    Community,
+    Entity,
+    Graph,
+    KGExtraction,
+    Relationship,
+)
+from shared.abstractions.ingestion import (
+    ChunkEnrichmentSettings,
+    ChunkEnrichmentStrategy,
+)
+from shared.abstractions.kg import (
+    GraphBuildSettings,
+    GraphCommunitySettings,
+    GraphEntitySettings,
+    GraphRelationshipSettings,
+    KGCreationSettings,
+    KGEnrichmentSettings,
+    KGEntityDeduplicationSettings,
+    KGEntityDeduplicationType,
+    KGRunType,
+)
+from shared.abstractions.llm import (
+    GenerationConfig,
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    Message,
+    MessageType,
+    RAGCompletion,
+)
+from shared.abstractions.prompt import Prompt
+from shared.abstractions.search import (
+    AggregateSearchResult,
+    ChunkSearchResult,
+    ChunkSearchSettings,
+    GraphSearchResult,
+    GraphSearchSettings,
+    HybridSearchSettings,
+    KGCommunityResult,
+    KGEntityResult,
+    KGGlobalResult,
+    KGRelationshipResult,
+    KGSearchResultType,
+    SearchMode,
+    SearchSettings,
+    WebSearchResponse,
+    select_search_filters,
+)
+from shared.abstractions.user import Token, TokenData, User
+from shared.abstractions.vector import (
+    IndexArgsHNSW,
+    IndexArgsIVFFlat,
+    IndexConfig,
+    IndexMeasure,
+    IndexMethod,
+    StorageResult,
+    Vector,
+    VectorEntry,
+    VectorQuantizationSettings,
+    VectorQuantizationType,
+    VectorTableName,
+    VectorType,
+)
+
+__all__ = [
+    # Base abstractions
+    "R2RSerializable",
+    "AsyncSyncMeta",
+    "syncable",
+    # Completion abstractions
+    "MessageType",
+    # Document abstractions
+    "Document",
+    "DocumentChunk",
+    "DocumentResponse",
+    "DocumentType",
+    "IngestionStatus",
+    "KGExtractionStatus",
+    "KGEnrichmentStatus",
+    "RawChunk",
+    "UnprocessedChunk",
+    "UpdateChunk",
+    # Embedding abstractions
+    "EmbeddingPurpose",
+    "default_embedding_prefixes",
+    # Exception abstractions
+    "R2RDocumentProcessingError",
+    "R2RException",
+    # Graph abstractions
+    "Entity",
+    "Community",
+    "KGExtraction",
+    "Relationship",
+    # Index abstractions
+    "IndexConfig",
+    # LLM abstractions
+    "GenerationConfig",
+    "LLMChatCompletion",
+    "LLMChatCompletionChunk",
+    "Message",
+    "RAGCompletion",
+    # Prompt abstractions
+    "Prompt",
+    # Search abstractions
+    "WebSearchResponse",
+    "AggregateSearchResult",
+    "GraphSearchResult",
+    "KGSearchResultType",
+    "KGEntityResult",
+    "KGRelationshipResult",
+    "KGCommunityResult",
+    "KGGlobalResult",
+    "GraphSearchSettings",
+    "ChunkSearchSettings",
+    "ChunkSearchResult",
+    "SearchSettings",
+    "select_search_filters",
+    "SearchMode",
+    "HybridSearchSettings",
+    # KG abstractions
+    "KGCreationSettings",
+    "KGEnrichmentSettings",
+    "KGEntityDeduplicationSettings",
+    "GraphBuildSettings",
+    "GraphEntitySettings",
+    "GraphRelationshipSettings",
+    "GraphCommunitySettings",
+    "KGEntityDeduplicationType",
+    "KGRunType",
+    # User abstractions
+    "Token",
+    "TokenData",
+    "User",
+    # Vector abstractions
+    "Vector",
+    "VectorEntry",
+    "VectorType",
+    "IndexMeasure",
+    "IndexMethod",
+    "VectorTableName",
+    "IndexArgsHNSW",
+    "IndexArgsIVFFlat",
+    "VectorQuantizationSettings",
+    "VectorQuantizationType",
+    "StorageResult",
+    "ChunkEnrichmentSettings",
+    "ChunkEnrichmentStrategy",
+]

+ 10 - 0
core/base/agent/__init__.py

@@ -0,0 +1,10 @@
+from .agent import Agent, AgentConfig, Conversation, Tool, ToolResult
+
+__all__ = [
+    # Agent abstractions
+    "Agent",
+    "AgentConfig",
+    "Conversation",
+    "Tool",
+    "ToolResult",
+]

+ 247 - 0
core/base/agent/agent.py

@@ -0,0 +1,247 @@
+import asyncio
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, AsyncGenerator, Optional, Type
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+    GenerationConfig,
+    LLMChatCompletion,
+    Message,
+    MessageType,
+)
+from core.base.providers import CompletionProvider, DatabaseProvider
+
+from .base import Tool, ToolResult
+
+logger = logging.getLogger()
+
+
+class Conversation:
+    def __init__(self):
+        self.messages: list[Message] = []
+        self._lock = asyncio.Lock()
+
+    def create_and_add_message(
+        self,
+        role: MessageType | str,
+        content: Optional[str] = None,
+        name: Optional[str] = None,
+        function_call: Optional[dict[str, Any]] = None,
+        tool_calls: Optional[list[dict[str, Any]]] = None,
+    ):
+        message = Message(
+            role=role,
+            content=content,
+            name=name,
+            function_call=function_call,
+            tool_calls=tool_calls,
+        )
+        self.add_message(message)
+
+    async def add_message(self, message):
+        async with self._lock:
+            self.messages.append(message)
+
+    async def get_messages(self) -> list[dict[str, Any]]:
+        async with self._lock:
+            return [
+                {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
+                for msg in self.messages
+            ]
+
+
+# TODO - Move agents to provider pattern
+class AgentConfig(BaseModel):
+    system_instruction_name: str = "rag_agent"
+    tool_names: list[str] = ["search"]
+    generation_config: GenerationConfig = GenerationConfig()
+    stream: bool = False
+
+    @classmethod
+    def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
+        base_args = cls.model_fields.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        return cls(**filtered_kwargs)  # type: ignore
+
+
+class Agent(ABC):
+    def __init__(
+        self,
+        llm_provider: CompletionProvider,
+        database_provider: DatabaseProvider,
+        config: AgentConfig,
+    ):
+        self.llm_provider = llm_provider
+        self.database_provider: DatabaseProvider = database_provider
+        self.config = config
+        self.conversation = Conversation()
+        self._completed = False
+        self._tools: list[Tool] = []
+        self._register_tools()
+
+    @abstractmethod
+    def _register_tools(self):
+        pass
+
+    async def _setup(self, system_instruction: Optional[str] = None):
+        content = system_instruction or (
+            await self.database_provider.prompts_handler.get_cached_prompt(
+                self.config.system_instruction_name
+            )
+        )
+        await self.conversation.add_message(
+            Message(
+                role="system",
+                content=system_instruction
+                or (
+                    await self.database_provider.prompts_handler.get_cached_prompt(
+                        self.config.system_instruction_name
+                    )
+                ),
+            )
+        )
+
+    @property
+    def tools(self) -> list[Tool]:
+        return self._tools
+
+    @tools.setter
+    def tools(self, tools: list[Tool]):
+        self._tools = tools
+
+    @abstractmethod
+    async def arun(
+        self,
+        system_instruction: Optional[str] = None,
+        messages: Optional[list[Message]] = None,
+        *args,
+        **kwargs,
+    ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
+        pass
+
+    @abstractmethod
+    async def process_llm_response(
+        self,
+        response: Any,
+        *args,
+        **kwargs,
+    ) -> None | AsyncGenerator[str, None]:
+        pass
+
+    async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
+        if tool := next((t for t in self.tools if t.name == tool_name), None):
+            return await tool.results_function(*args, **kwargs)
+        else:
+            return f"Error: Tool {tool_name} not found."
+
+    def get_generation_config(
+        self, last_message: dict, stream: bool = False
+    ) -> GenerationConfig:
+        if (
+            last_message["role"] in ["tool", "function"]
+            and last_message["content"] != ""
+        ):
+            return GenerationConfig(
+                **self.config.generation_config.model_dump(
+                    exclude={"functions", "tools", "stream"}
+                ),
+                stream=stream,
+            )
+        return GenerationConfig(
+            **self.config.generation_config.model_dump(
+                exclude={"functions", "tools", "stream"}
+            ),
+            # FIXME: Use tools instead of functions
+            # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
+            # tools=[
+            #     {
+            #         "function":{
+            #             "name": tool.name,
+            #             "description": tool.description,
+            #             "parameters": tool.parameters,
+            #         },
+            #         "type": "function"
+            #     }
+            #     for tool in self.tools
+            # ],
+            functions=[
+                {
+                    "name": tool.name,
+                    "description": tool.description,
+                    "parameters": tool.parameters,
+                }
+                for tool in self.tools
+            ],
+            stream=stream,
+        )
+
+    async def handle_function_or_tool_call(
+        self,
+        function_name: str,
+        function_arguments: str,
+        tool_id: Optional[str] = None,
+        *args,
+        **kwargs,
+    ) -> ToolResult:
+        await self.conversation.add_message(
+            Message(
+                role="assistant",
+                tool_calls=(
+                    [
+                        {
+                            "id": tool_id,
+                            "function": {
+                                "name": function_name,
+                                "arguments": function_arguments,
+                            },
+                        }
+                    ]
+                    if tool_id
+                    else None
+                ),
+                function_call=(
+                    {
+                        "name": function_name,
+                        "arguments": function_arguments,
+                    }
+                    if not tool_id
+                    else None
+                ),
+            )
+        )
+
+        if tool := next(
+            (t for t in self.tools if t.name == function_name), None
+        ):
+            merged_kwargs = {**kwargs, **json.loads(function_arguments)}
+            raw_result = await tool.results_function(*args, **merged_kwargs)
+            llm_formatted_result = tool.llm_format_function(raw_result)
+            tool_result = ToolResult(
+                raw_result=raw_result,
+                llm_formatted_result=llm_formatted_result,
+            )
+            if tool.stream_function:
+                tool_result.stream_result = tool.stream_function(raw_result)
+        else:
+            error_message = f"The requested tool '{function_name}' is not available. Available tools: {', '.join(t.name for t in self.tools)}"
+            tool_result = ToolResult(
+                raw_result=error_message,
+                llm_formatted_result=error_message,
+            )
+
+        await self.conversation.add_message(
+            Message(
+                role="tool" if tool_id else "function",
+                content=str(tool_result.llm_formatted_result),
+                name=function_name,
+            )
+        )
+
+        return tool_result

+ 22 - 0
core/base/agent/base.py

@@ -0,0 +1,22 @@
+from typing import Any, Callable, Optional
+
+from ..abstractions import R2RSerializable
+
+
+class Tool(R2RSerializable):
+    name: str
+    description: str
+    results_function: Callable
+    llm_format_function: Callable
+    stream_function: Optional[Callable] = None
+    parameters: Optional[dict[str, Any]] = None
+
+    class Config:
+        populate_by_name = True
+        arbitrary_types_allowed = True
+
+
+class ToolResult(R2RSerializable):
+    raw_result: Any
+    llm_formatted_result: str
+    stream_result: Optional[str] = None

+ 159 - 0
core/base/api/models/__init__.py

@@ -0,0 +1,159 @@
+from shared.api.models.auth.responses import (
+    TokenResponse,
+    WrappedTokenResponse,
+)
+from shared.api.models.base import (
+    GenericBooleanResponse,
+    GenericMessageResponse,
+    PaginatedR2RResult,
+    R2RResults,
+    WrappedBooleanResponse,
+    WrappedGenericMessageResponse,
+)
+from shared.api.models.ingestion.responses import (
+    IngestionResponse,
+    UpdateResponse,
+    WrappedIngestionResponse,
+    WrappedListVectorIndicesResponse,
+    WrappedMetadataUpdateResponse,
+    WrappedUpdateResponse,
+)
+from shared.api.models.kg.responses import (  # TODO: Need to review anything above this
+    Community,
+    Entity,
+    GraphResponse,
+    Relationship,
+    WrappedCommunitiesResponse,
+    WrappedCommunityResponse,
+    WrappedEntitiesResponse,
+    WrappedEntityResponse,
+    WrappedGraphResponse,
+    WrappedGraphsResponse,
+    WrappedRelationshipResponse,
+    WrappedRelationshipsResponse,
+)
+from shared.api.models.management.responses import (  # Document Responses; Prompt Responses; Chunk Responses; Conversation Responses; User Responses; TODO: anything below this hasn't been reviewed
+    AnalyticsResponse,
+    ChunkResponse,
+    CollectionResponse,
+    ConversationResponse,
+    LogResponse,
+    PromptResponse,
+    ServerStats,
+    SettingsResponse,
+    User,
+    WrappedAnalyticsResponse,
+    WrappedChunkResponse,
+    WrappedChunksResponse,
+    WrappedCollectionResponse,
+    WrappedCollectionsResponse,
+    WrappedConversationMessagesResponse,
+    WrappedConversationResponse,
+    WrappedConversationsResponse,
+    WrappedDocumentResponse,
+    WrappedDocumentsResponse,
+    WrappedLogsResponse,
+    WrappedMessageResponse,
+    WrappedMessagesResponse,
+    WrappedPromptResponse,
+    WrappedPromptsResponse,
+    WrappedResetDataResult,
+    WrappedServerStatsResponse,
+    WrappedSettingsResponse,
+    WrappedUserResponse,
+    WrappedUsersResponse,
+    WrappedVerificationResult,
+)
+from shared.api.models.retrieval.responses import (
+    AgentResponse,
+    CombinedSearchResponse,
+    RAGResponse,
+    WrappedAgentResponse,
+    WrappedCompletionResponse,
+    WrappedDocumentSearchResponse,
+    WrappedRAGResponse,
+    WrappedSearchResponse,
+    WrappedVectorSearchResponse,
+)
+
+__all__ = [
+    # Auth Responses
+    "TokenResponse",
+    "WrappedTokenResponse",
+    "WrappedVerificationResult",
+    "WrappedGenericMessageResponse",
+    "WrappedResetDataResult",
+    # Ingestion Responses
+    "IngestionResponse",
+    "WrappedIngestionResponse",
+    "WrappedUpdateResponse",
+    "WrappedMetadataUpdateResponse",
+    "WrappedListVectorIndicesResponse",
+    "UpdateResponse",
+    # Knowledge Graph Responses
+    "Entity",
+    "Relationship",
+    "Community",
+    "WrappedEntityResponse",
+    "WrappedEntitiesResponse",
+    "WrappedRelationshipResponse",
+    "WrappedRelationshipsResponse",
+    "WrappedCommunityResponse",
+    "WrappedCommunitiesResponse",
+    # TODO: Need to review anything above this
+    "GraphResponse",
+    "WrappedGraphResponse",
+    "WrappedGraphsResponse",
+    # Management Responses
+    "PromptResponse",
+    "ServerStats",
+    "LogResponse",
+    "AnalyticsResponse",
+    "SettingsResponse",
+    "ChunkResponse",
+    "CollectionResponse",
+    "WrappedServerStatsResponse",
+    "WrappedLogsResponse",
+    "WrappedAnalyticsResponse",
+    "WrappedSettingsResponse",
+    "WrappedDocumentResponse",
+    "WrappedDocumentsResponse",
+    "WrappedCollectionResponse",
+    "WrappedCollectionsResponse",
+    # Conversation Responses
+    "ConversationResponse",
+    "WrappedConversationMessagesResponse",
+    "WrappedConversationResponse",
+    "WrappedConversationsResponse",
+    # Prompt Responses
+    "WrappedPromptResponse",
+    "WrappedPromptsResponse",
+    # Conversation Responses
+    "WrappedMessageResponse",
+    "WrappedMessagesResponse",
+    # Chunk Responses
+    "WrappedChunkResponse",
+    "WrappedChunksResponse",
+    # User Responses
+    "User",
+    "WrappedUserResponse",
+    "WrappedUsersResponse",
+    # Base Responses
+    "PaginatedR2RResult",
+    "R2RResults",
+    "GenericBooleanResponse",
+    "GenericMessageResponse",
+    "WrappedBooleanResponse",
+    "WrappedGenericMessageResponse",
+    # TODO: This needs to be cleaned up
+    # Retrieval Responses
+    "CombinedSearchResponse",
+    "RAGResponse",
+    "AgentResponse",
+    "WrappedDocumentSearchResponse",
+    "WrappedSearchResponse",
+    "WrappedVectorSearchResponse",
+    "WrappedCompletionResponse",
+    "WrappedRAGResponse",
+    "WrappedAgentResponse",
+]

+ 11 - 0
core/base/logger/__init__.py

@@ -0,0 +1,11 @@
+from .base import RunInfoLog, RunType
+from .run_manager import RunManager, manage_run
+
+__all__ = [
+    # Basic types
+    "RunType",
+    "RunInfoLog",
+    # Run Manager
+    "RunManager",
+    "manage_run",
+]

+ 32 - 0
core/base/logger/base.py

@@ -0,0 +1,32 @@
+import logging
+from abc import abstractmethod
+from datetime import datetime
+from enum import Enum
+from typing import Any, Optional, Tuple, Union
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base import Message
+
+from ..providers.base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class RunInfoLog(BaseModel):
+    run_id: UUID
+    run_type: str
+    timestamp: datetime
+    user_id: UUID
+
+
+class RunType(str, Enum):
+    """Enumeration of the different types of runs."""
+
+    RETRIEVAL = "RETRIEVAL"
+    MANAGEMENT = "MANAGEMENT"
+    INGESTION = "INGESTION"
+    AUTH = "AUTH"
+    UNSPECIFIED = "UNSPECIFIED"
+    KG = "KG"

+ 62 - 0
core/base/logger/run_manager.py

@@ -0,0 +1,62 @@
+import asyncio
+import contextvars
+from contextlib import asynccontextmanager
+from typing import Optional
+from uuid import UUID
+
+from core.base.api.models import User
+from core.base.logger.base import RunType
+from core.base.utils import generate_id
+
+run_id_var = contextvars.ContextVar("run_id", default=generate_id())
+
+
+class RunManager:
+    def __init__(self):
+        self.run_info: dict[UUID, dict] = {}
+
+    async def set_run_info(self, run_type: str, run_id: Optional[UUID] = None):
+        run_id = run_id or run_id_var.get()
+        if run_id is None:
+            run_id = generate_id()
+            token = run_id_var.set(run_id)
+            self.run_info[run_id] = {"run_type": run_type}
+        else:
+            token = run_id_var.set(run_id)
+        return run_id, token
+
+    async def get_info_logs(self):
+        run_id = run_id_var.get()
+        return self.run_info.get(run_id, None)
+
+    async def log_run_info(
+        self,
+        run_type: RunType,
+        user: User,
+    ):
+        if asyncio.iscoroutine(user):
+            user = await user
+
+    async def clear_run_info(self, token: contextvars.Token):
+        run_id = run_id_var.get()
+        run_id_var.reset(token)
+        if run_id and run_id in self.run_info:
+            del self.run_info[run_id]
+
+
+@asynccontextmanager
+async def manage_run(
+    run_manager: RunManager,
+    run_type: RunType = RunType.UNSPECIFIED,
+    run_id: Optional[UUID] = None,
+):
+    run_id, token = await run_manager.set_run_info(run_type, run_id)
+    try:
+        yield run_id
+    finally:
+        # Check if we're in a test environment
+        if isinstance(token, contextvars.Token):
+            run_id_var.reset(token)
+        else:
+            # We're in a test environment, just reset the run_id_var
+            run_id_var.set(None)  # type: ignore

+ 5 - 0
core/base/parsers/__init__.py

@@ -0,0 +1,5 @@
+from .base_parser import AsyncParser
+
+__all__ = [
+    "AsyncParser",
+]

+ 13 - 0
core/base/parsers/base_parser.py

@@ -0,0 +1,13 @@
+"""Abstract base class for parsers."""
+
+from abc import ABC, abstractmethod
+from typing import AsyncGenerator, Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class AsyncParser(ABC, Generic[T]):
+
+    @abstractmethod
+    async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
+        pass

+ 5 - 0
core/base/pipeline/__init__.py

@@ -0,0 +1,5 @@
+from .base_pipeline import AsyncPipeline
+
+__all__ = [
+    "AsyncPipeline",
+]

+ 180 - 0
core/base/pipeline/base_pipeline.py

@@ -0,0 +1,180 @@
+"""Base pipeline class for running a sequence of pipes."""
+
+import asyncio
+import logging
+import traceback
+from typing import Any, AsyncGenerator, Optional
+
+from ..logger.run_manager import RunManager, manage_run
+from ..pipes.base_pipe import AsyncPipe, AsyncState
+
+logger = logging.getLogger()
+
+
+class AsyncPipeline:
+    """Pipeline class for running a sequence of pipes."""
+
+    def __init__(
+        self,
+        run_manager: Optional[RunManager] = None,
+    ):
+        self.pipes: list[AsyncPipe] = []
+        self.upstream_outputs: list[list[dict[str, str]]] = []
+        self.run_manager = run_manager or RunManager()
+        self.futures: dict[str, asyncio.Future] = {}
+        self.level = 0
+
+    def add_pipe(
+        self,
+        pipe: AsyncPipe,
+        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        """Add a pipe to the pipeline."""
+        self.pipes.append(pipe)
+        if not add_upstream_outputs:
+            add_upstream_outputs = []
+        self.upstream_outputs.append(add_upstream_outputs)
+
+    async def run(
+        self,
+        input: Any,
+        state: Optional[AsyncState] = None,
+        stream: bool = False,
+        run_manager: Optional[RunManager] = None,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        """Run the pipeline."""
+        run_manager = run_manager or self.run_manager
+        self.state = state or AsyncState()
+        current_input = input
+        async with manage_run(run_manager):
+            try:
+                for pipe_num in range(len(self.pipes)):
+                    config_name = self.pipes[pipe_num].config.name
+                    self.futures[config_name] = asyncio.Future()
+
+                    current_input = self._run_pipe(
+                        pipe_num,
+                        current_input,
+                        run_manager,
+                        *args,
+                        **kwargs,
+                    )
+                    self.futures[config_name].set_result(current_input)
+
+            except Exception as error:
+                # TODO: improve error handling here
+                error_trace = traceback.format_exc()
+                logger.error(
+                    f"Pipeline failed with error: {error}\n\nStack trace:\n{error_trace}"
+                )
+                raise error
+
+            return (
+                current_input
+                if stream
+                else await self._consume_all(current_input)
+            )
+
+    async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
+        result = []
+        async for item in gen:
+            if hasattr(
+                item, "__aiter__"
+            ):  # Check if the item is an async generator
+                sub_result = await self._consume_all(item)
+                result.extend(sub_result)
+            else:
+                result.append(item)
+        return result
+
+    async def _run_pipe(
+        self,
+        pipe_num: int,
+        input: Any,
+        run_manager: RunManager,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        # Collect inputs, waiting for the necessary futures
+        pipe = self.pipes[pipe_num]
+        add_upstream_outputs = self.sort_upstream_outputs(
+            self.upstream_outputs[pipe_num]
+        )
+        input_dict = {"message": input}
+
+        # Collection upstream outputs by prev_pipe_name
+        grouped_upstream_outputs: dict[str, list] = {}
+        for upstream_input in add_upstream_outputs:
+            upstream_pipe_name = upstream_input["prev_pipe_name"]
+            if upstream_pipe_name not in grouped_upstream_outputs:
+                grouped_upstream_outputs[upstream_pipe_name] = []
+            grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
+
+        for (
+            upstream_pipe_name,
+            upstream_inputs,
+        ) in grouped_upstream_outputs.items():
+
+            async def resolve_future_output(future):
+                result = future.result()
+                # consume the async generator
+                return [item async for item in result]
+
+            async def replay_items_as_async_gen(items):
+                for item in items:
+                    yield item
+
+            temp_results = await resolve_future_output(
+                self.futures[upstream_pipe_name]
+            )
+            if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
+                input_dict["message"] = replay_items_as_async_gen(temp_results)
+
+            for upstream_input in upstream_inputs:
+                outputs = await self.state.get(upstream_pipe_name, "output")
+                prev_output_field = upstream_input.get(
+                    "prev_output_field", None
+                )
+                if not prev_output_field:
+                    raise ValueError(
+                        "`prev_output_field` must be specified in the upstream_input"
+                    )
+                input_dict[upstream_input["input_field"]] = outputs[
+                    prev_output_field
+                ]
+        async for ele in await pipe.run(
+            pipe.Input(**input_dict),
+            self.state,
+            run_manager,
+            *args,
+            **kwargs,
+        ):
+            yield ele
+
+    def sort_upstream_outputs(
+        self, add_upstream_outputs: list[dict[str, str]]
+    ) -> list[dict[str, str]]:
+        pipe_name_to_index = {
+            pipe.config.name: index for index, pipe in enumerate(self.pipes)
+        }
+
+        def get_pipe_index(upstream_output):
+            return pipe_name_to_index[upstream_output["prev_pipe_name"]]
+
+        sorted_outputs = sorted(
+            add_upstream_outputs, key=get_pipe_index, reverse=True
+        )
+        return sorted_outputs
+
+
+async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
+    """Create an async generator to dequeue requests."""
+    while True:
+        request = await queue.get()
+        if request is None:
+            break
+        yield request

+ 3 - 0
core/base/pipes/__init__.py

@@ -0,0 +1,3 @@
+from .base_pipe import AsyncPipe, AsyncState
+
+__all__ = ["AsyncPipe", "AsyncState"]

+ 128 - 0
core/base/pipes/base_pipe.py

@@ -0,0 +1,128 @@
+import asyncio
+import logging
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, AsyncGenerator, Generic, Optional, TypeVar
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.logger.base import RunType
+from core.base.logger.run_manager import RunManager, manage_run
+
+logger = logging.getLogger()
+
+
+class AsyncState:
+    """A state object for storing data between pipes."""
+
+    def __init__(self):
+        self.data = {}
+        self.lock = asyncio.Lock()
+
+    async def update(self, outer_key: str, values: dict):
+        """Update the state with new values."""
+        async with self.lock:
+            if not isinstance(values, dict):
+                raise ValueError("Values must be contained in a dictionary.")
+            if outer_key not in self.data:
+                self.data[outer_key] = {}
+            for inner_key, inner_value in values.items():
+                self.data[outer_key][inner_key] = inner_value
+
+    async def get(self, outer_key: str, inner_key: str, default=None):
+        """Get a value from the state."""
+        async with self.lock:
+            if outer_key not in self.data:
+                raise ValueError(
+                    f"Key {outer_key} does not exist in the state."
+                )
+            if inner_key not in self.data[outer_key]:
+                return default or {}
+            return self.data[outer_key][inner_key]
+
+    async def delete(self, outer_key: str, inner_key: Optional[str] = None):
+        """Delete a value from the state."""
+        async with self.lock:
+            if outer_key in self.data and not inner_key:
+                del self.data[outer_key]
+            else:
+                if inner_key not in self.data[outer_key]:
+                    raise ValueError(
+                        f"Key {inner_key} does not exist in the state."
+                    )
+                del self.data[outer_key][inner_key]
+
+
+T = TypeVar("T")
+
+
+class AsyncPipe(Generic[T]):
+    """An asynchronous pipe for processing data with logging capabilities."""
+
+    class PipeConfig(BaseModel):
+        """Configuration for a pipe."""
+
+        name: str = "default_pipe"
+        max_log_queue_size: int = 100
+
+        class Config:
+            extra = "forbid"
+            arbitrary_types_allowed = True
+
+    class Input(BaseModel):
+        """Input for a pipe."""
+
+        message: Any
+
+        class Config:
+            extra = "forbid"
+            arbitrary_types_allowed = True
+
+    def __init__(
+        self,
+        config: PipeConfig,
+        run_manager: Optional[RunManager] = None,
+    ):
+        # TODO - Deprecate
+        self._config = config or self.PipeConfig()
+        self._run_manager = run_manager or RunManager()
+
+        logger.debug(f"Initialized pipe {self.config.name}")
+
+    @property
+    def config(self) -> PipeConfig:
+        return self._config
+
+    async def run(
+        self,
+        input: Input,
+        state: Optional[AsyncState],
+        run_manager: Optional[RunManager] = None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[T, None]:
+        """Run the pipe with logging capabilities."""
+
+        run_manager = run_manager or self._run_manager
+        state = state or AsyncState()
+
+        async def wrapped_run() -> AsyncGenerator[Any, None]:
+            async with manage_run(run_manager, RunType.UNSPECIFIED) as run_id:  # type: ignore
+                async for result in self._run_logic(  # type: ignore
+                    input, state, run_id, *args, **kwargs  # type: ignore
+                ):
+                    yield result
+
+        return wrapped_run()
+
+    @abstractmethod
+    async def _run_logic(
+        self,
+        input: Input,
+        state: AsyncState,
+        run_id: UUID,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[T, None]:
+        pass

+ 57 - 0
core/base/providers/__init__.py

@@ -0,0 +1,57 @@
+from .auth import AuthConfig, AuthProvider
+from .base import AppConfig, Provider, ProviderConfig
+from .crypto import CryptoConfig, CryptoProvider
+from .database import (
+    DatabaseConfig,
+    DatabaseConnectionManager,
+    DatabaseProvider,
+    Handler,
+    PostgresConfigurationSettings,
+)
+from .email import EmailConfig, EmailProvider
+from .embedding import EmbeddingConfig, EmbeddingProvider
+from .ingestion import (
+    ChunkingStrategy,
+    IngestionConfig,
+    IngestionMode,
+    IngestionProvider,
+)
+from .llm import CompletionConfig, CompletionProvider
+from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
+
+__all__ = [
+    # Auth provider
+    "AuthConfig",
+    "AuthProvider",
+    # Base provider classes
+    "AppConfig",
+    "Provider",
+    "ProviderConfig",
+    # Ingestion provider
+    "IngestionMode",
+    "IngestionConfig",
+    "IngestionProvider",
+    "ChunkingStrategy",
+    # Crypto provider
+    "CryptoConfig",
+    "CryptoProvider",
+    # Email provider
+    "EmailConfig",
+    "EmailProvider",
+    # Database providers
+    "DatabaseConnectionManager",
+    "DatabaseConfig",
+    "PostgresConfigurationSettings",
+    "DatabaseProvider",
+    "Handler",
+    # Embedding provider
+    "EmbeddingConfig",
+    "EmbeddingProvider",
+    # LLM provider
+    "CompletionConfig",
+    "CompletionProvider",
+    # Orchestration provider
+    "OrchestrationConfig",
+    "OrchestrationProvider",
+    "Workflow",
+]

+ 155 - 0
core/base/providers/auth.py

@@ -0,0 +1,155 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+
+# from .database import DatabaseProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+    from core.database import PostgresDatabaseProvider
+
+
+class AuthConfig(ProviderConfig):
+    secret_key: Optional[str] = None
+    require_authentication: bool = False
+    require_email_verification: bool = False
+    default_admin_email: str = "admin@example.com"
+    default_admin_password: str = "change_me_immediately"
+    access_token_lifetime_in_minutes: Optional[int] = None
+    refresh_token_lifetime_in_days: Optional[int] = None
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["r2r"]
+
+    def validate_config(self) -> None:
+        pass
+
+
+class AuthProvider(Provider, ABC):
+    security = HTTPBearer(auto_error=False)
+    crypto_provider: CryptoProvider
+    email_provider: EmailProvider
+    database_provider: "PostgresDatabaseProvider"
+
+    def __init__(
+        self,
+        config: AuthConfig,
+        crypto_provider: CryptoProvider,
+        database_provider: "PostgresDatabaseProvider",
+        email_provider: EmailProvider,
+    ):
+        if not isinstance(config, AuthConfig):
+            raise ValueError(
+                "AuthProvider must be initialized with an AuthConfig"
+            )
+        self.config = config
+        self.admin_email = config.default_admin_email
+        self.admin_password = config.default_admin_password
+        self.crypto_provider = crypto_provider
+        self.database_provider = database_provider
+        self.email_provider = email_provider
+        super().__init__(config)
+        self.config: AuthConfig = config  # for type hinting
+        self.database_provider: "PostgresDatabaseProvider" = (
+            database_provider  # for type hinting
+        )
+
+    async def _get_default_admin_user(self) -> User:
+        return await self.database_provider.users_handler.get_user_by_email(
+            self.admin_email
+        )
+
+    @abstractmethod
+    def create_access_token(self, data: dict) -> str:
+        pass
+
+    @abstractmethod
+    def create_refresh_token(self, data: dict) -> str:
+        pass
+
+    @abstractmethod
+    async def decode_token(self, token: str) -> TokenData:
+        pass
+
+    @abstractmethod
+    async def user(self, token: str) -> User:
+        pass
+
+    @abstractmethod
+    def get_current_active_user(self, current_user: User) -> User:
+        pass
+
+    @abstractmethod
+    async def register(self, email: str, password: str) -> User:
+        pass
+
+    @abstractmethod
+    async def verify_email(
+        self, email: str, verification_code: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def login(self, email: str, password: str) -> dict[str, Token]:
+        pass
+
+    @abstractmethod
+    async def refresh_access_token(
+        self, refresh_token: str
+    ) -> dict[str, Token]:
+        pass
+
+    async def auth_wrapper(
+        self, auth: Optional[HTTPAuthorizationCredentials] = Security(security)
+    ) -> User:
+        if not self.config.require_authentication and auth is None:
+            return await self._get_default_admin_user()
+
+        if auth is None:
+            raise R2RException(
+                message="Authentication required.",
+                status_code=401,
+            )
+
+        try:
+            return await self.user(auth.credentials)
+        except Exception as e:
+            raise R2RException(
+                message=f"Error '{e}' occurred during authentication.",
+                status_code=404,
+            )
+
+    @abstractmethod
+    async def change_password(
+        self, user: User, current_password: str, new_password: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def request_password_reset(self, email: str) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def confirm_password_reset(
+        self, reset_token: str, new_password: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def logout(self, token: str) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def send_reset_email(self, email: str) -> dict[str, str]:
+        pass

+ 68 - 0
core/base/providers/base.py

@@ -0,0 +1,68 @@
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, Type
+
+from pydantic import BaseModel
+
+from ..abstractions import R2RSerializable
+
+
+class AppConfig(R2RSerializable):
+    project_name: Optional[str] = None
+
+    @classmethod
+    def create(cls, *args, **kwargs):
+        project_name = kwargs.get("project_name")
+        return AppConfig(project_name=project_name)
+
+
+class ProviderConfig(BaseModel, ABC):
+    """A base provider configuration class"""
+
+    app: AppConfig  # Add an app_config field
+    extra_fields: dict[str, Any] = {}
+    provider: Optional[str] = None
+
+    class Config:
+        populate_by_name = True
+        arbitrary_types_allowed = True
+        ignore_extra = True
+
+    @abstractmethod
+    def validate_config(self) -> None:
+        pass
+
+    @classmethod
+    def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+        base_args = cls.model_fields.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        instance = cls(**filtered_kwargs)  # type: ignore
+        for k, v in kwargs.items():
+            if k not in base_args:
+                instance.extra_fields[k] = v
+        return instance
+
+    @property
+    @abstractmethod
+    def supported_providers(self) -> list[str]:
+        """Define a list of supported providers."""
+        pass
+
+    @classmethod
+    def from_dict(
+        cls: Type["ProviderConfig"], data: dict[str, Any]
+    ) -> "ProviderConfig":
+        """Create a new instance of the config from a dictionary."""
+        return cls.create(**data)
+
+
+class Provider(ABC):
+    """A base provider class to provide a common interface for all providers."""
+
+    def __init__(self, config: ProviderConfig, *args, **kwargs):
+        if config:
+            config.validate_config()
+        self.config = config

+ 39 - 0
core/base/providers/crypto.py

@@ -0,0 +1,39 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from .base import Provider, ProviderConfig
+
+
+class CryptoConfig(ProviderConfig):
+    provider: Optional[str] = None
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["bcrypt"]  # Add other crypto providers as needed
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Unsupported crypto provider: {self.provider}")
+
+
+class CryptoProvider(Provider, ABC):
+    def __init__(self, config: CryptoConfig):
+        if not isinstance(config, CryptoConfig):
+            raise ValueError(
+                "CryptoProvider must be initialized with a CryptoConfig"
+            )
+        super().__init__(config)
+
+    @abstractmethod
+    def get_password_hash(self, password: str) -> str:
+        pass
+
+    @abstractmethod
+    def verify_password(
+        self, plain_password: str, hashed_password: str
+    ) -> bool:
+        pass
+
+    @abstractmethod
+    def generate_verification_code(self, length: int = 32) -> str:
+        pass

+ 206 - 0
core/base/providers/database.py

@@ -0,0 +1,206 @@
+import logging
+from abc import abstractmethod
+from datetime import datetime
+from io import BytesIO
+from typing import BinaryIO, Optional, Tuple
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+    ChunkSearchResult,
+    Community,
+    DocumentResponse,
+    Entity,
+    IndexArgsHNSW,
+    IndexArgsIVFFlat,
+    IndexMeasure,
+    IndexMethod,
+    KGCreationSettings,
+    KGEnrichmentSettings,
+    KGEntityDeduplicationSettings,
+    Message,
+    Relationship,
+    SearchSettings,
+    User,
+    VectorEntry,
+    VectorTableName,
+)
+from core.base.api.models import CollectionResponse, GraphResponse
+
+from .base import Provider, ProviderConfig
+
+"""Base classes for knowledge graph providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, Tuple, Type
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from ..abstractions import (
+    Community,
+    Entity,
+    GraphSearchSettings,
+    KGCreationSettings,
+    KGEnrichmentSettings,
+    KGEntityDeduplicationSettings,
+    KGExtraction,
+    R2RSerializable,
+    Relationship,
+)
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+    @abstractmethod
+    def execute_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+        isolation_level: Optional[str] = None,
+    ):
+        pass
+
+    @abstractmethod
+    async def execute_many(self, query, params=None, batch_size=1000):
+        pass
+
+    @abstractmethod
+    def fetch_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+    ):
+        pass
+
+    @abstractmethod
+    def fetchrow_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+    ):
+        pass
+
+    @abstractmethod
+    async def initialize(self, pool: Any):
+        pass
+
+
+class Handler(ABC):
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: DatabaseConnectionManager,
+    ):
+        self.project_name = project_name
+        self.connection_manager = connection_manager
+
+    def _get_table_name(self, base_name: str) -> str:
+        return f"{self.project_name}.{base_name}"
+
+    @abstractmethod
+    def create_tables(self):
+        pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+    """
+    Configuration settings with defaults defined by the PGVector docker image.
+
+    These settings are helpful in managing the connections to the database.
+    To tune these settings for a specific deployment, see https://pgtune.leopard.in.ua/
+    """
+
+    checkpoint_completion_target: Optional[float] = 0.9
+    default_statistics_target: Optional[int] = 100
+    effective_io_concurrency: Optional[int] = 1
+    effective_cache_size: Optional[int] = 524288
+    huge_pages: Optional[str] = "try"
+    maintenance_work_mem: Optional[int] = 65536
+    max_connections: Optional[int] = 256
+    max_parallel_workers_per_gather: Optional[int] = 2
+    max_parallel_workers: Optional[int] = 8
+    max_parallel_maintenance_workers: Optional[int] = 2
+    max_wal_size: Optional[int] = 1024
+    max_worker_processes: Optional[int] = 8
+    min_wal_size: Optional[int] = 80
+    shared_buffers: Optional[int] = 16384
+    statement_cache_size: Optional[int] = 100
+    random_page_cost: Optional[float] = 4
+    wal_buffers: Optional[int] = 512
+    work_mem: Optional[int] = 4096
+
+
+class DatabaseConfig(ProviderConfig):
+    """A base database configuration class"""
+
+    provider: str = "postgres"
+    user: Optional[str] = None
+    password: Optional[str] = None
+    host: Optional[str] = None
+    port: Optional[int] = None
+    db_name: Optional[str] = None
+    project_name: Optional[str] = None
+    postgres_configuration_settings: Optional[
+        PostgresConfigurationSettings
+    ] = None
+    default_collection_name: str = "Default"
+    default_collection_description: str = "Your default collection."
+    collection_summary_system_prompt: str = "default_system"
+    collection_summary_task_prompt: str = "default_collection_summary"
+    enable_fts: bool = False
+
+    # KG settings
+    batch_size: Optional[int] = 1
+    kg_store_path: Optional[str] = None
+    graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings()
+    graph_creation_settings: KGCreationSettings = KGCreationSettings()
+    graph_entity_deduplication_settings: KGEntityDeduplicationSettings = (
+        KGEntityDeduplicationSettings()
+    )
+    graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+    def __post_init__(self):
+        self.validate_config()
+        # Capture additional fields
+        for key, value in self.extra_fields.items():
+            setattr(self, key, value)
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["postgres"]
+
+
+class DatabaseProvider(Provider):
+    connection_manager: DatabaseConnectionManager
+    # documents_handler: DocumentHandler
+    # collections_handler: CollectionsHandler
+    # token_handler: TokenHandler
+    # users_handler: UserHandler
+    # chunks_handler: ChunkHandler
+    # entity_handler: EntityHandler
+    # relationship_handler: RelationshipHandler
+    # graphs_handler: GraphHandler
+    # prompts_handler: PromptHandler
+    # files_handler: FileHandler
+    config: DatabaseConfig
+    project_name: str
+
+    def __init__(self, config: DatabaseConfig):
+        logger.info(f"Initializing DatabaseProvider with config {config}.")
+        super().__init__(config)
+
+    @abstractmethod
+    async def __aenter__(self):
+        pass
+
+    @abstractmethod
+    async def __aexit__(self, exc_type, exc, tb):
+        pass

+ 73 - 0
core/base/providers/email.py

@@ -0,0 +1,73 @@
+# email_provider.py
+import logging
+import os
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from .base import Provider, ProviderConfig
+
+
+class EmailConfig(ProviderConfig):
+    smtp_server: Optional[str] = None
+    smtp_port: Optional[int] = None
+    smtp_username: Optional[str] = None
+    smtp_password: Optional[str] = None
+    from_email: Optional[str] = None
+    use_tls: Optional[bool] = True
+    sendgrid_api_key: Optional[str] = None
+    verify_email_template_id: Optional[str] = None
+    reset_password_template_id: Optional[str] = None
+    frontend_url: Optional[str] = None
+    sender_name: Optional[str] = None
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return [
+            "smtp",
+            "console",
+            "sendgrid",
+        ]  # Could add more providers like AWS SES, SendGrid etc.
+
+    def validate_config(self) -> None:
+        if self.provider == "sendgrid":
+            if not (self.sendgrid_api_key or os.getenv("SENDGRID_API_KEY")):
+                raise ValueError(
+                    "SendGrid API key is required when using SendGrid provider"
+                )
+
+
+logger = logging.getLogger(__name__)
+
+
+class EmailProvider(Provider, ABC):
+    def __init__(self, config: EmailConfig):
+        if not isinstance(config, EmailConfig):
+            raise ValueError(
+                "EmailProvider must be initialized with an EmailConfig"
+            )
+        super().__init__(config)
+        self.config: EmailConfig = config  # for type hinting
+
+    @abstractmethod
+    async def send_email(
+        self,
+        to_email: str,
+        subject: str,
+        body: str,
+        html_body: Optional[str] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        pass
+
+    @abstractmethod
+    async def send_verification_email(
+        self, to_email: str, verification_code: str, *args, **kwargs
+    ) -> None:
+        pass
+
+    @abstractmethod
+    async def send_password_reset_email(
+        self, to_email: str, reset_token: str, *args, **kwargs
+    ) -> None:
+        pass

+ 197 - 0
core/base/providers/embedding.py

@@ -0,0 +1,197 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import VectorQuantizationSettings
+
+from ..abstractions import (
+    ChunkSearchResult,
+    EmbeddingPurpose,
+    default_embedding_prefixes,
+)
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class EmbeddingConfig(ProviderConfig):
+    provider: str
+    base_model: str
+    base_dimension: int
+    rerank_model: Optional[str] = None
+    rerank_url: Optional[str] = None
+    batch_size: int = 1
+    prefixes: Optional[dict[str, str]] = None
+    add_title_as_prefix: bool = True
+    concurrent_request_limit: int = 256
+    max_retries: int = 8
+    initial_backoff: float = 1
+    max_backoff: float = 64.0
+    quantization_settings: VectorQuantizationSettings = (
+        VectorQuantizationSettings()
+    )
+
+    ## deprecated
+    rerank_dimension: Optional[int] = None
+    rerank_transformer_type: Optional[str] = None
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["litellm", "openai", "ollama"]
+
+
+class EmbeddingProvider(Provider):
+    class PipeStage(Enum):
+        BASE = 1
+        RERANK = 2
+
+    def __init__(self, config: EmbeddingConfig):
+        if not isinstance(config, EmbeddingConfig):
+            raise ValueError(
+                "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
+            )
+        logger.info(f"Initializing EmbeddingProvider with config {config}.")
+
+        super().__init__(config)
+        self.config: EmbeddingConfig = config
+        self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+        self.current_requests = 0
+
+    async def _execute_with_backoff_async(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    return await self._execute_task(task)
+            except AuthenticationError as e:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                return self._execute_task_sync(task)
+            except AuthenticationError as e:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    @abstractmethod
+    async def _execute_task(self, task: dict[str, Any]):
+        pass
+
+    @abstractmethod
+    def _execute_task_sync(self, task: dict[str, Any]):
+        pass
+
+    async def async_get_embedding(
+        self,
+        text: str,
+        stage: PipeStage = PipeStage.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "text": text,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return await self._execute_with_backoff_async(task)
+
+    def get_embedding(
+        self,
+        text: str,
+        stage: PipeStage = PipeStage.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "text": text,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return self._execute_with_backoff_sync(task)
+
+    async def async_get_embeddings(
+        self,
+        texts: list[str],
+        stage: PipeStage = PipeStage.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ):
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return await self._execute_with_backoff_async(task)
+
+    def get_embeddings(
+        self,
+        texts: list[str],
+        stage: PipeStage = PipeStage.BASE,
+        purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
+    ) -> list[list[float]]:
+        task = {
+            "texts": texts,
+            "stage": stage,
+            "purpose": purpose,
+        }
+        return self._execute_with_backoff_sync(task)
+
+    @abstractmethod
+    def rerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: PipeStage = PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        pass
+
+    @abstractmethod
+    async def arerank(
+        self,
+        query: str,
+        results: list[ChunkSearchResult],
+        stage: PipeStage = PipeStage.RERANK,
+        limit: int = 10,
+    ):
+        pass
+
+    def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
+        self.prefixes = {}
+
+        for t, p in config_prefixes.items():
+            purpose = EmbeddingPurpose(t.lower())
+            self.prefixes[purpose] = p
+
+        if base_model in default_embedding_prefixes:
+            for t, p in default_embedding_prefixes[base_model].items():
+                if t not in self.prefixes:
+                    self.prefixes[t] = p

+ 204 - 0
core/base/providers/ingestion.py

@@ -0,0 +1,204 @@
+import logging
+from abc import ABC
+from enum import Enum
+from typing import TYPE_CHECKING, ClassVar
+
+from pydantic import BaseModel, Field
+
+from core.base.abstractions import ChunkEnrichmentSettings
+
+from .base import AppConfig, Provider, ProviderConfig
+from .llm import CompletionProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+    from core.database import PostgresDatabaseProvider
+
+
+class IngestionConfig(ProviderConfig):
+    _defaults: ClassVar[dict] = {
+        "app": AppConfig(),
+        "provider": "r2r",
+        "excluded_parsers": ["mp4"],
+        "chunk_enrichment_settings": ChunkEnrichmentSettings(),
+        "extra_parsers": {},
+        "audio_transcription_model": "openai/whisper-1",
+        "vision_img_prompt_name": "vision_img",
+        "vision_img_model": "openai/gpt-4o",
+        "vision_pdf_prompt_name": "vision_pdf",
+        "vision_pdf_model": "openai/gpt-4o",
+        "skip_document_summary": False,
+        "document_summary_system_prompt": "default_system",
+        "document_summary_task_prompt": "default_summary",
+        "chunks_for_document_summary": 128,
+        "document_summary_model": "openai/gpt-4o-mini",
+        "parser_overrides": {},
+        "extra_fields": {},
+    }
+
+    provider: str = Field(
+        default_factory=lambda: IngestionConfig._defaults["provider"]
+    )
+    excluded_parsers: list[str] = Field(
+        default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
+    )
+    chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "chunk_enrichment_settings"
+        ]
+    )
+    extra_parsers: dict[str, str] = Field(
+        default_factory=lambda: IngestionConfig._defaults["extra_parsers"]
+    )
+    audio_transcription_model: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "audio_transcription_model"
+        ]
+    )
+    vision_img_prompt_name: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "vision_img_prompt_name"
+        ]
+    )
+    vision_img_model: str = Field(
+        default_factory=lambda: IngestionConfig._defaults["vision_img_model"]
+    )
+    vision_pdf_prompt_name: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "vision_pdf_prompt_name"
+        ]
+    )
+    vision_pdf_model: str = Field(
+        default_factory=lambda: IngestionConfig._defaults["vision_pdf_model"]
+    )
+    skip_document_summary: bool = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "skip_document_summary"
+        ]
+    )
+    document_summary_system_prompt: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "document_summary_system_prompt"
+        ]
+    )
+    document_summary_task_prompt: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "document_summary_task_prompt"
+        ]
+    )
+    chunks_for_document_summary: int = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "chunks_for_document_summary"
+        ]
+    )
+    document_summary_model: str = Field(
+        default_factory=lambda: IngestionConfig._defaults[
+            "document_summary_model"
+        ]
+    )
+    parser_overrides: dict[str, str] = Field(
+        default_factory=lambda: IngestionConfig._defaults["parser_overrides"]
+    )
+
+    @classmethod
+    def set_default(cls, **kwargs):
+        for key, value in kwargs.items():
+            if key in cls._defaults:
+                cls._defaults[key] = value
+            else:
+                raise AttributeError(
+                    f"No default attribute '{key}' in IngestionConfig"
+                )
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["r2r", "unstructured_local", "unstructured_api"]
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider {self.provider} is not supported.")
+
+    @classmethod
+    def get_default(cls, mode: str, app) -> "IngestionConfig":
+        """Return default ingestion configuration for a given mode."""
+        if mode == "hi-res":
+            return cls(app=app, parser_overrides={"pdf": "zerox"})
+        else:
+            return cls(app=app)
+
+    @classmethod
+    def get_default(cls, mode: str, app) -> "IngestionConfig":
+        """Return default ingestion configuration for a given mode."""
+        if mode == "hi-res":
+            # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`.
+            return cls(app=app, parser_overrides={"pdf": "zerox"})
+        # elif mode == "fast":
+        #     # Skip summaries and other enrichment steps for speed.
+        #     return cls(
+        #         app=app,
+        #     )
+        else:
+            # For `custom` or any unrecognized mode, return a base config
+            return cls(app=app)
+
+    @classmethod
+    def set_default(cls, **kwargs):
+        for key, value in kwargs.items():
+            if key in cls._defaults:
+                cls._defaults[key] = value
+            else:
+                raise AttributeError(
+                    f"No default attribute '{key}' in GenerationConfig"
+                )
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "provider": "r2r",
+            "excluded_parsers": ["mp4"],
+            "chunk_enrichment_settings": ChunkEnrichmentSettings().dict(),
+            "extra_parsers": {},
+            "audio_transcription_model": "openai/whisper-1",
+            "vision_img_prompt_name": "vision_img",
+            "vision_img_model": "openai/gpt-4o",
+            "vision_pdf_prompt_name": "vision_pdf",
+            "vision_pdf_model": "openai/gpt-4o",
+            "skip_document_summary": False,
+            "document_summary_system_prompt": "default_system",
+            "document_summary_task_prompt": "default_summary",
+            "chunks_for_document_summary": 128,
+            "document_summary_model": "openai/gpt-4o-mini",
+            "parser_overrides": {},
+        }
+
+
+class IngestionProvider(Provider, ABC):
+
+    config: IngestionConfig
+    database_provider: "PostgresDatabaseProvider"
+    llm_provider: CompletionProvider
+
+    def __init__(
+        self,
+        config: IngestionConfig,
+        database_provider: "PostgresDatabaseProvider",
+        llm_provider: CompletionProvider,
+    ):
+        super().__init__(config)
+        self.config: IngestionConfig = config
+        self.llm_provider = llm_provider
+        self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+
+class ChunkingStrategy(str, Enum):
+    RECURSIVE = "recursive"
+    CHARACTER = "character"
+    BASIC = "basic"
+    BY_TITLE = "by_title"
+
+
+class IngestionMode(str, Enum):
+    hi_res = "hi-res"
+    fast = "fast"
+    custom = "custom"

+ 184 - 0
core/base/providers/llm.py

@@ -0,0 +1,184 @@
+import asyncio
+import logging
+import random
+import time
+from abc import abstractmethod
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from litellm import AuthenticationError
+
+from core.base.abstractions import (
+    GenerationConfig,
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class CompletionConfig(ProviderConfig):
+    provider: Optional[str] = None
+    generation_config: GenerationConfig = GenerationConfig()
+    concurrent_request_limit: int = 256
+    max_retries: int = 8
+    initial_backoff: float = 1.0
+    max_backoff: float = 64.0
+
+    def validate_config(self) -> None:
+        if not self.provider:
+            raise ValueError("Provider must be set.")
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["litellm", "openai"]
+
+
+class CompletionProvider(Provider):
+    def __init__(self, config: CompletionConfig) -> None:
+        if not isinstance(config, CompletionConfig):
+            raise ValueError(
+                "CompletionProvider must be initialized with a `CompletionConfig`."
+            )
+        logger.info(f"Initializing CompletionProvider with config: {config}")
+        super().__init__(config)
+        self.config: CompletionConfig = config
+        self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
+        self.thread_pool = ThreadPoolExecutor(
+            max_workers=config.concurrent_request_limit
+        )
+
+    async def _execute_with_backoff_async(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    return await self._execute_task(task)
+            except AuthenticationError as e:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    async def _execute_with_backoff_async_stream(
+        self, task: dict[str, Any]
+    ) -> AsyncGenerator[Any, None]:
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                async with self.semaphore:
+                    async for chunk in await self._execute_task(task):
+                        yield chunk
+                return  # Successful completion of the stream
+            except AuthenticationError as e:
+                raise
+            except Exception as e:
+                logger.warning(
+                    f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                await asyncio.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync(self, task: dict[str, Any]):
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                return self._execute_task_sync(task)
+            except Exception as e:
+                logger.warning(
+                    f"Request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    def _execute_with_backoff_sync_stream(
+        self, task: dict[str, Any]
+    ) -> Generator[Any, None, None]:
+        retries = 0
+        backoff = self.config.initial_backoff
+        while retries < self.config.max_retries:
+            try:
+                yield from self._execute_task_sync(task)
+                return  # Successful completion of the stream
+            except Exception as e:
+                logger.warning(
+                    f"Streaming request failed (attempt {retries + 1}): {str(e)}"
+                )
+                retries += 1
+                if retries == self.config.max_retries:
+                    raise
+                time.sleep(random.uniform(0, backoff))
+                backoff = min(backoff * 2, self.config.max_backoff)
+
+    @abstractmethod
+    async def _execute_task(self, task: dict[str, Any]):
+        pass
+
+    @abstractmethod
+    def _execute_task_sync(self, task: dict[str, Any]):
+        pass
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        if modalities := kwargs.get("modalities"):
+            task["modalities"] = modalities
+        response = await self._execute_with_backoff_async(task)
+        return LLMChatCompletion(**response.dict())
+
+    async def aget_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
+        generation_config.stream = True
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        async for chunk in self._execute_with_backoff_async_stream(task):
+            yield LLMChatCompletionChunk(**chunk.dict())
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        generation_config.stream = True
+        task = {
+            "messages": messages,
+            "generation_config": generation_config,
+            "kwargs": kwargs,
+        }
+        for chunk in self._execute_with_backoff_sync_stream(task):
+            yield LLMChatCompletionChunk(**chunk.dict())

+ 70 - 0
core/base/providers/orchestration.py

@@ -0,0 +1,70 @@
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+from .base import Provider, ProviderConfig
+
+
+class Workflow(Enum):
+    INGESTION = "ingestion"
+    KG = "kg"
+
+
+class OrchestrationConfig(ProviderConfig):
+    provider: str
+    max_runs: int = 2_048
+    kg_creation_concurrency_limit: int = 32
+    ingestion_concurrency_limit: int = 16
+    kg_concurrency_limit: int = 4
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider {self.provider} is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["hatchet", "simple"]
+
+
+class OrchestrationProvider(Provider):
+    def __init__(self, config: OrchestrationConfig):
+        super().__init__(config)
+        self.config = config
+        self.worker = None
+
+    @abstractmethod
+    async def start_worker(self):
+        pass
+
+    @abstractmethod
+    def get_worker(self, name: str, max_runs: int) -> Any:
+        pass
+
+    @abstractmethod
+    def step(self, *args, **kwargs) -> Any:
+        pass
+
+    @abstractmethod
+    def workflow(self, *args, **kwargs) -> Any:
+        pass
+
+    @abstractmethod
+    def failure(self, *args, **kwargs) -> Any:
+        pass
+
+    @abstractmethod
+    def register_workflows(
+        self, workflow: Workflow, service: Any, messages: dict
+    ) -> None:
+        pass
+
+    @abstractmethod
+    async def run_workflow(
+        self,
+        workflow_name: str,
+        parameters: dict,
+        options: dict,
+        *args,
+        **kwargs,
+    ) -> dict[str, str]:
+        pass

+ 41 - 0
core/base/utils/__init__.py

@@ -0,0 +1,41 @@
+from shared.utils import (
+    RecursiveCharacterTextSplitter,
+    TextSplitter,
+    _decorate_vector_type,
+    _get_str_estimation_output,
+    decrement_version,
+    format_search_results_for_llm,
+    format_search_results_for_stream,
+    generate_default_prompt_id,
+    generate_default_user_collection_id,
+    generate_document_id,
+    generate_extraction_id,
+    generate_id,
+    generate_user_id,
+    increment_version,
+    llm_cost_per_million_tokens,
+    run_pipeline,
+    to_async_generator,
+    validate_uuid,
+)
+
+__all__ = [
+    "format_search_results_for_stream",
+    "format_search_results_for_llm",
+    "generate_id",
+    "generate_default_user_collection_id",
+    "increment_version",
+    "decrement_version",
+    "run_pipeline",
+    "to_async_generator",
+    "generate_document_id",
+    "generate_extraction_id",
+    "generate_user_id",
+    "generate_default_prompt_id",
+    "RecursiveCharacterTextSplitter",
+    "TextSplitter",
+    "llm_cost_per_million_tokens",
+    "validate_uuid",
+    "_decorate_vector_type",
+    "_get_str_estimation_output",
+]

+ 25 - 0
core/configs/full.toml

@@ -0,0 +1,25 @@
+[completion]
+provider = "litellm"
+concurrent_request_limit = 128
+
+[database]
+  [database.graph_creation_settings]
+    clustering_mode = "remote"
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 2_048
+max_characters = 4_096
+combine_under_n_chars = 1_024
+overlap = 1_024
+
+    [ingestion.extra_parsers]
+    pdf = "zerox"
+
+[orchestration]
+provider = "hatchet"
+kg_creation_concurrency_limit = 32
+ingestion_concurrency_limit = 16
+kg_concurrency_limit = 8

+ 57 - 0
core/configs/full_azure.toml

@@ -0,0 +1,57 @@
+# A config which overrides all instances of `openai` with `azure` in the `r2r.toml` config
+[completion]
+provider = "litellm"
+concurrent_request_limit = 128
+
+  [completion.generation_config]
+  model = "azure/gpt-4o"
+
+[agent]
+  [agent.generation_config]
+  model = "azure/gpt-4o"
+
+[database]
+  [database.graph_creation_settings]
+    clustering_mode = "remote"
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_entity_deduplication_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_search_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+[embedding]
+provider = "litellm"
+base_model = "azure/text-embedding-3-small"
+base_dimension = 512
+
+[file]
+provider = "postgres"
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 2_048
+max_characters = 4_096
+combine_under_n_chars = 1_024
+overlap = 1_024
+document_summary_model = "azure/gpt-4o-mini"
+vision_img_model = "azure/gpt-4o"
+vision_pdf_model = "azure/gpt-4o"
+
+  [ingestion.extra_parsers]
+    pdf = "zerox"
+
+  [ingestion.chunk_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+[orchestration]
+provider = "hatchet"
+kg_creation_concurrency_limit = 32
+ingestion_concurrency_limit = 4
+kg_concurrency_limit = 8

+ 70 - 0
core/configs/full_local_llm.toml

@@ -0,0 +1,70 @@
+[agent]
+system_instruction_name = "rag_agent"
+tool_names = ["local_search"]
+
+  [agent.generation_config]
+  model = "ollama/llama3.1"
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+  [completion.generation_config]
+  model = "ollama/llama3.1"
+  temperature = 0.1
+  top_p = 1
+  max_tokens_to_sample = 1_024
+  stream = false
+  add_generation_kwargs = { }
+
+
+[database]
+provider = "postgres"
+
+  [database.graph_creation_settings]
+    clustering_mode = "remote"
+    graph_entity_description_prompt = "graphrag_entity_description"
+    entity_types = [] # if empty, all entities are extracted
+    relation_types = [] # if empty, all relations are extracted
+    fragment_merge_count = 4 # number of fragments to merge into a single extraction
+    max_knowledge_relationships = 100
+    max_description_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
+
+  [database.graph_entity_deduplication_settings]
+    graph_entity_deduplication_type = "by_name"
+    graph_entity_deduplication_prompt = "graphrag_entity_deduplication"
+    max_description_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication
+
+  [database.graph_enrichment_settings]
+    community_reports_prompt = "graphrag_community_reports"
+    max_summary_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for node description and graph clustering
+    leiden_params = {}
+
+  [database.graph_search_settings]
+    generation_config = { model = "ollama/llama3.1" }
+
+
+[embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[ingestion]
+provider = "unstructured_local"
+strategy = "auto"
+chunking_strategy = "by_title"
+new_after_n_chars = 512
+max_characters = 1_024
+combine_under_n_chars = 128
+overlap = 20
+chunks_for_document_summary = 16
+document_summary_model = "ollama/llama3.1"
+
+[orchestration]
+provider = "hatchet"

+ 68 - 0
core/configs/local_llm.toml

@@ -0,0 +1,68 @@
+[agent]
+system_instruction_name = "rag_agent"
+tool_names = ["local_search"]
+
+  [agent.generation_config]
+  model = "ollama/llama3.1"
+
+[completion]
+provider = "litellm"
+concurrent_request_limit = 1
+
+  [completion.generation_config]
+  model = "ollama/llama3.1"
+  temperature = 0.1
+  top_p = 1
+  max_tokens_to_sample = 1_024
+  stream = false
+  add_generation_kwargs = { }
+
+[embedding]
+provider = "ollama"
+base_model = "mxbai-embed-large"
+base_dimension = 1_024
+batch_size = 128
+add_title_as_prefix = true
+concurrent_request_limit = 2
+
+[database]
+provider = "postgres"
+
+  [database.graph_creation_settings]
+    graph_entity_description_prompt = "graphrag_entity_description"
+    entity_types = [] # if empty, all entities are extracted
+    relation_types = [] # if empty, all relations are extracted
+    fragment_merge_count = 4 # number of fragments to merge into a single extraction
+    max_knowledge_relationships = 100
+    max_description_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction
+
+  [database.graph_entity_deduplication_settings]
+    graph_entity_deduplication_type = "by_name"
+    graph_entity_deduplication_prompt = "graphrag_entity_deduplication"
+    max_description_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for deduplication
+
+  [database.graph_enrichment_settings]
+    community_reports_prompt = "graphrag_community_reports"
+    max_summary_input_length = 65536
+    generation_config = { model = "ollama/llama3.1" } # and other params, model used for node description and graph clustering
+    leiden_params = {}
+
+  [database.graph_search_settings]
+    generation_config = { model = "ollama/llama3.1" }
+
+
+[orchestration]
+provider = "simple"
+
+
+[ingestion]
+vision_img_model = "ollama/llama3.2-vision"
+vision_pdf_model = "ollama/llama3.2-vision"
+
+  [ingestion.extra_parsers]
+    pdf = "zerox"
+
+chunks_for_document_summary = 16
+document_summary_model = "ollama/llama3.1"

+ 46 - 0
core/configs/r2r_azure.toml

@@ -0,0 +1,46 @@
+# A config which overrides all instances of `openai` with `azure` in the `r2r.toml` config
+[agent]
+  [agent.generation_config]
+  model = "azure/gpt-4o"
+
+[completion]
+  [completion.generation_config]
+  model = "azure/gpt-4o"
+
+# KG settings
+batch_size = 256
+
+  [database.graph_creation_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_entity_deduplication_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+  [database.graph_search_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }
+
+[embedding]
+provider = "litellm"
+base_model = "openai/text-embedding-3-small" # continue with `openai` for embeddings, due to server rate limit on azure
+base_dimension = 512
+
+[file]
+provider = "postgres"
+
+[ingestion]
+provider = "r2r"
+chunking_strategy = "recursive"
+chunk_size = 1_024
+chunk_overlap = 512
+excluded_parsers = ["mp4"]
+
+audio_transcription_model="azure/whisper-1"
+document_summary_model = "azure/gpt-4o-mini"
+vision_img_model = "azure/gpt-4o"
+vision_pdf_model = "azure/gpt-4o"
+
+  [ingestion.chunk_enrichment_settings]
+    generation_config = { model = "azure/gpt-4o-mini" }

+ 8 - 0
core/configs/r2r_with_auth.toml

@@ -0,0 +1,8 @@
+[auth]
+provider = "r2r"
+access_token_lifetime_in_minutes = 60
+refresh_token_lifetime_in_days = 7
+require_authentication = true
+require_email_verification = false
+default_admin_email = "admin@example.com"
+default_admin_password = "change_me_immediately"

+ 5 - 0
core/database/__init__.py

@@ -0,0 +1,5 @@
+from .postgres import PostgresDatabaseProvider
+
+__all__ = [
+    "PostgresDatabaseProvider",
+]

+ 209 - 0
core/database/base.py

@@ -0,0 +1,209 @@
+import asyncio
+import logging
+import textwrap
+from contextlib import asynccontextmanager
+from typing import Optional
+
+import asyncpg
+
+from core.base.providers import DatabaseConnectionManager
+
+logger = logging.getLogger()
+
+
+class SemaphoreConnectionPool:
+    def __init__(self, connection_string, postgres_configuration_settings):
+        self.connection_string = connection_string
+        self.postgres_configuration_settings = postgres_configuration_settings
+
+    async def initialize(self):
+        try:
+            logger.info(
+                f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
+            )
+
+            self.semaphore = asyncio.Semaphore(
+                int(self.postgres_configuration_settings.max_connections * 0.9)
+            )
+
+            self.pool = await asyncpg.create_pool(
+                self.connection_string,
+                max_size=self.postgres_configuration_settings.max_connections,
+                statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
+            )
+
+            logger.info(
+                "Successfully connected to Postgres database and created connection pool."
+            )
+        except Exception as e:
+            raise ValueError(
+                f"Error {e} occurred while attempting to connect to relational database."
+            ) from e
+
+    @asynccontextmanager
+    async def get_connection(self):
+        async with self.semaphore:
+            async with self.pool.acquire() as conn:
+                yield conn
+
+    async def close(self):
+        await self.pool.close()
+
+
+class QueryBuilder:
+    def __init__(self, table_name: str):
+        self.table_name = table_name
+        self.conditions: list[str] = []
+        self.params: dict = {}
+        self.select_fields = "*"
+        self.operation = "SELECT"
+        self.limit_value: Optional[int] = None
+        self.insert_data: Optional[dict] = None
+
+    def select(self, fields: list[str]):
+        self.select_fields = ", ".join(fields)
+        return self
+
+    def insert(self, data: dict):
+        self.operation = "INSERT"
+        self.insert_data = data
+        return self
+
+    def delete(self):
+        self.operation = "DELETE"
+        return self
+
+    def where(self, condition: str, **kwargs):
+        self.conditions.append(condition)
+        self.params.update(kwargs)
+        return self
+
+    def limit(self, value: int):
+        self.limit_value = value
+        return self
+
+    def build(self):
+        if self.operation == "SELECT":
+            query = f"SELECT {self.select_fields} FROM {self.table_name}"
+        elif self.operation == "INSERT":
+            columns = ", ".join(self.insert_data.keys())
+            values = ", ".join(f":{key}" for key in self.insert_data.keys())
+            query = (
+                f"INSERT INTO {self.table_name} ({columns}) VALUES ({values})"
+            )
+            self.params.update(self.insert_data)
+        elif self.operation == "DELETE":
+            query = f"DELETE FROM {self.table_name}"
+        else:
+            raise ValueError(f"Unsupported operation: {self.operation}")
+
+        if self.conditions:
+            query += " WHERE " + " AND ".join(self.conditions)
+
+        if self.limit_value is not None and self.operation == "SELECT":
+            query += f" LIMIT {self.limit_value}"
+
+        return query, self.params
+
+
+class PostgresConnectionManager(DatabaseConnectionManager):
+
+    def __init__(self):
+        self.pool: Optional[SemaphoreConnectionPool] = None
+
+    async def initialize(self, pool: SemaphoreConnectionPool):
+        self.pool = pool
+
+    async def execute_query(self, query, params=None, isolation_level=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            if isolation_level:
+                async with conn.transaction(isolation=isolation_level):
+                    if params:
+                        return await conn.execute(query, *params)
+                    else:
+                        return await conn.execute(query)
+            else:
+                if params:
+                    return await conn.execute(query, *params)
+                else:
+                    return await conn.execute(query)
+
+    async def execute_many(self, query, params=None, batch_size=1000):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction():
+                if params:
+                    results = []
+                    for i in range(0, len(params), batch_size):
+                        param_batch = params[i : i + batch_size]
+                        result = await conn.executemany(query, param_batch)
+                        results.append(result)
+                    return results
+                else:
+                    return await conn.executemany(query)
+
+    async def fetch_query(self, query, params=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        try:
+            async with self.pool.get_connection() as conn:
+                async with conn.transaction():
+                    return (
+                        await conn.fetch(query, *params)
+                        if params
+                        else await conn.fetch(query)
+                    )
+        except asyncpg.exceptions.DuplicatePreparedStatementError:
+            error_msg = textwrap.dedent(
+                """
+                Database Configuration Error
+
+                Your database provider does not support statement caching.
+
+                To fix this, either:
+                • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
+                • Add statement_cache_size = 0 to your database configuration:
+
+                    [database.postgres_configuration_settings]
+                    statement_cache_size = 0
+
+                This is required when using connection poolers like PgBouncer or
+                managed database services like Supabase.
+            """
+            ).strip()
+            raise ValueError(error_msg) from None
+
+    async def fetchrow_query(self, query, params=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction():
+                if params:
+                    return await conn.fetchrow(query, *params)
+                else:
+                    return await conn.fetchrow(query)
+
+    @asynccontextmanager
+    async def transaction(self, isolation_level=None):
+        """
+        Async context manager for database transactions.
+
+        Args:
+            isolation_level: Optional isolation level for the transaction
+
+        Yields:
+            The connection manager instance for use within the transaction
+        """
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction(isolation=isolation_level):
+                try:
+                    yield self
+                except Exception as e:
+                    logger.error(f"Transaction failed: {str(e)}")
+                    raise

+ 1488 - 0
core/database/chunks.py

@@ -0,0 +1,1488 @@
+import copy
+import json
+import logging
+import time
+import uuid
+from typing import Any, Optional, TypedDict
+from uuid import UUID
+
+import numpy as np
+
+from core.base import (
+    ChunkSearchResult,
+    Handler,
+    IndexArgsHNSW,
+    IndexArgsIVFFlat,
+    IndexMeasure,
+    IndexMethod,
+    R2RException,
+    SearchSettings,
+    VectorEntry,
+    VectorQuantizationType,
+    VectorTableName,
+)
+
+from .base import PostgresConnectionManager
+from .vecs.exc import ArgError, FilterError
+
+logger = logging.getLogger()
+from core.base.utils import _decorate_vector_type
+
+
+def psql_quote_literal(value: str) -> str:
+    """
+    Safely quote a string literal for PostgreSQL to prevent SQL injection.
+    This is a simple implementation - in production, you should use proper parameterization
+    or your database driver's quoting functions.
+    """
+    return "'" + value.replace("'", "''") + "'"
+
+
+def index_measure_to_ops(
+    measure: IndexMeasure,
+    quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+):
+    return _decorate_vector_type(measure.ops, quantization_type)
+
+
+def quantize_vector_to_binary(
+    vector: list[float] | np.ndarray,
+    threshold: float = 0.0,
+) -> bytes:
+    """
+    Quantizes a float vector to a binary vector string for PostgreSQL bit type.
+    Used when quantization_type is INT1.
+
+    Args:
+        vector (List[float] | np.ndarray): Input vector of floats
+        threshold (float, optional): Threshold for binarization. Defaults to 0.0.
+
+    Returns:
+        str: Binary string representation for PostgreSQL bit type
+    """
+    # Convert input to numpy array if it isn't already
+    if not isinstance(vector, np.ndarray):
+        vector = np.array(vector)
+
+    # Convert to binary (1 where value > threshold, 0 otherwise)
+    binary_vector = (vector > threshold).astype(int)
+
+    # Convert to string of 1s and 0s
+    # Convert to string of 1s and 0s, then to bytes
+    binary_string = "".join(map(str, binary_vector))
+    return binary_string.encode("ascii")
+
+
+class HybridSearchIntermediateResult(TypedDict):
+    semantic_rank: int
+    full_text_rank: int
+    data: ChunkSearchResult
+    rrf_score: float
+
+
+class PostgresChunksHandler(Handler):
+    TABLE_NAME = VectorTableName.CHUNKS
+
+    COLUMN_VARS = [
+        "id",
+        "document_id",
+        "owner_id",
+        "collection_ids",
+    ]
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        dimension: int,
+        quantization_type: VectorQuantizationType,
+    ):
+        super().__init__(project_name, connection_manager)
+        self.dimension = dimension
+        self.quantization_type = quantization_type
+
+    async def create_tables(self):
+        # Check for old table name first
+        check_query = """
+        SELECT EXISTS (
+            SELECT FROM pg_tables
+            WHERE schemaname = $1
+            AND tablename = $2
+        );
+        """
+        old_table_exists = await self.connection_manager.fetch_query(
+            check_query, (self.project_name, self.project_name)
+        )
+
+        if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
+            raise ValueError(
+                f"Found old vector table '{self.project_name}.{self.project_name}'. "
+                "Please run `r2r db upgrade` with the CLI, or to run manually, "
+                "run in R2R/py/migrations with 'alembic upgrade head' to update "
+                "your database schema to the new version."
+            )
+
+        binary_col = (
+            ""
+            if self.quantization_type != VectorQuantizationType.INT1
+            else f"vec_binary bit({self.dimension}),"
+        )
+
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
+            id UUID PRIMARY KEY,
+            document_id UUID,
+            owner_id UUID,
+            collection_ids UUID[],
+            vec vector({self.dimension}),
+            {binary_col}
+            text TEXT,
+            metadata JSONB,
+            fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
+        );
+        CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
+        CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
+        CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
+        CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
+        """
+
+        await self.connection_manager.execute_query(query)
+
+    async def upsert(self, entry: VectorEntry) -> None:
+        """
+        Upsert function that handles vector quantization only when quantization_type is INT1.
+        Matches the table schema where vec_binary column only exists for INT1 quantization.
+        """
+        # Check the quantization type to determine which columns to use
+        if self.quantization_type == VectorQuantizationType.INT1:
+            # For quantized vectors, use vec_binary column
+            query = f"""
+            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+            (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
+            ON CONFLICT (id) DO UPDATE SET
+            document_id = EXCLUDED.document_id,
+            owner_id = EXCLUDED.owner_id,
+            collection_ids = EXCLUDED.collection_ids,
+            vec = EXCLUDED.vec,
+            vec_binary = EXCLUDED.vec_binary,
+            text = EXCLUDED.text,
+            metadata = EXCLUDED.metadata;
+            """
+            await self.connection_manager.execute_query(
+                query,
+                (
+                    entry.id,
+                    entry.document_id,
+                    entry.owner_id,
+                    entry.collection_ids,
+                    str(entry.vector.data),
+                    quantize_vector_to_binary(
+                        entry.vector.data
+                    ),  # Convert to binary
+                    entry.text,
+                    json.dumps(entry.metadata),
+                ),
+            )
+        else:
+            # For regular vectors, use vec column only
+            query = f"""
+            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+            (id, document_id, owner_id, collection_ids, vec, text, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            ON CONFLICT (id) DO UPDATE SET
+            document_id = EXCLUDED.document_id,
+            owner_id = EXCLUDED.owner_id,
+            collection_ids = EXCLUDED.collection_ids,
+            vec = EXCLUDED.vec,
+            text = EXCLUDED.text,
+            metadata = EXCLUDED.metadata;
+            """
+
+            await self.connection_manager.execute_query(
+                query,
+                (
+                    entry.id,
+                    entry.document_id,
+                    entry.owner_id,
+                    entry.collection_ids,
+                    str(entry.vector.data),
+                    entry.text,
+                    json.dumps(entry.metadata),
+                ),
+            )
+
+    async def upsert_entries(self, entries: list[VectorEntry]) -> None:
+        """
+        Batch upsert function that handles vector quantization only when quantization_type is INT1.
+        Matches the table schema where vec_binary column only exists for INT1 quantization.
+        """
+        if self.quantization_type == VectorQuantizationType.INT1:
+            # For quantized vectors, use vec_binary column
+            query = f"""
+            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+            (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6::bit({self.dimension}), $7, $8)
+            ON CONFLICT (id) DO UPDATE SET
+            document_id = EXCLUDED.document_id,
+            owner_id = EXCLUDED.owner_id,
+            collection_ids = EXCLUDED.collection_ids,
+            vec = EXCLUDED.vec,
+            vec_binary = EXCLUDED.vec_binary,
+            text = EXCLUDED.text,
+            metadata = EXCLUDED.metadata;
+            """
+            bin_params = [
+                (
+                    entry.id,
+                    entry.document_id,
+                    entry.owner_id,
+                    entry.collection_ids,
+                    str(entry.vector.data),
+                    quantize_vector_to_binary(
+                        entry.vector.data
+                    ),  # Convert to binary
+                    entry.text,
+                    json.dumps(entry.metadata),
+                )
+                for entry in entries
+            ]
+            await self.connection_manager.execute_many(query, bin_params)
+
+        else:
+            # For regular vectors, use vec column only
+            query = f"""
+            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+            (id, document_id, owner_id, collection_ids, vec, text, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            ON CONFLICT (id) DO UPDATE SET
+            document_id = EXCLUDED.document_id,
+            owner_id = EXCLUDED.owner_id,
+            collection_ids = EXCLUDED.collection_ids,
+            vec = EXCLUDED.vec,
+            text = EXCLUDED.text,
+            metadata = EXCLUDED.metadata;
+            """
+            params = [
+                (
+                    entry.id,
+                    entry.document_id,
+                    entry.owner_id,
+                    entry.collection_ids,
+                    str(entry.vector.data),
+                    entry.text,
+                    json.dumps(entry.metadata),
+                )
+                for entry in entries
+            ]
+
+            await self.connection_manager.execute_many(query, params)
+
+    async def semantic_search(
+        self, query_vector: list[float], search_settings: SearchSettings
+    ) -> list[ChunkSearchResult]:
+        try:
+            imeasure_obj = IndexMeasure(
+                search_settings.chunk_settings.index_measure
+            )
+        except ValueError:
+            raise ValueError("Invalid index measure")
+
+        table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
+        cols = [
+            f"{table_name}.id",
+            f"{table_name}.document_id",
+            f"{table_name}.owner_id",
+            f"{table_name}.collection_ids",
+            f"{table_name}.text",
+        ]
+
+        params: list[str | int | bytes] = []
+        # For binary vectors (INT1), implement two-stage search
+        if self.quantization_type == VectorQuantizationType.INT1:
+            # Convert query vector to binary format
+            binary_query = quantize_vector_to_binary(query_vector)
+            # TODO - Put depth multiplier in config / settings
+            extended_limit = (
+                search_settings.limit * 20
+            )  # Get 20x candidates for re-ranking
+            if (
+                imeasure_obj == IndexMeasure.hamming_distance
+                or imeasure_obj == IndexMeasure.jaccard_distance
+            ):
+                binary_search_measure_repr = imeasure_obj.pgvector_repr
+            else:
+                binary_search_measure_repr = (
+                    IndexMeasure.hamming_distance.pgvector_repr
+                )
+
+            # Use binary column and binary-specific distance measures for first stage
+            stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit({self.dimension})"
+            stage1_param = binary_query
+
+            cols.append(
+                f"{table_name}.vec"
+            )  # Need original vector for re-ranking
+            if search_settings.include_metadatas:
+                cols.append(f"{table_name}.metadata")
+
+            select_clause = ", ".join(cols)
+            where_clause = ""
+            params.append(stage1_param)
+
+            if search_settings.filters:
+                where_clause = self._build_filters(
+                    search_settings.filters, params
+                )
+                where_clause = f"WHERE {where_clause}"
+
+            # First stage: Get candidates using binary search
+            query = f"""
+            WITH candidates AS (
+                SELECT {select_clause},
+                    ({stage1_distance}) as binary_distance
+                FROM {table_name}
+                {where_clause}
+                ORDER BY {stage1_distance}
+                LIMIT ${len(params) + 1}
+                OFFSET ${len(params) + 2}
+            )
+            -- Second stage: Re-rank using original vectors
+            SELECT
+                id,
+                document_id,
+                owner_id,
+                collection_ids,
+                text,
+                {"metadata," if search_settings.include_metadatas else ""}
+                (vec <=> ${len(params) + 4}::vector({self.dimension})) as distance
+            FROM candidates
+            ORDER BY distance
+            LIMIT ${len(params) + 3}
+            """
+
+            params.extend(
+                [
+                    extended_limit,  # First stage limit
+                    search_settings.offset,
+                    search_settings.limit,  # Final limit
+                    str(query_vector),  # For re-ranking
+                ]
+            )
+
+        else:
+            # Standard float vector handling - unchanged from original
+            distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector({self.dimension})"
+            query_param = str(query_vector)
+
+            if search_settings.include_scores:
+                cols.append(f"({distance_calc}) AS distance")
+            if search_settings.include_metadatas:
+                cols.append(f"{table_name}.metadata")
+
+            select_clause = ", ".join(cols)
+            where_clause = ""
+            params.append(query_param)
+
+            if search_settings.filters:
+                where_clause = self._build_filters(
+                    search_settings.filters, params
+                )
+                where_clause = f"WHERE {where_clause}"
+
+            query = f"""
+            SELECT {select_clause}
+            FROM {table_name}
+            {where_clause}
+            ORDER BY {distance_calc}
+            LIMIT ${len(params) + 1}
+            OFFSET ${len(params) + 2}
+            """
+            params.extend([search_settings.limit, search_settings.offset])
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        return [
+            ChunkSearchResult(
+                id=UUID(str(result["id"])),
+                document_id=UUID(str(result["document_id"])),
+                owner_id=UUID(str(result["owner_id"])),
+                collection_ids=result["collection_ids"],
+                text=result["text"],
+                score=(
+                    (1 - float(result["distance"]))
+                    if "distance" in result
+                    else -1
+                ),
+                metadata=(
+                    json.loads(result["metadata"])
+                    if search_settings.include_metadatas
+                    else {}
+                ),
+            )
+            for result in results
+        ]
+
+    async def full_text_search(
+        self, query_text: str, search_settings: SearchSettings
+    ) -> list[ChunkSearchResult]:
+
+        where_clauses = []
+        params: list[str | int | bytes] = [query_text]
+
+        if search_settings.filters:
+            filters_clause = self._build_filters(
+                search_settings.filters, params
+            )
+            where_clauses.append(filters_clause)
+
+        if where_clauses:
+            where_clause = (
+                "WHERE "
+                + " AND ".join(where_clauses)
+                + " AND fts @@ websearch_to_tsquery('english', $1)"
+            )
+        else:
+            where_clause = "WHERE fts @@ websearch_to_tsquery('english', $1)"
+
+        query = f"""
+            SELECT
+                id, document_id, owner_id, collection_ids, text, metadata,
+                ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
+            FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+            {where_clause}
+        """
+
+        query += f"""
+            ORDER BY rank DESC
+            OFFSET ${len(params)+1} LIMIT ${len(params)+2}
+        """
+        params.extend(
+            [
+                search_settings.offset,
+                search_settings.hybrid_settings.full_text_limit,
+            ]
+        )
+
+        results = await self.connection_manager.fetch_query(query, params)
+        return [
+            ChunkSearchResult(
+                id=UUID(str(r["id"])),
+                document_id=UUID(str(r["document_id"])),
+                owner_id=UUID(str(r["owner_id"])),
+                collection_ids=r["collection_ids"],
+                text=r["text"],
+                score=float(r["rank"]),
+                metadata=json.loads(r["metadata"]),
+            )
+            for r in results
+        ]
+
+    async def hybrid_search(
+        self,
+        query_text: str,
+        query_vector: list[float],
+        search_settings: SearchSettings,
+        *args,
+        **kwargs,
+    ) -> list[ChunkSearchResult]:
+        if search_settings.hybrid_settings is None:
+            raise ValueError(
+                "Please provide a valid `hybrid_settings` in the `search_settings`."
+            )
+        if (
+            search_settings.hybrid_settings.full_text_limit
+            < search_settings.limit
+        ):
+            raise ValueError(
+                "The `full_text_limit` must be greater than or equal to the `limit`."
+            )
+
+        semantic_settings = copy.deepcopy(search_settings)
+        semantic_settings.limit += search_settings.offset
+
+        full_text_settings = copy.deepcopy(search_settings)
+        full_text_settings.hybrid_settings.full_text_limit += (
+            search_settings.offset
+        )
+
+        semantic_results: list[ChunkSearchResult] = await self.semantic_search(
+            query_vector, semantic_settings
+        )
+        full_text_results: list[ChunkSearchResult] = (
+            await self.full_text_search(query_text, full_text_settings)
+        )
+
+        semantic_limit = search_settings.limit
+        full_text_limit = search_settings.hybrid_settings.full_text_limit
+        semantic_weight = search_settings.hybrid_settings.semantic_weight
+        full_text_weight = search_settings.hybrid_settings.full_text_weight
+        rrf_k = search_settings.hybrid_settings.rrf_k
+
+        combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
+
+        for rank, result in enumerate(semantic_results, 1):
+            combined_results[result.id] = {
+                "semantic_rank": rank,
+                "full_text_rank": full_text_limit,
+                "data": result,
+                "rrf_score": 0.0,  # Initialize with 0, will be calculated later
+            }
+
+        for rank, result in enumerate(full_text_results, 1):
+            if result.id in combined_results:
+                combined_results[result.id]["full_text_rank"] = rank
+            else:
+                combined_results[result.id] = {
+                    "semantic_rank": semantic_limit,
+                    "full_text_rank": rank,
+                    "data": result,
+                    "rrf_score": 0.0,  # Initialize with 0, will be calculated later
+                }
+
+        combined_results = {
+            k: v
+            for k, v in combined_results.items()
+            if v["semantic_rank"] <= semantic_limit * 2
+            and v["full_text_rank"] <= full_text_limit * 2
+        }
+
+        for hyb_result in combined_results.values():
+            semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
+            full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
+            hyb_result["rrf_score"] = (
+                semantic_score * semantic_weight
+                + full_text_score * full_text_weight
+            ) / (semantic_weight + full_text_weight)
+
+        sorted_results = sorted(
+            combined_results.values(),
+            key=lambda x: x["rrf_score"],
+            reverse=True,
+        )
+        offset_results = sorted_results[
+            search_settings.offset : search_settings.offset
+            + search_settings.limit
+        ]
+
+        return [
+            ChunkSearchResult(
+                id=result["data"].id,
+                document_id=result["data"].document_id,
+                owner_id=result["data"].owner_id,
+                collection_ids=result["data"].collection_ids,
+                text=result["data"].text,
+                score=result["rrf_score"],
+                metadata={
+                    **result["data"].metadata,
+                    "semantic_rank": result["semantic_rank"],
+                    "full_text_rank": result["full_text_rank"],
+                },
+            )
+            for result in offset_results
+        ]
+
+    async def delete(
+        self, filters: dict[str, Any]
+    ) -> dict[str, dict[str, str]]:
+        params: list[str | int | bytes] = []
+        where_clause = self._build_filters(filters, params)
+
+        query = f"""
+        DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        WHERE {where_clause}
+        RETURNING id, document_id, text;
+        """
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        return {
+            str(result["id"]): {
+                "status": "deleted",
+                "id": str(result["id"]),
+                "document_id": str(result["document_id"]),
+                "text": result["text"],
+            }
+            for result in results
+        }
+
+    async def assign_document_chunks_to_collection(
+        self, document_id: UUID, collection_id: UUID
+    ) -> None:
+        query = f"""
+        UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        SET collection_ids = array_append(collection_ids, $1)
+        WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
+        """
+        result = await self.connection_manager.execute_query(
+            query, (str(collection_id), str(document_id))
+        )
+        return result
+
+    async def remove_document_from_collection_vector(
+        self, document_id: UUID, collection_id: UUID
+    ) -> None:
+        query = f"""
+        UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        SET collection_ids = array_remove(collection_ids, $1)
+        WHERE document_id = $2;
+        """
+        await self.connection_manager.execute_query(
+            query, (collection_id, document_id)
+        )
+
+    async def delete_user_vector(self, owner_id: UUID) -> None:
+        query = f"""
+        DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        WHERE owner_id = $1;
+        """
+        await self.connection_manager.execute_query(query, (owner_id,))
+
+    async def delete_collection_vector(self, collection_id: UUID) -> None:
+        query = f"""
+         DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+         WHERE $1 = ANY(collection_ids)
+         RETURNING collection_ids
+         """
+        results = await self.connection_manager.fetchrow_query(
+            query, (collection_id,)
+        )
+        return None
+
+    async def list_document_chunks(
+        self,
+        document_id: UUID,
+        offset: int,
+        limit: int,
+        include_vectors: bool = False,
+    ) -> dict[str, Any]:
+        vector_select = ", vec" if include_vectors else ""
+        limit_clause = f"LIMIT {limit}" if limit > -1 else ""
+
+        query = f"""
+        SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
+        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        WHERE document_id = $1
+        ORDER BY (metadata->>'chunk_order')::integer
+        OFFSET $2
+        {limit_clause};
+        """
+
+        params = [document_id, offset]
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        chunks = []
+        total = 0
+        if results:
+            total = results[0].get("total", 0)
+            chunks = [
+                {
+                    "id": result["id"],
+                    "document_id": result["document_id"],
+                    "owner_id": result["owner_id"],
+                    "collection_ids": result["collection_ids"],
+                    "text": result["text"],
+                    "metadata": json.loads(result["metadata"]),
+                    "vector": (
+                        json.loads(result["vec"]) if include_vectors else None
+                    ),
+                }
+                for result in results
+            ]
+
+        return {"results": chunks, "total_entries": total}
+
+    async def get_chunk(self, id: UUID) -> dict:
+        query = f"""
+        SELECT id, document_id, owner_id, collection_ids, text, metadata
+        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        WHERE id = $1;
+        """
+
+        result = await self.connection_manager.fetchrow_query(query, (id,))
+
+        if result:
+            return {
+                "id": result["id"],
+                "document_id": result["document_id"],
+                "owner_id": result["owner_id"],
+                "collection_ids": result["collection_ids"],
+                "text": result["text"],
+                "metadata": json.loads(result["metadata"]),
+            }
+        raise R2RException(
+            message=f"Chunk with ID {id} not found", status_code=404
+        )
+
+    async def create_index(
+        self,
+        table_name: Optional[VectorTableName] = None,
+        index_measure: IndexMeasure = IndexMeasure.cosine_distance,
+        index_method: IndexMethod = IndexMethod.auto,
+        index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
+        index_name: Optional[str] = None,
+        index_column: Optional[str] = None,
+        concurrently: bool = True,
+    ) -> None:
+        """
+        Creates an index for the collection.
+
+        Note:
+            When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
+            process that enables performant indexes to be built for large collections with low end
+            database hardware.
+
+            Those steps are:
+
+            - Creates a new table with a different name
+            - Randomly selects records from the existing table
+            - Inserts the random records from the existing table into the new table
+            - Creates the requested vector index on the new table
+            - Upserts all data from the existing table into the new table
+            - Drops the existing table
+            - Renames the new table to the existing tables name
+
+            If you create dependencies (like views) on the table that underpins
+            a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
+            it will succeed.
+
+        Args:
+            index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
+            index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
+            index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
+            index_name (str, optional): The name of the index to create. Defaults to None.
+            concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
+        Raises:
+            ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
+        """
+
+        if table_name == VectorTableName.CHUNKS:
+            table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"  # TODO - Fix bug in vector table naming convention
+            if index_column:
+                col_name = index_column
+            else:
+                col_name = (
+                    "vec"
+                    if (
+                        index_measure != IndexMeasure.hamming_distance
+                        and index_measure != IndexMeasure.jaccard_distance
+                    )
+                    else "vec_binary"
+                )
+        elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+            )
+            col_name = "description_embedding"
+        elif table_name == VectorTableName.GRAPHS_ENTITIES:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+            )
+            col_name = "description_embedding"
+        elif table_name == VectorTableName.COMMUNITIES:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+            )
+            col_name = "embedding"
+        else:
+            raise ArgError("invalid table name")
+
+        if index_method not in (
+            IndexMethod.ivfflat,
+            IndexMethod.hnsw,
+            IndexMethod.auto,
+        ):
+            raise ArgError("invalid index method")
+
+        if index_arguments:
+            # Disallow case where user submits index arguments but uses the
+            # IndexMethod.auto index (index build arguments should only be
+            # used with a specific index)
+            if index_method == IndexMethod.auto:
+                raise ArgError(
+                    "Index build parameters are not allowed when using the IndexMethod.auto index."
+                )
+            # Disallow case where user specifies one index type but submits
+            # index build arguments for the other index type
+            if (
+                isinstance(index_arguments, IndexArgsHNSW)
+                and index_method != IndexMethod.hnsw
+            ) or (
+                isinstance(index_arguments, IndexArgsIVFFlat)
+                and index_method != IndexMethod.ivfflat
+            ):
+                raise ArgError(
+                    f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
+                )
+
+        if index_method == IndexMethod.auto:
+            index_method = IndexMethod.hnsw
+
+        ops = index_measure_to_ops(
+            index_measure  # , quantization_type=self.quantization_type
+        )
+
+        if ops is None:
+            raise ArgError("Unknown index measure")
+
+        concurrently_sql = "CONCURRENTLY" if concurrently else ""
+
+        index_name = (
+            index_name
+            or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
+        )
+
+        create_index_sql = f"""
+        CREATE INDEX {concurrently_sql} {index_name}
+        ON {table_name_str}
+        USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
+        """
+
+        try:
+            if concurrently:
+                async with (
+                    self.connection_manager.pool.get_connection() as conn  # type: ignore
+                ):
+                    # Disable automatic transaction management
+                    await conn.execute(
+                        "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+                    )
+                    await conn.execute(create_index_sql)
+            else:
+                # Non-concurrent index creation can use normal query execution
+                await self.connection_manager.execute_query(create_index_sql)
+        except Exception as e:
+            raise Exception(f"Failed to create index: {e}")
+        return None
+
+    def _build_filters(
+        self, filters: dict, parameters: list[str | int | bytes]
+    ) -> str:
+
+        def parse_condition(key: str, value: Any) -> str:  # type: ignore
+            # nonlocal parameters
+            if key in self.COLUMN_VARS:
+                # Handle column-based filters
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        parameters.append(clause)
+                        return f"{key} = ${len(parameters)}"
+                    elif op == "$ne":
+                        parameters.append(clause)
+                        return f"{key} != ${len(parameters)}"
+                    elif op == "$in":
+                        parameters.append(clause)
+                        return f"{key} = ANY(${len(parameters)})"
+                    elif op == "$nin":
+                        parameters.append(clause)
+                        return f"{key} != ALL(${len(parameters)})"
+                    elif op == "$overlap":
+                        parameters.append(clause)
+                        return f"{key} && ${len(parameters)}"
+                    elif op == "$contains":
+                        parameters.append(clause)
+                        return f"{key} @> ${len(parameters)}"
+                    elif op == "$any":
+                        if key == "collection_ids":
+                            parameters.append(f"%{clause}%")
+                            return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
+                        parameters.append(clause)
+                        return f"${len(parameters)} = ANY({key})"
+                    else:
+                        raise FilterError(
+                            f"Unsupported operator for column {key}: {op}"
+                        )
+                else:
+                    # Handle direct equality
+                    parameters.append(value)
+                    return f"{key} = ${len(parameters)}"
+            else:
+                # Handle JSON-based filters
+                json_col = "metadata"
+                if key.startswith("metadata."):
+                    key = key.split("metadata.")[1]
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op not in (
+                        "$eq",
+                        "$ne",
+                        "$lt",
+                        "$lte",
+                        "$gt",
+                        "$gte",
+                        "$in",
+                        "$contains",
+                    ):
+                        raise FilterError("unknown operator")
+
+                    if op == "$eq":
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
+                        )
+                    elif op == "$ne":
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
+                        )
+                    elif op == "$lt":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
+                    elif op == "$lte":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
+                    elif op == "$gt":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
+                    elif op == "$gte":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
+                    elif op == "$in":
+                        # Ensure clause is a list
+                        if not isinstance(clause, list):
+                            raise FilterError(
+                                "argument to $in filter must be a list"
+                            )
+                        # Append the Python list as a parameter; many drivers can convert Python lists to arrays
+                        parameters.append(clause)
+                        # Cast the parameter to a text array type
+                        return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
+
+                    # elif op == "$in":
+                    #     if not isinstance(clause, list):
+                    #         raise FilterError(
+                    #             "argument to $in filter must be a list"
+                    #         )
+                    #     parameters.append(json.dumps(clause))
+                    #     return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
+                    elif op == "$contains":
+                        if isinstance(clause, (int, float, str)):
+                            clause = [clause]
+                        # Now clause is guaranteed to be a list or array-like structure.
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
+                        )
+
+                        # if not isinstance(clause, (int, str, float, list)):
+                        #     raise FilterError(
+                        #         "argument to $contains filter must be a scalar or array"
+                        #     )
+                        # parameters.append(json.dumps(clause))
+                        # return (
+                        #     f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
+                        # )
+
+        def parse_filter(filter_dict: dict) -> str:
+            filter_conditions = []
+            for key, value in filter_dict.items():
+                if key == "$and":
+                    and_conditions = [
+                        parse_filter(f) for f in value if f
+                    ]  # Skip empty dictionaries
+                    if and_conditions:
+                        filter_conditions.append(
+                            f"({' AND '.join(and_conditions)})"
+                        )
+                elif key == "$or":
+                    or_conditions = [
+                        parse_filter(f) for f in value if f
+                    ]  # Skip empty dictionaries
+                    if or_conditions:
+                        filter_conditions.append(
+                            f"({' OR '.join(or_conditions)})"
+                        )
+                else:
+                    filter_conditions.append(parse_condition(key, value))
+
+            # Check if there is only a single condition
+            if len(filter_conditions) == 1:
+                return filter_conditions[0]
+            else:
+                return " AND ".join(filter_conditions)
+
+        where_clause = parse_filter(filters)
+
+        return where_clause
+
+    async def list_indices(
+        self,
+        offset: int,
+        limit: int,
+        filters: Optional[dict[str, Any]] = None,
+    ) -> dict:
+        where_clauses = []
+        params: list[Any] = [self.project_name]  # Start with schema name
+        param_count = 1
+
+        # Handle filtering
+        if filters:
+            if "table_name" in filters:
+                where_clauses.append(f"i.tablename = ${param_count + 1}")
+                params.append(filters["table_name"])
+                param_count += 1
+            if "index_method" in filters:
+                where_clauses.append(f"am.amname = ${param_count + 1}")
+                params.append(filters["index_method"])
+                param_count += 1
+            if "index_name" in filters:
+                where_clauses.append(
+                    f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
+                )
+                params.append(f"%{filters['index_name']}%")
+                param_count += 1
+
+        where_clause = " AND ".join(where_clauses) if where_clauses else ""
+        if where_clause:
+            where_clause = "AND " + where_clause
+
+        query = f"""
+        WITH index_info AS (
+            SELECT
+                i.indexname as name,
+                i.tablename as table_name,
+                i.indexdef as definition,
+                am.amname as method,
+                pg_relation_size(c.oid) as size_in_bytes,
+                c.reltuples::bigint as row_estimate,
+                COALESCE(psat.idx_scan, 0) as number_of_scans,
+                COALESCE(psat.idx_tup_read, 0) as tuples_read,
+                COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
+                COUNT(*) OVER() as total_count
+            FROM pg_indexes i
+            JOIN pg_class c ON c.relname = i.indexname
+            JOIN pg_am am ON c.relam = am.oid
+            LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
+                AND psat.schemaname = i.schemaname
+            WHERE i.schemaname = $1
+            AND i.indexdef LIKE '%vector%'
+            {where_clause}
+        )
+        SELECT *
+        FROM index_info
+        ORDER BY name
+        LIMIT ${param_count + 1}
+        OFFSET ${param_count + 2}
+        """
+
+        # Add limit and offset to params
+        params.extend([limit, offset])
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        indices = []
+        total_entries = 0
+
+        if results:
+            total_entries = results[0]["total_count"]
+            for result in results:
+                index_info = {
+                    "name": result["name"],
+                    "table_name": result["table_name"],
+                    "definition": result["definition"],
+                    "size_in_bytes": result["size_in_bytes"],
+                    "row_estimate": result["row_estimate"],
+                    "number_of_scans": result["number_of_scans"],
+                    "tuples_read": result["tuples_read"],
+                    "tuples_fetched": result["tuples_fetched"],
+                }
+                indices.append(index_info)
+
+        # Calculate pagination info
+        total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1
+        current_page = (offset // limit) + 1 if limit > 0 else 1
+
+        page_info = {
+            "total_entries": total_entries,
+            "total_pages": total_pages,
+            "current_page": current_page,
+            "limit": limit,
+            "offset": offset,
+            "has_previous": offset > 0,
+            "has_next": offset + limit < total_entries,
+            "previous_offset": max(0, offset - limit) if offset > 0 else None,
+            "next_offset": (
+                offset + limit if offset + limit < total_entries else None
+            ),
+        }
+
+        return {"indices": indices, "page_info": page_info}
+
+    async def delete_index(
+        self,
+        index_name: str,
+        table_name: Optional[VectorTableName] = None,
+        concurrently: bool = True,
+    ) -> None:
+        """
+        Deletes a vector index.
+
+        Args:
+            index_name (str): Name of the index to delete
+            table_name (VectorTableName, optional): Table the index belongs to
+            concurrently (bool): Whether to drop the index concurrently
+
+        Raises:
+            ArgError: If table name is invalid or index doesn't exist
+            Exception: If index deletion fails
+        """
+        # Validate table name and get column name
+        if table_name == VectorTableName.CHUNKS:
+            table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
+            col_name = "vec"
+        elif table_name == VectorTableName.ENTITIES_DOCUMENT:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
+            )
+            col_name = "description_embedding"
+        elif table_name == VectorTableName.GRAPHS_ENTITIES:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
+            )
+            col_name = "description_embedding"
+        elif table_name == VectorTableName.COMMUNITIES:
+            table_name_str = (
+                f"{self.project_name}.{VectorTableName.COMMUNITIES}"
+            )
+            col_name = "description_embedding"
+        else:
+            raise ArgError("invalid table name")
+
+        # Extract schema and base table name
+        schema_name, base_table_name = table_name_str.split(".")
+
+        # Verify index exists and is a vector index
+        query = """
+        SELECT indexdef
+        FROM pg_indexes
+        WHERE indexname = $1
+        AND schemaname = $2
+        AND tablename = $3
+        AND indexdef LIKE $4
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            query, (index_name, schema_name, base_table_name, f"%({col_name}%")
+        )
+
+        if not result:
+            raise ArgError(
+                f"Vector index '{index_name}' does not exist on table {table_name_str}"
+            )
+
+        # Drop the index
+        concurrently_sql = "CONCURRENTLY" if concurrently else ""
+        drop_query = (
+            f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
+        )
+
+        try:
+            if concurrently:
+                async with (
+                    self.connection_manager.pool.get_connection() as conn  # type: ignore
+                ):
+                    # Disable automatic transaction management
+                    await conn.execute(
+                        "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
+                    )
+                    await conn.execute(drop_query)
+            else:
+                await self.connection_manager.execute_query(drop_query)
+        except Exception as e:
+            raise Exception(f"Failed to delete index: {e}")
+
+    async def get_semantic_neighbors(
+        self,
+        offset: int,
+        limit: int,
+        document_id: UUID,
+        id: UUID,
+        similarity_threshold: float = 0.5,
+    ) -> list[dict[str, Any]]:
+
+        table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
+        query = f"""
+        WITH target_vector AS (
+            SELECT vec FROM {table_name}
+            WHERE document_id = $1 AND id = $2
+        )
+        SELECT t.id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
+        FROM {table_name} t, target_vector tv
+        WHERE (t.vec <=> tv.vec) >= $3
+            AND t.document_id = $1
+            AND t.id != $2
+        ORDER BY similarity ASC
+        LIMIT $4
+        """
+        results = await self.connection_manager.fetch_query(
+            query,
+            (str(document_id), str(id), similarity_threshold, limit),
+        )
+
+        return [
+            {
+                "id": str(r["id"]),
+                "text": r["text"],
+                "metadata": json.loads(r["metadata"]),
+                "document_id": str(r["document_id"]),
+                "similarity": float(r["similarity"]),
+            }
+            for r in results
+        ]
+
+    async def list_chunks(
+        self,
+        offset: int,
+        limit: int,
+        filters: Optional[dict[str, Any]] = None,
+        include_vectors: bool = False,
+    ) -> dict[str, Any]:
+        """
+        List chunks with pagination support.
+
+        Args:
+            offset (int, optional): Number of records to skip. Defaults to 0.
+            limit (int, optional): Maximum number of records to return. Defaults to 10.
+            filters (dict, optional): Dictionary of filters to apply. Defaults to None.
+            include_vectors (bool, optional): Whether to include vector data. Defaults to False.
+
+        Returns:
+            dict: Dictionary containing:
+                - results: List of chunk records
+                - total_entries: Total number of chunks matching the filters
+                - page_info: Pagination information
+        """
+        # Validate sort parameters
+        valid_sort_columns = {
+            "created_at": "metadata->>'created_at'",
+            "updated_at": "metadata->>'updated_at'",
+            "chunk_order": "metadata->>'chunk_order'",
+            "text": "text",
+        }
+
+        # Build the select clause
+        vector_select = ", vec" if include_vectors else ""
+        select_clause = f"""
+            id, document_id, owner_id, collection_ids,
+            text, metadata{vector_select}, COUNT(*) OVER() AS total
+        """
+
+        # Build the where clause if filters are provided
+        where_clause = ""
+        params: list[str | int | bytes] = []
+        if filters:
+            where_clause = self._build_filters(filters, params)
+            where_clause = f"WHERE {where_clause}"
+
+        # Construct the final query
+        query = f"""
+        SELECT {select_clause}
+        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+        {where_clause}
+        LIMIT $%s
+        OFFSET $%s
+        """
+
+        # Add pagination parameters
+        params.extend([limit, offset])
+        param_indices = list(range(1, len(params) + 1))
+        formatted_query = query % tuple(param_indices)
+
+        # Execute the query
+        results = await self.connection_manager.fetch_query(
+            formatted_query, params
+        )
+
+        # Process results
+        chunks = []
+        total = 0
+        if results:
+            total = results[0].get("total", 0)
+            chunks = [
+                {
+                    "id": str(result["id"]),
+                    "document_id": str(result["document_id"]),
+                    "owner_id": str(result["owner_id"]),
+                    "collection_ids": result["collection_ids"],
+                    "text": result["text"],
+                    "metadata": json.loads(result["metadata"]),
+                    "vector": (
+                        json.loads(result["vec"]) if include_vectors else None
+                    ),
+                }
+                for result in results
+            ]
+
+        # Calculate pagination info
+        total_pages = (total + limit - 1) // limit if limit > 0 else 1
+        current_page = (offset // limit) + 1 if limit > 0 else 1
+
+        page_info = {
+            "total_entries": total,
+            "total_pages": total_pages,
+            "current_page": current_page,
+            "limit": limit,
+            "offset": offset,
+            "has_previous": offset > 0,
+            "has_next": offset + limit < total,
+            "previous_offset": max(0, offset - limit) if offset > 0 else None,
+            "next_offset": offset + limit if offset + limit < total else None,
+        }
+
+        return {"results": chunks, "page_info": page_info}
+
+    async def search_documents(
+        self,
+        query_text: str,
+        settings: SearchSettings,
+    ) -> list[dict[str, Any]]:
+        """
+        Search for documents based on their metadata fields and/or body text.
+        Joins with documents table to get complete document metadata.
+
+        Args:
+            query_text (str): The search query text
+            settings (SearchSettings): Search settings including search preferences and filters
+
+        Returns:
+            list[dict[str, Any]]: List of documents with their search scores and complete metadata
+        """
+        where_clauses = []
+        params: list[str | int | bytes] = [query_text]
+
+        # Build the dynamic metadata field search expression
+        metadata_fields_expr = " || ' ' || ".join(
+            [
+                f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
+                for key in settings.metadata_keys  # type: ignore
+            ]
+        )
+
+        query = f"""
+            WITH
+            -- Metadata search scores
+            metadata_scores AS (
+                SELECT DISTINCT ON (v.document_id)
+                    v.document_id,
+                    d.metadata as doc_metadata,
+                    CASE WHEN $1 = '' THEN 0.0
+                    ELSE
+                        ts_rank_cd(
+                            setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
+                            websearch_to_tsquery('english', $1),
+                            32
+                        )
+                    END as metadata_rank
+                FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
+                LEFT JOIN {self._get_table_name('documents')} d ON v.document_id = d.id
+                WHERE v.metadata IS NOT NULL
+            ),
+            -- Body search scores
+            body_scores AS (
+                SELECT
+                    document_id,
+                    AVG(
+                        ts_rank_cd(
+                            setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
+                            websearch_to_tsquery('english', $1),
+                            32
+                        )
+                    ) as body_rank
+                FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
+                WHERE $1 != ''
+                {f"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
+                GROUP BY document_id
+            ),
+            -- Combined scores with document metadata
+            combined_scores AS (
+                SELECT
+                    COALESCE(m.document_id, b.document_id) as document_id,
+                    m.doc_metadata as metadata,
+                    COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
+                    COALESCE(b.body_rank, 0) as debug_body_rank,
+                    CASE
+                        WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
+                            COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
+                        WHEN {str(settings.search_over_metadata).lower()} THEN
+                            COALESCE(m.metadata_rank, 0)
+                        WHEN {str(settings.search_over_body).lower()} THEN
+                            COALESCE(b.body_rank, 0)
+                        ELSE 0
+                    END as rank
+                FROM metadata_scores m
+                FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
+                WHERE (
+                    ($1 = '') OR
+                    ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
+                    ({str(settings.search_over_body).lower()} AND b.body_rank > 0)
+                )
+        """
+
+        # Add any additional filters
+        if settings.filters:
+            filter_clause = self._build_filters(settings.filters, params)
+            where_clauses.append(filter_clause)
+
+        if where_clauses:
+            query += f" AND {' AND '.join(where_clauses)}"
+
+        query += """
+            )
+            SELECT
+                document_id,
+                metadata,
+                rank as score,
+                debug_metadata_rank,
+                debug_body_rank
+            FROM combined_scores
+            WHERE rank > 0
+            ORDER BY rank DESC
+            OFFSET ${offset_param} LIMIT ${limit_param}
+        """.format(
+            offset_param=len(params) + 1,
+            limit_param=len(params) + 2,
+        )
+
+        # Add offset and limit to params
+        params.extend([settings.offset, settings.limit])
+
+        # Execute query
+        results = await self.connection_manager.fetch_query(query, params)
+
+        # Format results with complete document metadata
+        return [
+            {
+                "document_id": str(r["document_id"]),
+                "metadata": (
+                    json.loads(r["metadata"])
+                    if isinstance(r["metadata"], str)
+                    else r["metadata"]
+                ),
+                "score": float(r["score"]),
+                "debug_metadata_rank": float(r["debug_metadata_rank"]),
+                "debug_body_rank": float(r["debug_body_rank"]),
+            }
+            for r in results
+        ]
+
+    def _get_index_options(
+        self,
+        method: IndexMethod,
+        index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
+    ) -> str:
+        if method == IndexMethod.ivfflat:
+            if isinstance(index_arguments, IndexArgsIVFFlat):
+                return f"WITH (lists={index_arguments.n_lists})"
+            else:
+                # Default value if no arguments provided
+                return "WITH (lists=100)"
+        elif method == IndexMethod.hnsw:
+            if isinstance(index_arguments, IndexArgsHNSW):
+                return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
+            else:
+                # Default values if no arguments provided
+                return "WITH (m=16, ef_construction=64)"
+        else:
+            return ""  # No options for other methods

+ 471 - 0
core/database/collections.py

@@ -0,0 +1,471 @@
+import json
+import logging
+from typing import Any, Optional
+from uuid import UUID, uuid4
+
+from asyncpg.exceptions import UniqueViolationError
+from fastapi import HTTPException
+
+from core.base import (
+    Handler,
+    DatabaseConfig,
+    KGExtractionStatus,
+    R2RException,
+    generate_default_user_collection_id,
+)
+from core.base.abstractions import (
+    DocumentResponse,
+    DocumentType,
+    IngestionStatus,
+)
+from core.base.api.models import CollectionResponse
+from core.utils import generate_default_user_collection_id
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresCollectionsHandler(Handler):
+    TABLE_NAME = "collections"
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        config: DatabaseConfig,
+    ):
+        self.config = config
+        super().__init__(project_name, connection_manager)
+
+    async def create_tables(self) -> None:
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            owner_id UUID,
+            name TEXT NOT NULL,
+            description TEXT,
+            graph_sync_status TEXT DEFAULT 'pending',
+            graph_cluster_status TEXT DEFAULT 'pending',
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW(),
+            user_count INT DEFAULT 0,
+            document_count INT DEFAULT 0
+        );
+        """
+        await self.connection_manager.execute_query(query)
+
+    async def collection_exists(self, collection_id: UUID) -> bool:
+        """Check if a collection exists."""
+        query = f"""
+            SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+            WHERE id = $1
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id]
+        )
+        return result is not None
+
+    async def create_collection(
+        self,
+        owner_id: UUID,
+        name: Optional[str] = None,
+        description: str = "",
+        collection_id: Optional[UUID] = None,
+    ) -> CollectionResponse:
+
+        if not name and not collection_id:
+            name = self.config.default_collection_name
+            collection_id = generate_default_user_collection_id(owner_id)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+            (id, owner_id, name, description)
+            VALUES ($1, $2, $3, $4)
+            RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+        """
+        params = [
+            collection_id or uuid4(),
+            owner_id,
+            name,
+            description,
+        ]
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+            if not result:
+                raise R2RException(
+                    status_code=404, message="Collection not found"
+                )
+
+            return CollectionResponse(
+                id=result["id"],
+                owner_id=result["owner_id"],
+                name=result["name"],
+                description=result["description"],
+                graph_cluster_status=result["graph_cluster_status"],
+                graph_sync_status=result["graph_sync_status"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                user_count=0,
+                document_count=0,
+            )
+        except UniqueViolationError:
+            raise R2RException(
+                message="Collection with this ID already exists",
+                status_code=409,
+            )
+
+    async def update_collection(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+    ) -> CollectionResponse:
+        """Update an existing collection."""
+        if not await self.collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+
+        update_fields = []
+        params: list = []
+        param_index = 1
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(collection_id)
+
+        query = f"""
+            WITH updated_collection AS (
+                UPDATE {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+                SET {', '.join(update_fields)}
+                WHERE id = ${param_index}
+                RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
+            )
+            SELECT
+                uc.*,
+                COUNT(DISTINCT u.id) FILTER (WHERE u.id IS NOT NULL) as user_count,
+                COUNT(DISTINCT d.id) FILTER (WHERE d.id IS NOT NULL) as document_count
+            FROM updated_collection uc
+            LEFT JOIN {self._get_table_name('users')} u ON uc.id = ANY(u.collection_ids)
+            LEFT JOIN {self._get_table_name('documents')} d ON uc.id = ANY(d.collection_ids)
+            GROUP BY uc.id, uc.owner_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at
+        """
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+            if not result:
+                raise R2RException(
+                    status_code=404, message="Collection not found"
+                )
+
+            return CollectionResponse(
+                id=result["id"],
+                owner_id=result["owner_id"],
+                name=result["name"],
+                description=result["description"],
+                graph_sync_status=result["graph_sync_status"],
+                graph_cluster_status=result["graph_cluster_status"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                user_count=result["user_count"],
+                document_count=result["document_count"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the collection: {e}",
+            )
+
+    async def delete_collection_relational(self, collection_id: UUID) -> None:
+        # Remove collection_id from users
+        user_update_query = f"""
+            UPDATE {self._get_table_name('users')}
+            SET collection_ids = array_remove(collection_ids, $1)
+            WHERE $1 = ANY(collection_ids)
+        """
+        await self.connection_manager.execute_query(
+            user_update_query, [collection_id]
+        )
+
+        # Remove collection_id from documents
+        document_update_query = f"""
+            WITH updated AS (
+                UPDATE {self._get_table_name('documents')}
+                SET collection_ids = array_remove(collection_ids, $1)
+                WHERE $1 = ANY(collection_ids)
+                RETURNING 1
+            )
+            SELECT COUNT(*) AS affected_rows FROM updated
+        """
+        await self.connection_manager.fetchrow_query(
+            document_update_query, [collection_id]
+        )
+
+        # Delete the collection
+        delete_query = f"""
+            DELETE FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+            WHERE id = $1
+            RETURNING id
+        """
+        deleted = await self.connection_manager.fetchrow_query(
+            delete_query, [collection_id]
+        )
+
+        if not deleted:
+            raise R2RException(status_code=404, message="Collection not found")
+
+    async def documents_in_collection(
+        self, collection_id: UUID, offset: int, limit: int
+    ) -> dict[str, list[DocumentResponse] | int]:
+        """
+        Get all documents in a specific collection with pagination.
+        Args:
+            collection_id (UUID): The ID of the collection to get documents from.
+            offset (int): The number of documents to skip.
+            limit (int): The maximum number of documents to return.
+        Returns:
+            List[DocumentResponse]: A list of DocumentResponse objects representing the documents in the collection.
+        Raises:
+            R2RException: If the collection doesn't exist.
+        """
+        if not await self.collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+        query = f"""
+            SELECT d.id, d.owner_id, d.type, d.metadata, d.title, d.version,
+                d.size_in_bytes, d.ingestion_status, d.extraction_status, d.created_at, d.updated_at, d.summary,
+                COUNT(*) OVER() AS total_entries
+            FROM {self._get_table_name('documents')} d
+            WHERE $1 = ANY(d.collection_ids)
+            ORDER BY d.created_at DESC
+            OFFSET $2
+        """
+
+        conditions = [collection_id, offset]
+        if limit != -1:
+            query += " LIMIT $3"
+            conditions.append(limit)
+
+        results = await self.connection_manager.fetch_query(query, conditions)
+        documents = [
+            DocumentResponse(
+                id=row["id"],
+                collection_ids=[collection_id],
+                owner_id=row["owner_id"],
+                document_type=DocumentType(row["type"]),
+                metadata=json.loads(row["metadata"]),
+                title=row["title"],
+                version=row["version"],
+                size_in_bytes=row["size_in_bytes"],
+                ingestion_status=IngestionStatus(row["ingestion_status"]),
+                extraction_status=KGExtractionStatus(row["extraction_status"]),
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                summary=row["summary"],
+            )
+            for row in results
+        ]
+        total_entries = results[0]["total_entries"] if results else 0
+
+        return {"results": documents, "total_entries": total_entries}
+
+    async def get_collections_overview(
+        self,
+        offset: int,
+        limit: int,
+        filter_user_ids: Optional[list[UUID]] = None,
+        filter_document_ids: Optional[list[UUID]] = None,
+        filter_collection_ids: Optional[list[UUID]] = None,
+    ) -> dict[str, list[CollectionResponse] | int]:
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filter_user_ids:
+            conditions.append(
+                f"""
+                c.id IN (
+                    SELECT unnest(collection_ids)
+                    FROM {self.project_name}.users
+                    WHERE id = ANY(${param_index})
+                )
+            """
+            )
+            params.append(filter_user_ids)
+            param_index += 1
+
+        if filter_document_ids:
+            conditions.append(
+                f"""
+                c.id IN (
+                    SELECT unnest(collection_ids)
+                    FROM {self.project_name}.documents
+                    WHERE id = ANY(${param_index})
+                )
+            """
+            )
+            params.append(filter_document_ids)
+            param_index += 1
+
+        if filter_collection_ids:
+            conditions.append(f"c.id = ANY(${param_index})")
+            params.append(filter_collection_ids)
+            param_index += 1
+
+        where_clause = (
+            f"WHERE {' AND '.join(conditions)}" if conditions else ""
+        )
+
+        query = f"""
+            SELECT
+                c.*,
+                COUNT(*) OVER() as total_entries
+            FROM {self.project_name}.collections c
+            {where_clause}
+            ORDER BY created_at DESC
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            query += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        try:
+            results = await self.connection_manager.fetch_query(query, params)
+
+            if not results:
+                return {"results": [], "total_entries": 0}
+
+            total_entries = results[0]["total_entries"] if results else 0
+
+            collections = [CollectionResponse(**row) for row in results]
+
+            return {"results": collections, "total_entries": total_entries}
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while fetching collections: {e}",
+            )
+
+    async def assign_document_to_collection_relational(
+        self,
+        document_id: UUID,
+        collection_id: UUID,
+    ) -> UUID:
+        """
+        Assign a document to a collection.
+
+        Args:
+            document_id (UUID): The ID of the document to assign.
+            collection_id (UUID): The ID of the collection to assign the document to.
+
+        Raises:
+            R2RException: If the collection doesn't exist, if the document is not found,
+                        or if there's a database error.
+        """
+        try:
+            if not await self.collection_exists(collection_id):
+                raise R2RException(
+                    status_code=404, message="Collection not found"
+                )
+
+            # First, check if the document exists
+            document_check_query = f"""
+                SELECT 1 FROM {self._get_table_name('documents')}
+                WHERE id = $1
+            """
+            document_exists = await self.connection_manager.fetchrow_query(
+                document_check_query, [document_id]
+            )
+
+            if not document_exists:
+                raise R2RException(
+                    status_code=404, message="Document not found"
+                )
+
+            # If document exists, proceed with the assignment
+            assign_query = f"""
+                UPDATE {self._get_table_name('documents')}
+                SET collection_ids = array_append(collection_ids, $1)
+                WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+                RETURNING id
+            """
+            result = await self.connection_manager.fetchrow_query(
+                assign_query, [collection_id, document_id]
+            )
+
+            if not result:
+                # Document exists but was already assigned to the collection
+                raise R2RException(
+                    status_code=409,
+                    message="Document is already assigned to the collection",
+                )
+
+            update_collection_query = f"""
+                UPDATE {self._get_table_name('collections')}
+                SET document_count = document_count + 1
+                WHERE id = $1
+            """
+            await self.connection_manager.execute_query(
+                query=update_collection_query, params=[collection_id]
+            )
+
+            return collection_id
+
+        except R2RException:
+            # Re-raise R2RExceptions as they are already handled
+            raise
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error '{e}' occurred while assigning the document to the collection",
+            )
+
+    async def remove_document_from_collection_relational(
+        self, document_id: UUID, collection_id: UUID
+    ) -> None:
+        """
+        Remove a document from a collection.
+
+        Args:
+            document_id (UUID): The ID of the document to remove.
+            collection_id (UUID): The ID of the collection to remove the document from.
+
+        Raises:
+            R2RException: If the collection doesn't exist or if the document is not in the collection.
+        """
+        if not await self.collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+
+        query = f"""
+            UPDATE {self._get_table_name('documents')}
+            SET collection_ids = array_remove(collection_ids, $1)
+            WHERE id = $2 AND $1 = ANY(collection_ids)
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id, document_id]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=404,
+                message="Document not found in the specified collection",
+            )

+ 376 - 0
core/database/conversations.py

@@ -0,0 +1,376 @@
+import json
+from typing import Any, Dict, List, Optional
+from uuid import UUID, uuid4
+
+from core.base import Handler, Message, R2RException
+from shared.api.models.management.responses import (
+    ConversationResponse,
+    MessageResponse,
+)
+
+from .base import PostgresConnectionManager
+
+
+class PostgresConversationsHandler(Handler):
+    def __init__(
+        self, project_name: str, connection_manager: PostgresConnectionManager
+    ):
+        self.project_name = project_name
+        self.connection_manager = connection_manager
+
+    async def create_tables(self):
+        # Ensure the uuid_generate_v4() extension is available
+        # Depending on your environment, you may need a separate call:
+        # await self.connection_manager.execute_query("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";")
+
+        create_conversations_query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name("conversations")} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            user_id UUID,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            name TEXT
+        );
+        """
+
+        create_messages_query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name("messages")} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            conversation_id UUID NOT NULL,
+            parent_id UUID,
+            content JSONB,
+            metadata JSONB,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            FOREIGN KEY (conversation_id) REFERENCES {self._get_table_name("conversations")}(id),
+            FOREIGN KEY (parent_id) REFERENCES {self._get_table_name("messages")}(id)
+        );
+        """
+        await self.connection_manager.execute_query(create_conversations_query)
+        await self.connection_manager.execute_query(create_messages_query)
+
+    async def create_conversation(
+        self, user_id: Optional[UUID] = None, name: Optional[str] = None
+    ) -> ConversationResponse:
+        query = f"""
+            INSERT INTO {self._get_table_name("conversations")} (user_id, name)
+            VALUES ($1, $2)
+            RETURNING id, extract(epoch from created_at) as created_at_epoch
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [user_id, name]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=500, message="Failed to create conversation."
+            )
+
+        return ConversationResponse(
+            id=str(result["id"]),
+            created_at=result["created_at_epoch"],
+        )
+
+    async def verify_conversation_access(
+        self, conversation_id: UUID, user_id: UUID
+    ) -> bool:
+        query = f"""
+            SELECT 1 FROM {self._get_table_name("conversations")}
+            WHERE id = $1 AND (user_id IS NULL OR user_id = $2)
+        """
+        row = await self.connection_manager.fetchrow_query(
+            query, [conversation_id, user_id]
+        )
+        return row is not None
+
+    async def get_conversations_overview(
+        self,
+        offset: int,
+        limit: int,
+        user_ids: Optional[UUID | List[UUID]] = None,
+        conversation_ids: Optional[List[UUID]] = None,
+    ) -> Dict[str, Any]:
+        # Construct conditions
+        conditions = []
+        params = []
+        param_index = 1
+
+        if user_ids is not None:
+            if isinstance(user_ids, UUID):
+                conditions.append(f"user_id = ${param_index}")
+                params.append(user_ids)
+                param_index += 1
+            else:
+                # user_ids is a list of UUIDs
+                placeholders = ", ".join(
+                    f"${i+param_index}" for i in range(len(user_ids))
+                )
+                conditions.append(
+                    f"user_id = ANY(ARRAY[{placeholders}]::uuid[])"
+                )
+                params.extend(user_ids)
+                param_index += len(user_ids)
+
+        if conversation_ids:
+            placeholders = ", ".join(
+                f"${i+param_index}" for i in range(len(conversation_ids))
+            )
+            conditions.append(f"id = ANY(ARRAY[{placeholders}]::uuid[])")
+            params.extend(conversation_ids)
+            param_index += len(conversation_ids)
+
+        where_clause = ""
+        if conditions:
+            where_clause = "WHERE " + " AND ".join(conditions)
+
+        limit_clause = ""
+        if limit != -1:
+            limit_clause = f"LIMIT ${param_index}"
+            params.append(limit)
+            param_index += 1
+
+        offset_clause = f"OFFSET ${param_index}"
+        params.append(offset)
+
+        query = f"""
+            WITH conversation_overview AS (
+                SELECT id, extract(epoch from created_at) as created_at_epoch, user_id, name
+                FROM {self._get_table_name("conversations")}
+                {where_clause}
+            ),
+            counted_overview AS (
+                SELECT *,
+                       COUNT(*) OVER() AS total_entries
+                FROM conversation_overview
+            )
+            SELECT * FROM counted_overview
+            ORDER BY created_at_epoch DESC
+            {limit_clause} {offset_clause}
+        """
+        results = await self.connection_manager.fetch_query(query, params)
+
+        if not results:
+            return {"results": [], "total_entries": 0}
+
+        total_entries = results[0]["total_entries"]
+        conversations = [
+            {
+                "id": str(row["id"]),
+                "created_at": row["created_at_epoch"],
+                "user_id": str(row["user_id"]) if row["user_id"] else None,
+                "name": row["name"] or None,
+            }
+            for row in results
+        ]
+
+        return {"results": conversations, "total_entries": total_entries}
+
+    async def add_message(
+        self,
+        conversation_id: UUID,
+        content: Message,
+        parent_id: Optional[UUID] = None,
+        metadata: Optional[dict] = None,
+    ) -> MessageResponse:
+        # Check if conversation exists
+        conv_check_query = f"""
+            SELECT 1 FROM {self._get_table_name("conversations")}
+            WHERE id = $1
+        """
+        conv_row = await self.connection_manager.fetchrow_query(
+            conv_check_query, [conversation_id]
+        )
+        if not conv_row:
+            raise R2RException(
+                status_code=404,
+                message=f"Conversation {conversation_id} not found.",
+            )
+
+        # Check parent message if provided
+        if parent_id:
+            parent_check_query = f"""
+                SELECT 1 FROM {self._get_table_name("messages")}
+                WHERE id = $1 AND conversation_id = $2
+            """
+            parent_row = await self.connection_manager.fetchrow_query(
+                parent_check_query, [parent_id, conversation_id]
+            )
+            if not parent_row:
+                raise R2RException(
+                    status_code=404,
+                    message=f"Parent message {parent_id} not found in conversation {conversation_id}.",
+                )
+
+        message_id = uuid4()
+        content_str = json.dumps(content.model_dump())
+        metadata_str = json.dumps(metadata or {})
+
+        query = f"""
+            INSERT INTO {self._get_table_name("messages")}
+            (id, conversation_id, parent_id, content, created_at, metadata)
+            VALUES ($1, $2, $3, $4::jsonb, NOW(), $5::jsonb)
+            RETURNING id
+        """
+        inserted = await self.connection_manager.fetchrow_query(
+            query,
+            [
+                message_id,
+                conversation_id,
+                parent_id,
+                content_str,
+                metadata_str,
+            ],
+        )
+        if not inserted:
+            raise R2RException(
+                status_code=500, message="Failed to insert message."
+            )
+
+        return MessageResponse(id=str(message_id), message=content)
+
+    async def edit_message(
+        self,
+        message_id: UUID,
+        new_content: str,
+        additional_metadata: dict = {},
+    ) -> Dict[str, Any]:
+        # Get the original message
+        query = f"""
+            SELECT conversation_id, parent_id, content, metadata
+            FROM {self._get_table_name("messages")}
+            WHERE id = $1
+        """
+        row = await self.connection_manager.fetchrow_query(query, [message_id])
+        if not row:
+            raise R2RException(
+                status_code=404, message=f"Message {message_id} not found."
+            )
+
+        old_content = json.loads(row["content"])
+        old_metadata = json.loads(row["metadata"])
+
+        # Update the content
+        old_message = Message(**old_content)
+        edited_message = Message(
+            role=old_message.role,
+            content=new_content,
+            name=old_message.name,
+            function_call=old_message.function_call,
+            tool_calls=old_message.tool_calls,
+        )
+
+        # Merge metadata and mark edited
+        new_metadata = {**old_metadata, **additional_metadata, "edited": True}
+
+        # Instead of branching, we'll simply replace the message content and metadata:
+        # NOTE: If you prefer versioning or forking behavior, you'd add a new message.
+        # For simplicity, we just edit the existing message.
+        update_query = f"""
+            UPDATE {self._get_table_name("messages")}
+            SET content = $1::jsonb, metadata = $2::jsonb, created_at = NOW()
+            WHERE id = $3
+            RETURNING id
+        """
+        updated = await self.connection_manager.fetchrow_query(
+            update_query,
+            [
+                json.dumps(edited_message.model_dump()),
+                json.dumps(new_metadata),
+                message_id,
+            ],
+        )
+        if not updated:
+            raise R2RException(
+                status_code=500, message="Failed to update message."
+            )
+
+        return {
+            "id": str(message_id),
+            "message": edited_message,
+            "metadata": new_metadata,
+        }
+
+    async def update_message_metadata(
+        self, message_id: UUID, metadata: dict
+    ) -> None:
+        # Fetch current metadata
+        query = f"""
+            SELECT metadata FROM {self._get_table_name("messages")}
+            WHERE id = $1
+        """
+        row = await self.connection_manager.fetchrow_query(query, [message_id])
+        if not row:
+            raise R2RException(
+                status_code=404, message=f"Message {message_id} not found."
+            )
+
+        current_metadata = row["metadata"] or {}
+        updated_metadata = {**current_metadata, **metadata}
+
+        update_query = f"""
+            UPDATE {self._get_table_name("messages")}
+            SET metadata = $1::jsonb
+            WHERE id = $2
+        """
+        await self.connection_manager.execute_query(
+            update_query, [updated_metadata, message_id]
+        )
+
+    async def get_conversation(
+        self, conversation_id: UUID
+    ) -> List[MessageResponse]:
+        # Check conversation
+        conv_query = f"SELECT extract(epoch from created_at) AS created_at_epoch FROM {self._get_table_name('conversations')} WHERE id = $1"
+        conv_row = await self.connection_manager.fetchrow_query(
+            conv_query, [conversation_id]
+        )
+        if not conv_row:
+            raise R2RException(
+                status_code=404,
+                message=f"Conversation {conversation_id} not found.",
+            )
+
+        # Retrieve messages in chronological order
+        # We'll recursively gather messages based on parent_id = NULL as root.
+        # Since no branching, we simply order by created_at.
+        msg_query = f"""
+            SELECT id, content, metadata
+            FROM {self._get_table_name("messages")}
+            WHERE conversation_id = $1
+            ORDER BY created_at ASC
+        """
+        results = await self.connection_manager.fetch_query(
+            msg_query, [conversation_id]
+        )
+
+        print("results = ", results)
+        return [
+            MessageResponse(
+                id=str(row["id"]),
+                message=Message(**json.loads(row["content"])),
+                metadata=json.loads(row["metadata"]),
+            )
+            for row in results
+        ]
+
+    async def delete_conversation(self, conversation_id: UUID):
+        # Check if conversation exists
+        conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1"
+        conv_row = await self.connection_manager.fetchrow_query(
+            conv_query, [conversation_id]
+        )
+        if not conv_row:
+            raise R2RException(
+                status_code=404,
+                message=f"Conversation {conversation_id} not found.",
+            )
+
+        # Delete all messages
+        del_messages_query = f"DELETE FROM {self._get_table_name('messages')} WHERE conversation_id = $1"
+        await self.connection_manager.execute_query(
+            del_messages_query, [conversation_id]
+        )
+
+        # Delete conversation
+        del_conv_query = f"DELETE FROM {self._get_table_name('conversations')} WHERE id = $1"
+        await self.connection_manager.execute_query(
+            del_conv_query, [conversation_id]
+        )

+ 933 - 0
core/database/documents.py

@@ -0,0 +1,933 @@
+import asyncio
+import copy
+import json
+import logging
+from typing import Any, Optional
+from uuid import UUID
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import (
+    Handler,
+    DocumentResponse,
+    DocumentType,
+    IngestionStatus,
+    KGEnrichmentStatus,
+    KGExtractionStatus,
+    R2RException,
+    SearchSettings,
+)
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresDocumentsHandler(Handler):
+    TABLE_NAME = "documents"
+    COLUMN_VARS = [
+        "extraction_id",
+        "id",
+        "owner_id",
+        "collection_ids",
+    ]
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        dimension: int,
+    ):
+        self.dimension = dimension
+        super().__init__(project_name, connection_manager)
+
+    async def create_tables(self):
+        logger.info(
+            f"Creating table, if not exists: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+        )
+        try:
+            query = f"""
+            CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
+                id UUID PRIMARY KEY,
+                collection_ids UUID[],
+                owner_id UUID,
+                type TEXT,
+                metadata JSONB,
+                title TEXT,
+                summary TEXT NULL,
+                summary_embedding vector({self.dimension}) NULL,
+                version TEXT,
+                size_in_bytes INT,
+                ingestion_status TEXT DEFAULT 'pending',
+                extraction_status TEXT DEFAULT 'pending',
+                created_at TIMESTAMPTZ DEFAULT NOW(),
+                updated_at TIMESTAMPTZ DEFAULT NOW(),
+                ingestion_attempt_number INT DEFAULT 0,
+                raw_tsvector tsvector GENERATED ALWAYS AS (
+                    setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
+                    setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
+                    setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
+                ) STORED
+            );
+            CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
+            ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
+
+            -- Full text search index
+            CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
+            ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+            USING GIN (raw_tsvector);
+            """
+            await self.connection_manager.execute_query(query)
+        except Exception as e:
+            logger.warning(f"Error {e} when creating document table.")
+
+    async def upsert_documents_overview(
+        self, documents_overview: DocumentResponse | list[DocumentResponse]
+    ) -> None:
+        if isinstance(documents_overview, DocumentResponse):
+            documents_overview = [documents_overview]
+
+        # TODO: make this an arg
+        max_retries = 20
+        for document in documents_overview:
+            retries = 0
+            while retries < max_retries:
+                try:
+                    async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+                        async with conn.transaction():
+                            # Lock the row for update
+                            check_query = f"""
+                            SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+                            WHERE id = $1 FOR UPDATE
+                            """
+                            existing_doc = await conn.fetchrow(
+                                check_query, document.id
+                            )
+
+                            db_entry = document.convert_to_db_entry()
+
+                            if existing_doc:
+                                db_version = existing_doc[
+                                    "ingestion_attempt_number"
+                                ]
+                                db_status = existing_doc["ingestion_status"]
+                                new_version = db_entry[
+                                    "ingestion_attempt_number"
+                                ]
+
+                                # Only increment version if status is changing to 'success' or if it's a new version
+                                if (
+                                    db_status != "success"
+                                    and db_entry["ingestion_status"]
+                                    == "success"
+                                ) or (new_version > db_version):
+                                    new_attempt_number = db_version + 1
+                                else:
+                                    new_attempt_number = db_version
+
+                                db_entry["ingestion_attempt_number"] = (
+                                    new_attempt_number
+                                )
+
+                                update_query = f"""
+                                UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+                                SET collection_ids = $1, owner_id = $2, type = $3, metadata = $4,
+                                    title = $5, version = $6, size_in_bytes = $7, ingestion_status = $8,
+                                    extraction_status = $9, updated_at = $10, ingestion_attempt_number = $11,
+                                    summary = $12, summary_embedding = $13
+                                WHERE id = $14
+                                """
+
+                                await conn.execute(
+                                    update_query,
+                                    db_entry["collection_ids"],
+                                    db_entry["owner_id"],
+                                    db_entry["document_type"],
+                                    db_entry["metadata"],
+                                    db_entry["title"],
+                                    db_entry["version"],
+                                    db_entry["size_in_bytes"],
+                                    db_entry["ingestion_status"],
+                                    db_entry["extraction_status"],
+                                    db_entry["updated_at"],
+                                    new_attempt_number,
+                                    db_entry["summary"],
+                                    db_entry["summary_embedding"],
+                                    document.id,
+                                )
+                            else:
+
+                                insert_query = f"""
+                                INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+                                (id, collection_ids, owner_id, type, metadata, title, version,
+                                size_in_bytes, ingestion_status, extraction_status, created_at,
+                                updated_at, ingestion_attempt_number, summary, summary_embedding)
+                                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
+                                """
+                                await conn.execute(
+                                    insert_query,
+                                    db_entry["id"],
+                                    db_entry["collection_ids"],
+                                    db_entry["owner_id"],
+                                    db_entry["document_type"],
+                                    db_entry["metadata"],
+                                    db_entry["title"],
+                                    db_entry["version"],
+                                    db_entry["size_in_bytes"],
+                                    db_entry["ingestion_status"],
+                                    db_entry["extraction_status"],
+                                    db_entry["created_at"],
+                                    db_entry["updated_at"],
+                                    db_entry["ingestion_attempt_number"],
+                                    db_entry["summary"],
+                                    db_entry["summary_embedding"],
+                                )
+
+                    break  # Success, exit the retry loop
+                except (
+                    asyncpg.exceptions.UniqueViolationError,
+                    asyncpg.exceptions.DeadlockDetectedError,
+                ) as e:
+                    retries += 1
+                    if retries == max_retries:
+                        logger.error(
+                            f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
+                        )
+                        raise
+                    else:
+                        wait_time = 0.1 * (2**retries)  # Exponential backoff
+                        await asyncio.sleep(wait_time)
+
+    async def delete(
+        self, document_id: UUID, version: Optional[str] = None
+    ) -> None:
+        query = f"""
+        DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+        WHERE id = $1
+        """
+
+        params = [str(document_id)]
+
+        if version:
+            query += " AND version = $2"
+            params.append(version)
+
+        await self.connection_manager.execute_query(query=query, params=params)
+
+    async def _get_status_from_table(
+        self,
+        ids: list[UUID],
+        table_name: str,
+        status_type: str,
+        column_name: str,
+    ):
+        """
+        Get the workflow status for a given document or list of documents.
+
+        Args:
+            ids (list[UUID]): The document IDs.
+            table_name (str): The table name.
+            status_type (str): The type of status to retrieve.
+
+        Returns:
+            The workflow status for the given document or list of documents.
+        """
+        query = f"""
+            SELECT {status_type} FROM {self._get_table_name(table_name)}
+            WHERE {column_name} = ANY($1)
+        """
+        return [
+            row[status_type]
+            for row in await self.connection_manager.fetch_query(query, [ids])
+        ]
+
+    async def _get_ids_from_table(
+        self,
+        status: list[str],
+        table_name: str,
+        status_type: str,
+        collection_id: Optional[UUID] = None,
+    ):
+        """
+        Get the IDs from a given table.
+
+        Args:
+            status (Union[str, list[str]]): The status or list of statuses to retrieve.
+            table_name (str): The table name.
+            status_type (str): The type of status to retrieve.
+        """
+        query = f"""
+            SELECT id FROM {self._get_table_name(table_name)}
+            WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
+        """
+        records = await self.connection_manager.fetch_query(
+            query, [status, collection_id]
+        )
+        return [record["id"] for record in records]
+
+    async def _set_status_in_table(
+        self,
+        ids: list[UUID],
+        status: str,
+        table_name: str,
+        status_type: str,
+        column_name: str,
+    ):
+        """
+        Set the workflow status for a given document or list of documents.
+
+        Args:
+            ids (list[UUID]): The document IDs.
+            status (str): The status to set.
+            table_name (str): The table name.
+            status_type (str): The type of status to set.
+            column_name (str): The column name in the table to update.
+        """
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {status_type} = $1
+            WHERE {column_name} = Any($2)
+        """
+        await self.connection_manager.execute_query(query, [status, ids])
+
+    def _get_status_model(self, status_type: str):
+        """
+        Get the status model for a given status type.
+
+        Args:
+            status_type (str): The type of status to retrieve.
+
+        Returns:
+            The status model for the given status type.
+        """
+        if status_type == "ingestion":
+            return IngestionStatus
+        elif status_type == "extraction_status":
+            return KGExtractionStatus
+        elif status_type == "graph_cluster_status":
+            return KGEnrichmentStatus
+        elif status_type == "graph_sync_status":
+            return KGEnrichmentStatus
+        else:
+            raise R2RException(
+                status_code=400, message=f"Invalid status type: {status_type}"
+            )
+
+    async def get_workflow_status(
+        self, id: UUID | list[UUID], status_type: str
+    ):
+        """
+        Get the workflow status for a given document or list of documents.
+
+        Args:
+            id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
+            status_type (str): The type of status to retrieve.
+
+        Returns:
+            The workflow status for the given document or list of documents.
+        """
+
+        ids = [id] if isinstance(id, UUID) else id
+        out_model = self._get_status_model(status_type)
+        result = await self._get_status_from_table(
+            ids,
+            out_model.table_name(),
+            status_type,
+            out_model.id_column(),
+        )
+
+        result = [out_model[status.upper()] for status in result]
+        return result[0] if isinstance(id, UUID) else result
+
+    async def set_workflow_status(
+        self, id: UUID | list[UUID], status_type: str, status: str
+    ):
+        """
+        Set the workflow status for a given document or list of documents.
+
+        Args:
+            id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
+            status_type (str): The type of status to set.
+            status (str): The status to set.
+        """
+        ids = [id] if isinstance(id, UUID) else id
+        out_model = self._get_status_model(status_type)
+
+        return await self._set_status_in_table(
+            ids,
+            status,
+            out_model.table_name(),
+            status_type,
+            out_model.id_column(),
+        )
+
+    async def get_document_ids_by_status(
+        self,
+        status_type: str,
+        status: str | list[str],
+        collection_id: Optional[UUID] = None,
+    ):
+        """
+        Get the IDs for a given status.
+
+        Args:
+            ids_key (str): The key to retrieve the IDs.
+            status_type (str): The type of status to retrieve.
+            status (Union[str, list[str]]): The status or list of statuses to retrieve.
+        """
+
+        if isinstance(status, str):
+            status = [status]
+
+        out_model = self._get_status_model(status_type)
+        return await self._get_ids_from_table(
+            status, out_model.table_name(), status_type, collection_id
+        )
+
+    async def get_documents_overview(
+        self,
+        offset: int,
+        limit: int,
+        filter_user_ids: Optional[list[UUID]] = None,
+        filter_document_ids: Optional[list[UUID]] = None,
+        filter_collection_ids: Optional[list[UUID]] = None,
+    ) -> dict[str, Any]:
+        conditions = []
+        or_conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        # Handle document IDs with AND
+        if filter_document_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(filter_document_ids)
+            param_index += 1
+
+        # Handle user_ids and collection_ids with OR
+        if filter_user_ids:
+            or_conditions.append(f"owner_id = ANY(${param_index})")
+            params.append(filter_user_ids)
+            param_index += 1
+
+        if filter_collection_ids:
+            or_conditions.append(f"collection_ids && ${param_index}")
+            params.append(filter_collection_ids)
+            param_index += 1
+
+        base_query = f"""
+            FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+        """
+
+        # Combine conditions with appropriate AND/OR logic
+        where_conditions = []
+        if conditions:
+            where_conditions.append("(" + " AND ".join(conditions) + ")")
+        if or_conditions:
+            where_conditions.append("(" + " OR ".join(or_conditions) + ")")
+
+        if where_conditions:
+            base_query += " WHERE " + " AND ".join(where_conditions)
+
+        # Construct the SELECT part of the query based on column existence
+        select_fields = """
+            SELECT id, collection_ids, owner_id, type, metadata, title, version,
+                size_in_bytes, ingestion_status, extraction_status, created_at, updated_at,
+                summary, summary_embedding,
+                COUNT(*) OVER() AS total_entries
+        """
+
+        query = f"""
+            {select_fields}
+            {base_query}
+            ORDER BY created_at DESC
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            query += f" LIMIT ${param_index}"
+            params.append(limit)
+            param_index += 1
+
+        try:
+            results = await self.connection_manager.fetch_query(query, params)
+            total_entries = results[0]["total_entries"] if results else 0
+
+            documents = []
+            for row in results:
+                # Safely handle the embedding
+                embedding = None
+                if (
+                    "summary_embedding" in row
+                    and row["summary_embedding"] is not None
+                ):
+                    try:
+                        # Parse the vector string returned by Postgres
+                        embedding_str = row["summary_embedding"]
+                        if embedding_str.startswith(
+                            "["
+                        ) and embedding_str.endswith("]"):
+                            embedding = [
+                                float(x)
+                                for x in embedding_str[1:-1].split(",")
+                                if x
+                            ]
+                    except Exception as e:
+                        logger.warning(
+                            f"Failed to parse embedding for document {row['id']}: {e}"
+                        )
+
+                documents.append(
+                    DocumentResponse(
+                        id=row["id"],
+                        collection_ids=row["collection_ids"],
+                        owner_id=row["owner_id"],
+                        document_type=DocumentType(row["type"]),
+                        metadata=json.loads(row["metadata"]),
+                        title=row["title"],
+                        version=row["version"],
+                        size_in_bytes=row["size_in_bytes"],
+                        ingestion_status=IngestionStatus(
+                            row["ingestion_status"]
+                        ),
+                        extraction_status=KGExtractionStatus(
+                            row["extraction_status"]
+                        ),
+                        created_at=row["created_at"],
+                        updated_at=row["updated_at"],
+                        summary=row["summary"] if "summary" in row else None,
+                        summary_embedding=embedding,
+                    )
+                )
+            return {"results": documents, "total_entries": total_entries}
+        except Exception as e:
+            logger.error(f"Error in get_documents_overview: {str(e)}")
+            raise HTTPException(
+                status_code=500,
+                detail="Database query failed",
+            )
+
+    async def semantic_document_search(
+        self, query_embedding: list[float], search_settings: SearchSettings
+    ) -> list[DocumentResponse]:
+        """Search documents using semantic similarity with their summary embeddings."""
+
+        where_clauses = ["summary_embedding IS NOT NULL"]
+        params: list[str | int | bytes] = [str(query_embedding)]
+
+        # Handle filters
+        if search_settings.filters:
+            filter_clause = self._build_filters(
+                search_settings.filters, params
+            )
+            where_clauses.append(filter_clause)
+
+        where_clause = " AND ".join(where_clauses)
+
+        query = f"""
+        WITH document_scores AS (
+            SELECT
+                id,
+                collection_ids,
+                owner_id,
+                type,
+                metadata,
+                title,
+                version,
+                size_in_bytes,
+                ingestion_status,
+                extraction_status,
+                created_at,
+                updated_at,
+                summary,
+                summary_embedding,
+                (summary_embedding <=> $1::vector({self.dimension})) as semantic_distance
+            FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+            WHERE {where_clause}
+            ORDER BY semantic_distance ASC
+            LIMIT ${len(params) + 1}
+            OFFSET ${len(params) + 2}
+        )
+        SELECT *,
+            1.0 - semantic_distance as semantic_score
+        FROM document_scores
+        """
+
+        params.extend([search_settings.limit, search_settings.offset])
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        return [
+            DocumentResponse(
+                id=row["id"],
+                collection_ids=row["collection_ids"],
+                owner_id=row["owner_id"],
+                document_type=DocumentType(row["type"]),
+                metadata={
+                    **(
+                        json.loads(row["metadata"])
+                        if search_settings.include_metadatas
+                        else {}
+                    ),
+                    "search_score": float(row["semantic_score"]),
+                    "search_type": "semantic",
+                },
+                title=row["title"],
+                version=row["version"],
+                size_in_bytes=row["size_in_bytes"],
+                ingestion_status=IngestionStatus(row["ingestion_status"]),
+                extraction_status=KGExtractionStatus(row["extraction_status"]),
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                summary=row["summary"],
+                summary_embedding=[
+                    float(x)
+                    for x in row["summary_embedding"][1:-1].split(",")
+                    if x
+                ],
+            )
+            for row in results
+        ]
+
+    async def full_text_document_search(
+        self, query_text: str, search_settings: SearchSettings
+    ) -> list[DocumentResponse]:
+        """Enhanced full-text search using generated tsvector."""
+
+        where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
+        params: list[str | int | bytes] = [query_text]
+
+        # Handle filters
+        if search_settings.filters:
+            filter_clause = self._build_filters(
+                search_settings.filters, params
+            )
+            where_clauses.append(filter_clause)
+
+        where_clause = " AND ".join(where_clauses)
+
+        query = f"""
+        WITH document_scores AS (
+            SELECT
+                id,
+                collection_ids,
+                owner_id,
+                type,
+                metadata,
+                title,
+                version,
+                size_in_bytes,
+                ingestion_status,
+                extraction_status,
+                created_at,
+                updated_at,
+                summary,
+                summary_embedding,
+                ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
+            FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+            WHERE {where_clause}
+            ORDER BY text_score DESC
+            LIMIT ${len(params) + 1}
+            OFFSET ${len(params) + 2}
+        )
+        SELECT * FROM document_scores
+        """
+
+        params.extend([search_settings.limit, search_settings.offset])
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        return [
+            DocumentResponse(
+                id=row["id"],
+                collection_ids=row["collection_ids"],
+                owner_id=row["owner_id"],
+                document_type=DocumentType(row["type"]),
+                metadata={
+                    **(
+                        json.loads(row["metadata"])
+                        if search_settings.include_metadatas
+                        else {}
+                    ),
+                    "search_score": float(row["text_score"]),
+                    "search_type": "full_text",
+                },
+                title=row["title"],
+                version=row["version"],
+                size_in_bytes=row["size_in_bytes"],
+                ingestion_status=IngestionStatus(row["ingestion_status"]),
+                extraction_status=KGExtractionStatus(row["extraction_status"]),
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                summary=row["summary"],
+                summary_embedding=(
+                    [
+                        float(x)
+                        for x in row["summary_embedding"][1:-1].split(",")
+                        if x
+                    ]
+                    if row["summary_embedding"]
+                    else None
+                ),
+            )
+            for row in results
+        ]
+
+    async def hybrid_document_search(
+        self,
+        query_text: str,
+        query_embedding: list[float],
+        search_settings: SearchSettings,
+    ) -> list[DocumentResponse]:
+        """Search documents using both semantic and full-text search with RRF fusion."""
+
+        # Get more results than needed for better fusion
+        extended_settings = copy.deepcopy(search_settings)
+        extended_settings.limit = search_settings.limit * 3
+
+        # Get results from both search methods
+        semantic_results = await self.semantic_document_search(
+            query_embedding, extended_settings
+        )
+        full_text_results = await self.full_text_document_search(
+            query_text, extended_settings
+        )
+
+        # Combine results using RRF
+        doc_scores: dict[str, dict] = {}
+
+        # Process semantic results
+        for rank, result in enumerate(semantic_results, 1):
+            doc_id = str(result.id)
+            doc_scores[doc_id] = {
+                "semantic_rank": rank,
+                "full_text_rank": len(full_text_results)
+                + 1,  # Default rank if not found
+                "data": result,
+            }
+
+        # Process full-text results
+        for rank, result in enumerate(full_text_results, 1):
+            doc_id = str(result.id)
+            if doc_id in doc_scores:
+                doc_scores[doc_id]["full_text_rank"] = rank
+            else:
+                doc_scores[doc_id] = {
+                    "semantic_rank": len(semantic_results)
+                    + 1,  # Default rank if not found
+                    "full_text_rank": rank,
+                    "data": result,
+                }
+
+        # Calculate RRF scores using hybrid search settings
+        rrf_k = search_settings.hybrid_settings.rrf_k
+        semantic_weight = search_settings.hybrid_settings.semantic_weight
+        full_text_weight = search_settings.hybrid_settings.full_text_weight
+
+        for scores in doc_scores.values():
+            semantic_score = 1 / (rrf_k + scores["semantic_rank"])
+            full_text_score = 1 / (rrf_k + scores["full_text_rank"])
+
+            # Weighted combination
+            combined_score = (
+                semantic_score * semantic_weight
+                + full_text_score * full_text_weight
+            ) / (semantic_weight + full_text_weight)
+
+            scores["final_score"] = combined_score
+
+        # Sort by final score and apply offset/limit
+        sorted_results = sorted(
+            doc_scores.values(), key=lambda x: x["final_score"], reverse=True
+        )[
+            search_settings.offset : search_settings.offset
+            + search_settings.limit
+        ]
+
+        return [
+            DocumentResponse(
+                **{
+                    **result["data"].__dict__,
+                    "metadata": {
+                        **(
+                            result["data"].metadata
+                            if search_settings.include_metadatas
+                            else {}
+                        ),
+                        "search_score": result["final_score"],
+                        "semantic_rank": result["semantic_rank"],
+                        "full_text_rank": result["full_text_rank"],
+                        "search_type": "hybrid",
+                    },
+                }
+            )
+            for result in sorted_results
+        ]
+
+    async def search_documents(
+        self,
+        query_text: str,
+        query_embedding: Optional[list[float]] = None,
+        settings: Optional[SearchSettings] = None,
+    ) -> list[DocumentResponse]:
+        """
+        Main search method that delegates to the appropriate search method based on settings.
+        """
+        if settings is None:
+            settings = SearchSettings()
+
+        if (
+            settings.use_semantic_search and settings.use_fulltext_search
+        ) or settings.use_hybrid_search:
+            if query_embedding is None:
+                raise ValueError(
+                    "query_embedding is required for hybrid search"
+                )
+            return await self.hybrid_document_search(
+                query_text, query_embedding, settings
+            )
+        elif settings.use_semantic_search:
+            if query_embedding is None:
+                raise ValueError(
+                    "query_embedding is required for vector search"
+                )
+            return await self.semantic_document_search(
+                query_embedding, settings
+            )
+        else:
+            return await self.full_text_document_search(query_text, settings)
+
+    # TODO - Remove copy pasta, consolidate
+    def _build_filters(
+        self, filters: dict, parameters: list[str | int | bytes]
+    ) -> str:
+
+        def parse_condition(key: str, value: Any) -> str:  # type: ignore
+            # nonlocal parameters
+            if key in self.COLUMN_VARS:
+                # Handle column-based filters
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        parameters.append(clause)
+                        return f"{key} = ${len(parameters)}"
+                    elif op == "$ne":
+                        parameters.append(clause)
+                        return f"{key} != ${len(parameters)}"
+                    elif op == "$in":
+                        parameters.append(clause)
+                        return f"{key} = ANY(${len(parameters)})"
+                    elif op == "$nin":
+                        parameters.append(clause)
+                        return f"{key} != ALL(${len(parameters)})"
+                    elif op == "$overlap":
+                        parameters.append(clause)
+                        return f"{key} && ${len(parameters)}"
+                    elif op == "$contains":
+                        parameters.append(clause)
+                        return f"{key} @> ${len(parameters)}"
+                    elif op == "$any":
+                        if key == "collection_ids":
+                            parameters.append(f"%{clause}%")
+                            return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
+                        parameters.append(clause)
+                        return f"${len(parameters)} = ANY({key})"
+                    else:
+                        raise ValueError(
+                            f"Unsupported operator for column {key}: {op}"
+                        )
+                else:
+                    # Handle direct equality
+                    parameters.append(value)
+                    return f"{key} = ${len(parameters)}"
+            else:
+                # Handle JSON-based filters
+                json_col = "metadata"
+                if key.startswith("metadata."):
+                    key = key.split("metadata.")[1]
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op not in (
+                        "$eq",
+                        "$ne",
+                        "$lt",
+                        "$lte",
+                        "$gt",
+                        "$gte",
+                        "$in",
+                        "$contains",
+                    ):
+                        raise ValueError("unknown operator")
+
+                    if op == "$eq":
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
+                        )
+                    elif op == "$ne":
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
+                        )
+                    elif op == "$lt":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
+                    elif op == "$lte":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
+                    elif op == "$gt":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
+                    elif op == "$gte":
+                        parameters.append(json.dumps(clause))
+                        return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
+                    elif op == "$in":
+                        if not isinstance(clause, list):
+                            raise ValueError(
+                                "argument to $in filter must be a list"
+                            )
+                        parameters.append(json.dumps(clause))
+                        return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
+                    elif op == "$contains":
+                        if not isinstance(clause, (int, str, float, list)):
+                            raise ValueError(
+                                "argument to $contains filter must be a scalar or array"
+                            )
+                        parameters.append(json.dumps(clause))
+                        return (
+                            f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
+                        )
+
+        def parse_filter(filter_dict: dict) -> str:
+            filter_conditions = []
+            for key, value in filter_dict.items():
+                if key == "$and":
+                    and_conditions = [
+                        parse_filter(f) for f in value if f
+                    ]  # Skip empty dictionaries
+                    if and_conditions:
+                        filter_conditions.append(
+                            f"({' AND '.join(and_conditions)})"
+                        )
+                elif key == "$or":
+                    or_conditions = [
+                        parse_filter(f) for f in value if f
+                    ]  # Skip empty dictionaries
+                    if or_conditions:
+                        filter_conditions.append(
+                            f"({' OR '.join(or_conditions)})"
+                        )
+                else:
+                    filter_conditions.append(parse_condition(key, value))
+
+            # Check if there is only a single condition
+            if len(filter_conditions) == 1:
+                return filter_conditions[0]
+            else:
+                return " AND ".join(filter_conditions)
+
+        where_clause = parse_filter(filters)
+
+        return where_clause

+ 275 - 0
core/database/files.py

@@ -0,0 +1,275 @@
+import io
+import logging
+from typing import BinaryIO, Optional, Union
+from uuid import UUID
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import Handler, R2RException
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresFilesHandler(Handler):
+    """PostgreSQL implementation of the FileHandler."""
+
+    TABLE_NAME = "files"
+
+    connection_manager: PostgresConnectionManager
+
+    async def create_tables(self) -> None:
+        """Create the necessary tables for file storage."""
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} (
+            document_id UUID PRIMARY KEY,
+            name TEXT NOT NULL,
+            oid OID NOT NULL,
+            size BIGINT NOT NULL,
+            type TEXT,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW()
+        );
+
+        -- Create trigger for updating the updated_at timestamp
+        CREATE OR REPLACE FUNCTION {self.project_name}.update_files_updated_at()
+        RETURNS TRIGGER AS $$
+        BEGIN
+            NEW.updated_at = CURRENT_TIMESTAMP;
+            RETURN NEW;
+        END;
+        $$ LANGUAGE plpgsql;
+
+        DROP TRIGGER IF EXISTS update_files_updated_at
+        ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)};
+
+        CREATE TRIGGER update_files_updated_at
+            BEFORE UPDATE ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+            FOR EACH ROW
+            EXECUTE FUNCTION {self.project_name}.update_files_updated_at();
+        """
+        await self.connection_manager.execute_query(query)
+
+    async def upsert_file(
+        self,
+        document_id: UUID,
+        file_name: str,
+        file_oid: int,
+        file_size: int,
+        file_type: Optional[str] = None,
+    ) -> None:
+        """Add or update a file entry in storage."""
+        query = f"""
+        INSERT INTO {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+        (document_id, name, oid, size, type)
+        VALUES ($1, $2, $3, $4, $5)
+        ON CONFLICT (document_id) DO UPDATE SET
+            name = EXCLUDED.name,
+            oid = EXCLUDED.oid,
+            size = EXCLUDED.size,
+            type = EXCLUDED.type,
+            updated_at = NOW();
+        """
+        await self.connection_manager.execute_query(
+            query, [document_id, file_name, file_oid, file_size, file_type]
+        )
+
+    async def store_file(
+        self,
+        document_id: UUID,
+        file_name: str,
+        file_content: io.BytesIO,
+        file_type: Optional[str] = None,
+    ) -> None:
+        """Store a new file in the database."""
+        size = file_content.getbuffer().nbytes
+
+        async with (  # type: ignore
+            self.connection_manager.pool.get_connection() as conn
+        ):
+            async with conn.transaction():
+                oid = await conn.fetchval("SELECT lo_create(0)")
+                await self._write_lobject(conn, oid, file_content)
+                await self.upsert_file(
+                    document_id, file_name, oid, size, file_type
+                )
+
+    async def _write_lobject(
+        self, conn, oid: int, file_content: io.BytesIO
+    ) -> None:
+        """Write content to a large object."""
+        lobject = await conn.fetchval("SELECT lo_open($1, $2)", oid, 0x20000)
+
+        try:
+            chunk_size = 8192  # 8 KB chunks
+            while True:
+                if chunk := file_content.read(chunk_size):
+                    await conn.execute(
+                        "SELECT lowrite($1, $2)", lobject, chunk
+                    )
+                else:
+                    break
+
+            await conn.execute("SELECT lo_close($1)", lobject)
+
+        except Exception as e:
+            await conn.execute("SELECT lo_unlink($1)", oid)
+            raise HTTPException(
+                status_code=500,
+                detail=f"Failed to write to large object: {e}",
+            )
+
+    async def retrieve_file(
+        self, document_id: UUID
+    ) -> Optional[tuple[str, BinaryIO, int]]:
+        """Retrieve a file from storage."""
+        query = f"""
+        SELECT name, oid, size
+        FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+        WHERE document_id = $1
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            query, [document_id]
+        )
+        if not result:
+            raise R2RException(
+                status_code=404,
+                message=f"File for document {document_id} not found",
+            )
+
+        file_name, oid, size = (
+            result["name"],
+            result["oid"],
+            result["size"],
+        )
+
+        async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+            file_content = await self._read_lobject(conn, oid)
+            return file_name, io.BytesIO(file_content), size
+
+    async def _read_lobject(self, conn, oid: int) -> bytes:
+        """Read content from a large object."""
+        file_data = io.BytesIO()
+        chunk_size = 8192
+
+        async with conn.transaction():
+            try:
+                lo_exists = await conn.fetchval(
+                    "SELECT EXISTS(SELECT 1 FROM pg_largeobject WHERE loid = $1)",
+                    oid,
+                )
+                if not lo_exists:
+                    raise R2RException(
+                        status_code=404,
+                        message=f"Large object {oid} not found.",
+                    )
+
+                lobject = await conn.fetchval(
+                    "SELECT lo_open($1, 262144)", oid
+                )
+
+                if lobject is None:
+                    raise R2RException(
+                        status_code=404,
+                        message=f"Failed to open large object {oid}.",
+                    )
+
+                while True:
+                    chunk = await conn.fetchval(
+                        "SELECT loread($1, $2)", lobject, chunk_size
+                    )
+                    if not chunk:
+                        break
+                    file_data.write(chunk)
+            except asyncpg.exceptions.UndefinedObjectError as e:
+                raise R2RException(
+                    status_code=404,
+                    message=f"Failed to read large object {oid}: {e}",
+                )
+            finally:
+                await conn.execute("SELECT lo_close($1)", lobject)
+
+        return file_data.getvalue()
+
+    async def delete_file(self, document_id: UUID) -> bool:
+        """Delete a file from storage."""
+        query = f"""
+        SELECT oid FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+        WHERE document_id = $1
+        """
+
+        async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
+            async with conn.transaction():
+                oid = await conn.fetchval(query, document_id)
+                if not oid:
+                    raise R2RException(
+                        status_code=404,
+                        message=f"File for document {document_id} not found",
+                    )
+
+                await self._delete_lobject(conn, oid)
+
+                delete_query = f"""
+                DELETE FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+                WHERE document_id = $1
+                """
+                await conn.execute(delete_query, document_id)
+
+        return True
+
+    async def _delete_lobject(self, conn, oid: int) -> None:
+        """Delete a large object."""
+        await conn.execute("SELECT lo_unlink($1)", oid)
+
+    async def get_files_overview(
+        self,
+        offset: int,
+        limit: int,
+        filter_document_ids: Optional[list[UUID]] = None,
+        filter_file_names: Optional[list[str]] = None,
+    ) -> list[dict]:
+        """Get an overview of stored files."""
+        conditions = []
+        params: list[Union[str, list[str], int]] = []
+        query = f"""
+        SELECT document_id, name, oid, size, type, created_at, updated_at
+        FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+        """
+
+        if filter_document_ids:
+            conditions.append(f"document_id = ANY(${len(params) + 1})")
+            params.append([str(doc_id) for doc_id in filter_document_ids])
+
+        if filter_file_names:
+            conditions.append(f"name = ANY(${len(params) + 1})")
+            params.append(filter_file_names)
+
+        if conditions:
+            query += " WHERE " + " AND ".join(conditions)
+
+        query += f" ORDER BY created_at DESC OFFSET ${len(params) + 1} LIMIT ${len(params) + 2}"
+        params.extend([offset, limit])
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        if not results:
+            raise R2RException(
+                status_code=404,
+                message="No files found with the given filters",
+            )
+
+        return [
+            {
+                "document_id": row["document_id"],
+                "file_name": row["name"],
+                "file_oid": row["oid"],
+                "file_size": row["size"],
+                "file_type": row["type"],
+                "created_at": row["created_at"],
+                "updated_at": row["updated_at"],
+            }
+            for row in results
+        ]

+ 2790 - 0
core/database/graphs.py

@@ -0,0 +1,2790 @@
+import asyncio
+import datetime
+import json
+import logging
+import os
+import time
+from enum import Enum
+from typing import Any, AsyncGenerator, Optional, Tuple, Union
+from uuid import UUID
+
+import asyncpg
+import httpx
+from asyncpg.exceptions import UndefinedTableError, UniqueViolationError
+from fastapi import HTTPException
+
+from core.base.abstractions import (
+    Community,
+    Entity,
+    Graph,
+    KGCreationSettings,
+    KGEnrichmentSettings,
+    KGEnrichmentStatus,
+    KGEntityDeduplicationSettings,
+    KGExtractionStatus,
+    R2RException,
+    Relationship,
+    VectorQuantizationType,
+)
+from core.base.api.models import GraphResponse
+from core.base.providers.database import Handler
+from core.base.utils import (
+    _decorate_vector_type,
+    _get_str_estimation_output,
+    llm_cost_per_million_tokens,
+)
+
+from .base import PostgresConnectionManager
+from .collections import PostgresCollectionsHandler
+
+
+class StoreType(str, Enum):
+    GRAPHS = "graphs"
+    DOCUMENTS = "documents"
+
+
+logger = logging.getLogger()
+
+
+class PostgresEntitiesHandler(Handler):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
+
+    def _get_table_name(self, table: str) -> str:
+        """Get the fully qualified table name."""
+        return f'"{self.project_name}"."{table}"'
+
+    def _get_entity_table_for_store(self, store_type: StoreType) -> str:
+        """Get the appropriate table name for the store type."""
+        if isinstance(store_type, StoreType):
+            store_type = store_type.value
+        return f"{store_type}_entities"
+
+    def _get_parent_constraint(self, store_type: StoreType) -> str:
+        """Get the appropriate foreign key constraint for the store type."""
+        if store_type == StoreType.GRAPHS:
+            return f"""
+                CONSTRAINT fk_graph
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("graphs")}(id)
+                    ON DELETE CASCADE
+            """
+        else:
+            return f"""
+                CONSTRAINT fk_document
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("documents")}(id)
+                    ON DELETE CASCADE
+            """
+
+    async def create_tables(self) -> None:
+        """Create separate tables for graph and document entities."""
+        vector_column_str = _decorate_vector_type(
+            f"({self.dimension})", self.quantization_type
+        )
+
+        for store_type in StoreType:
+            table_name = self._get_entity_table_for_store(store_type)
+            parent_constraint = self._get_parent_constraint(store_type)
+
+            QUERY = f"""
+                CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+                    id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                    name TEXT NOT NULL,
+                    category TEXT,
+                    description TEXT,
+                    parent_id UUID NOT NULL,
+                    description_embedding {vector_column_str},
+                    chunk_ids UUID[],
+                    metadata JSONB,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    {parent_constraint}
+                );
+                CREATE INDEX IF NOT EXISTS {table_name}_name_idx
+                    ON {self._get_table_name(table_name)} (name);
+                CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+                    ON {self._get_table_name(table_name)} (parent_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_category_idx
+                    ON {self._get_table_name(table_name)} (category);
+            """
+            await self.connection_manager.execute_query(QUERY)
+
+    async def create(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        name: str,
+        category: Optional[str] = None,
+        description: Optional[str] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        chunk_ids: Optional[list[UUID]] = None,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Entity:
+        """Create a new entity in the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+
+        if isinstance(metadata, str):
+            try:
+                metadata = json.loads(metadata)
+            except json.JSONDecodeError:
+                pass
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+        """
+
+        params = [
+            name,
+            category,
+            description,
+            parent_id,
+            description_embedding,
+            chunk_ids,
+            json.dumps(metadata) if metadata else None,
+        ]
+
+        result = await self.connection_manager.fetchrow_query(
+            query=query,
+            params=params,
+        )
+
+        return Entity(
+            id=result["id"],
+            name=result["name"],
+            category=result["category"],
+            description=result["description"],
+            parent_id=result["parent_id"],
+            chunk_ids=result["chunk_ids"],
+            metadata=result["metadata"],
+        )
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        entity_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        """Retrieve entities from the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if entity_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(entity_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(entity_names)
+            param_index += 1
+
+        select_fields = """
+            id, name, category, description, parent_id,
+            chunk_ids, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+        """
+
+        count_params = params[: param_index - 1]
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, count_params
+            )
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        entities = []
+        for row in rows:
+            # Convert the Record to a dictionary
+            entity_dict = dict(row)
+
+            # Process metadata if it exists and is a string
+            if isinstance(entity_dict["metadata"], str):
+                try:
+                    entity_dict["metadata"] = json.loads(
+                        entity_dict["metadata"]
+                    )
+                except json.JSONDecodeError:
+                    pass
+
+            entities.append(Entity(**entity_dict))
+
+        return entities, count
+
+    async def update(
+        self,
+        entity_id: UUID,
+        store_type: StoreType,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        category: Optional[str] = None,
+        metadata: Optional[dict] = None,
+    ) -> Entity:
+        """Update an entity in the specified store."""
+        table_name = self._get_entity_table_for_store(store_type)
+        update_fields = []
+        params: list[Any] = []
+        param_index = 1
+
+        if isinstance(metadata, str):
+            try:
+                metadata = json.loads(metadata)
+            except json.JSONDecodeError:
+                pass
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if description_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(description_embedding)
+            param_index += 1
+
+        if category is not None:
+            update_fields.append(f"category = ${param_index}")
+            params.append(category)
+            param_index += 1
+
+        if metadata is not None:
+            update_fields.append(f"metadata = ${param_index}")
+            params.append(json.dumps(metadata))
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(entity_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {', '.join(update_fields)}
+            WHERE id = ${param_index}\
+            RETURNING id, name, category, description, parent_id, chunk_ids, metadata
+        """
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Entity(
+                id=result["id"],
+                name=result["name"],
+                category=result["category"],
+                description=result["description"],
+                parent_id=result["parent_id"],
+                chunk_ids=result["chunk_ids"],
+                metadata=result["metadata"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the entity: {e}",
+            )
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        entity_ids: Optional[list[UUID]] = None,
+        store_type: StoreType = StoreType.GRAPHS,
+    ) -> None:
+        """
+        Delete entities from the specified store.
+        If entity_ids is not provided, deletes all entities for the given parent_id.
+
+        Args:
+            parent_id (UUID): Parent ID (collection_id or document_id)
+            entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
+            store_type (StoreType): Type of store (graph or document)
+
+        Returns:
+            list[UUID]: List of deleted entity IDs
+
+        Raises:
+            R2RException: If specific entities were requested but not all found
+        """
+        table_name = self._get_entity_table_for_store(store_type)
+
+        if entity_ids is None:
+            # Delete all entities for the parent_id
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE parent_id = $1
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [parent_id]
+            )
+        else:
+            # Delete specific entities
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE id = ANY($1) AND parent_id = $2
+                RETURNING id
+            """
+
+            results = await self.connection_manager.fetch_query(
+                QUERY, [entity_ids, parent_id]
+            )
+
+            # Check if all requested entities were deleted
+            deleted_ids = [row["id"] for row in results]
+            if entity_ids and len(deleted_ids) != len(entity_ids):
+                raise R2RException(
+                    f"Some entities not found in {store_type} store or no permission to delete",
+                    404,
+                )
+
+
+class PostgresRelationshipsHandler(Handler):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
+
+    def _get_table_name(self, table: str) -> str:
+        """Get the fully qualified table name."""
+        return f'"{self.project_name}"."{table}"'
+
+    def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
+        """Get the appropriate table name for the store type."""
+        if isinstance(store_type, StoreType):
+            store_type = store_type.value
+        return f"{store_type}_relationships"
+
+    def _get_parent_constraint(self, store_type: StoreType) -> str:
+        """Get the appropriate foreign key constraint for the store type."""
+        if store_type == StoreType.GRAPHS:
+            return f"""
+                CONSTRAINT fk_graph
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("graphs")}(id)
+                    ON DELETE CASCADE
+            """
+        else:
+            return f"""
+                CONSTRAINT fk_document
+                    FOREIGN KEY(parent_id)
+                    REFERENCES {self._get_table_name("documents")}(id)
+                    ON DELETE CASCADE
+            """
+
+    async def create_tables(self) -> None:
+        """Create separate tables for graph and document relationships."""
+        for store_type in StoreType:
+            table_name = self._get_relationship_table_for_store(store_type)
+            parent_constraint = self._get_parent_constraint(store_type)
+            vector_column_str = _decorate_vector_type(
+                f"({self.dimension})", self.quantization_type
+            )
+            QUERY = f"""
+                CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
+                    id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                    subject TEXT NOT NULL,
+                    predicate TEXT NOT NULL,
+                    object TEXT NOT NULL,
+                    description TEXT,
+                    description_embedding {vector_column_str},
+                    subject_id UUID,
+                    object_id UUID,
+                    weight FLOAT DEFAULT 1.0,
+                    chunk_ids UUID[],
+                    parent_id UUID NOT NULL,
+                    metadata JSONB,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    {parent_constraint}
+                );
+
+                CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
+                    ON {self._get_table_name(table_name)} (subject);
+                CREATE INDEX IF NOT EXISTS {table_name}_object_idx
+                    ON {self._get_table_name(table_name)} (object);
+                CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
+                    ON {self._get_table_name(table_name)} (predicate);
+                CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
+                    ON {self._get_table_name(table_name)} (parent_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
+                    ON {self._get_table_name(table_name)} (subject_id);
+                CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
+                    ON {self._get_table_name(table_name)} (object_id);
+            """
+            await self.connection_manager.execute_query(QUERY)
+
+    async def create(
+        self,
+        subject: str,
+        subject_id: UUID,
+        predicate: str,
+        object: str,
+        object_id: UUID,
+        parent_id: UUID,
+        store_type: StoreType,
+        description: str | None = None,
+        weight: float | None = 1.0,
+        chunk_ids: Optional[list[UUID]] = None,
+        description_embedding: Optional[list[float] | str] = None,
+        metadata: Optional[dict[str, Any] | str] = None,
+    ) -> Relationship:
+        """Create a new relationship in the specified store."""
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        if isinstance(metadata, str):
+            try:
+                metadata = json.loads(metadata)
+            except json.JSONDecodeError:
+                pass
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (subject, predicate, object, description, subject_id, object_id,
+             weight, chunk_ids, parent_id, description_embedding, metadata)
+            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
+            RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+        """
+
+        params = [
+            subject,
+            predicate,
+            object,
+            description,
+            subject_id,
+            object_id,
+            weight,
+            chunk_ids,
+            parent_id,
+            description_embedding,
+            json.dumps(metadata) if metadata else None,
+        ]
+
+        result = await self.connection_manager.fetchrow_query(
+            query=query,
+            params=params,
+        )
+
+        return Relationship(
+            id=result["id"],
+            subject=result["subject"],
+            predicate=result["predicate"],
+            object=result["object"],
+            description=result["description"],
+            subject_id=result["subject_id"],
+            object_id=result["object_id"],
+            weight=result["weight"],
+            chunk_ids=result["chunk_ids"],
+            parent_id=result["parent_id"],
+            metadata=result["metadata"],
+        )
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        relationship_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        relationship_types: Optional[list[str]] = None,
+        include_metadata: bool = False,
+    ):
+        """
+        Get relationships from the specified store.
+
+        Args:
+            parent_id: UUID of the parent (collection_id or document_id)
+            store_type: Type of store (graph or document)
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            relationship_ids: Optional list of specific relationship IDs to retrieve
+            entity_names: Optional list of entity names to filter by (matches subject or object)
+            relationship_types: Optional list of relationship types (predicates) to filter by
+            include_metadata: Whether to include metadata in the response
+
+        Returns:
+            Tuple of (list of relationships, total count)
+        """
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if relationship_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(relationship_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(
+                f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
+            )
+            params.append(entity_names)
+            param_index += 1
+
+        if relationship_types:
+            conditions.append(f"predicate = ANY(${param_index})")
+            params.append(relationship_types)
+            param_index += 1
+
+        select_fields = """
+            id, subject, predicate, object, description,
+            subject_id, object_id, weight, chunk_ids,
+            parent_id
+        """
+        if include_metadata:
+            select_fields += ", metadata"
+
+        # Count query
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+        """
+        count_params = params[: param_index - 1]
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, count_params
+            )
+        )[0]["count"]
+
+        # Main query
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        relationships = []
+        for row in rows:
+            relationship_dict = dict(row)
+            if include_metadata and isinstance(
+                relationship_dict["metadata"], str
+            ):
+                try:
+                    relationship_dict["metadata"] = json.loads(
+                        relationship_dict["metadata"]
+                    )
+                except json.JSONDecodeError:
+                    pass
+            elif not include_metadata:
+                relationship_dict.pop("metadata", None)
+            relationships.append(Relationship(**relationship_dict))
+
+        return relationships, count
+
+    async def update(
+        self,
+        relationship_id: UUID,
+        store_type: StoreType,
+        subject: Optional[str],
+        subject_id: Optional[UUID],
+        predicate: Optional[str],
+        object: Optional[str],
+        object_id: Optional[UUID],
+        description: Optional[str],
+        description_embedding: Optional[list[float] | str],
+        weight: Optional[float],
+        metadata: Optional[dict[str, Any] | str],
+    ) -> Relationship:
+        """Update multiple relationships in the specified store."""
+        table_name = self._get_relationship_table_for_store(store_type)
+        update_fields = []
+        params: list = []
+        param_index = 1
+
+        if isinstance(metadata, str):
+            try:
+                metadata = json.loads(metadata)
+            except json.JSONDecodeError:
+                pass
+
+        if subject is not None:
+            update_fields.append(f"subject = ${param_index}")
+            params.append(subject)
+            param_index += 1
+
+        if subject_id is not None:
+            update_fields.append(f"subject_id = ${param_index}")
+            params.append(subject_id)
+            param_index += 1
+
+        if predicate is not None:
+            update_fields.append(f"predicate = ${param_index}")
+            params.append(predicate)
+            param_index += 1
+
+        if object is not None:
+            update_fields.append(f"object = ${param_index}")
+            params.append(object)
+            param_index += 1
+
+        if object_id is not None:
+            update_fields.append(f"object_id = ${param_index}")
+            params.append(object_id)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if description_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(description_embedding)
+            param_index += 1
+
+        if weight is not None:
+            update_fields.append(f"weight = ${param_index}")
+            params.append(weight)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(relationship_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {', '.join(update_fields)}
+            WHERE id = ${param_index}
+            RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
+        """
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Relationship(
+                id=result["id"],
+                subject=result["subject"],
+                predicate=result["predicate"],
+                object=result["object"],
+                description=result["description"],
+                subject_id=result["subject_id"],
+                object_id=result["object_id"],
+                weight=result["weight"],
+                chunk_ids=result["chunk_ids"],
+                parent_id=result["parent_id"],
+                metadata=result["metadata"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the relationship: {e}",
+            )
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        relationship_ids: Optional[list[UUID]] = None,
+        store_type: StoreType = StoreType.GRAPHS,
+    ) -> None:
+        """
+        Delete relationships from the specified store.
+        If relationship_ids is not provided, deletes all relationships for the given parent_id.
+
+        Args:
+            parent_id: UUID of the parent (collection_id or document_id)
+            relationship_ids: Optional list of specific relationship IDs to delete
+            store_type: Type of store (graph or document)
+
+        Returns:
+            List of deleted relationship IDs
+
+        Raises:
+            R2RException: If specific relationships were requested but not all found
+        """
+        table_name = self._get_relationship_table_for_store(store_type)
+
+        if relationship_ids is None:
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE parent_id = $1
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [parent_id]
+            )
+        else:
+            QUERY = f"""
+                DELETE FROM {self._get_table_name(table_name)}
+                WHERE id = ANY($1) AND parent_id = $2
+                RETURNING id
+            """
+            results = await self.connection_manager.fetch_query(
+                QUERY, [relationship_ids, parent_id]
+            )
+
+            deleted_ids = [row["id"] for row in results]
+            if relationship_ids and len(deleted_ids) != len(relationship_ids):
+                raise R2RException(
+                    f"Some relationships not found in {store_type} store or no permission to delete",
+                    404,
+                )
+
+
+class PostgresCommunitiesHandler(Handler):
+
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
+
+    async def create_tables(self) -> None:
+
+        vector_column_str = _decorate_vector_type(
+            f"({self.dimension})", self.quantization_type
+        )
+
+        query = f"""
+            CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            collection_id UUID,
+            community_id UUID,
+            level INT,
+            name TEXT NOT NULL,
+            summary TEXT NOT NULL,
+            findings TEXT[],
+            rating FLOAT,
+            rating_explanation TEXT,
+            description_embedding {vector_column_str} NOT NULL,
+            created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+            updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+            metadata JSONB,
+            UNIQUE (community_id, level, collection_id)
+        );"""
+
+        await self.connection_manager.execute_query(query)
+
+    async def create(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        name: str,
+        summary: str,
+        findings: Optional[list[str]],
+        rating: Optional[float],
+        rating_explanation: Optional[str],
+        description_embedding: Optional[list[float] | str] = None,
+    ) -> Community:
+        # Do we ever want to get communities from document store?
+        table_name = "graphs_communities"
+
+        if isinstance(description_embedding, list):
+            description_embedding = str(description_embedding)
+
+        query = f"""
+            INSERT INTO {self._get_table_name(table_name)}
+            (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
+            VALUES ($1, $2, $3, $4, $5, $6, $7)
+            RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+        """
+
+        params = [
+            parent_id,
+            name,
+            summary,
+            findings,
+            rating,
+            rating_explanation,
+            description_embedding,
+        ]
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return Community(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                summary=result["summary"],
+                findings=result["findings"],
+                rating=result["rating"],
+                rating_explanation=result["rating_explanation"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while creating the community: {e}",
+            )
+
+    async def update(
+        self,
+        community_id: UUID,
+        store_type: StoreType,
+        name: Optional[str] = None,
+        summary: Optional[str] = None,
+        summary_embedding: Optional[list[float] | str] = None,
+        findings: Optional[list[str]] = None,
+        rating: Optional[float] = None,
+        rating_explanation: Optional[str] = None,
+    ) -> Community:
+        table_name = "graphs_communities"
+        update_fields = []
+        params: list[Any] = []
+        param_index = 1
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if summary is not None:
+            update_fields.append(f"summary = ${param_index}")
+            params.append(summary)
+            param_index += 1
+
+        if summary_embedding is not None:
+            update_fields.append(f"description_embedding = ${param_index}")
+            params.append(summary_embedding)
+            param_index += 1
+
+        if findings is not None:
+            update_fields.append(f"findings = ${param_index}")
+            params.append(findings)
+            param_index += 1
+
+        if rating is not None:
+            update_fields.append(f"rating = ${param_index}")
+            params.append(rating)
+            param_index += 1
+
+        if rating_explanation is not None:
+            update_fields.append(f"rating_explanation = ${param_index}")
+            params.append(rating_explanation)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(community_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(table_name)}
+            SET {", ".join(update_fields)}
+            WHERE id = ${param_index}\
+            RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
+        """
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+
+            return Community(
+                id=result["id"],
+                community_id=result["community_id"],
+                name=result["name"],
+                summary=result["summary"],
+                findings=result["findings"],
+                rating=result["rating"],
+                rating_explanation=result["rating_explanation"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the community: {e}",
+            )
+
+    async def delete(
+        self,
+        parent_id: UUID,
+        community_id: UUID,
+    ) -> None:
+        table_name = "graphs_communities"
+
+        query = f"""
+            DELETE FROM {self._get_table_name(table_name)}
+            WHERE id = $1 AND collection_id = $2
+        """
+
+        params = [community_id, parent_id]
+
+        try:
+            results = await self.connection_manager.execute_query(
+                query, params
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while deleting the community: {e}",
+            )
+
+        params = [
+            community_id,
+            parent_id,
+        ]
+
+        try:
+            results = await self.connection_manager.execute_query(
+                query, params
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while deleting the community: {e}",
+            )
+
+    async def get(
+        self,
+        parent_id: UUID,
+        store_type: StoreType,
+        offset: int,
+        limit: int,
+        community_ids: Optional[list[UUID]] = None,
+        community_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ):
+        """Retrieve communities from the specified store."""
+        # Do we ever want to get communities from document store?
+        table_name = "graphs_communities"
+
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if community_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(community_ids)
+            param_index += 1
+
+        if community_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(community_names)
+            param_index += 1
+
+        select_fields = """
+            id, community_id, name, summary, findings, rating,
+            rating_explanation, level, created_at, updated_at
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+        """
+
+        count = (
+            await self.connection_manager.fetch_query(
+                COUNT_QUERY, params[: param_index - 1]
+            )
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name(table_name)}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        communities = []
+        for row in rows:
+            community_dict = dict(row)
+
+            communities.append(Community(**community_dict))
+
+        return communities, count
+
+
+class PostgresGraphsHandler(Handler):
+    """Handler for Knowledge Graph METHODS in PostgreSQL."""
+
+    TABLE_NAME = "graphs"
+
+    def __init__(
+        self,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+
+        self.project_name: str = kwargs.get("project_name")  # type: ignore
+        self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
+        self.dimension: int = kwargs.get("dimension")  # type: ignore
+        self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
+        self.collections_handler: PostgresCollectionsHandler = kwargs.get("collections_handler")  # type: ignore
+
+        self.entities = PostgresEntitiesHandler(*args, **kwargs)
+        self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
+        self.communities = PostgresCommunitiesHandler(*args, **kwargs)
+
+        self.handlers = [
+            self.entities,
+            self.relationships,
+            self.communities,
+        ]
+
+        import networkx as nx
+
+        self.nx = nx
+
+    async def create_tables(self) -> None:
+        """Create the graph tables with mandatory collection_id support."""
+        QUERY = f"""
+            CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
+                id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+                collection_id UUID NOT NULL,
+                name TEXT NOT NULL,
+                description TEXT,
+                status TEXT NOT NULL,
+                document_ids UUID[],
+                metadata JSONB,
+                created_at TIMESTAMPTZ DEFAULT NOW(),
+                updated_at TIMESTAMPTZ DEFAULT NOW()
+            );
+
+            CREATE INDEX IF NOT EXISTS graph_collection_id_idx
+                ON {self._get_table_name("graphs")} (collection_id);
+        """
+
+        await self.connection_manager.execute_query(QUERY)
+
+        for handler in self.handlers:
+            await handler.create_tables()
+
+    async def create(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+        status: str = "pending",
+    ) -> GraphResponse:
+        """Create a new graph associated with a collection."""
+
+        name = name or f"Graph {collection_id}"
+        description = description or ""
+
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            (id, collection_id, name, description, status)
+            VALUES ($1, $2, $3, $4, $5)
+            RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
+        """
+        params = [
+            collection_id,
+            collection_id,
+            name,
+            description,
+            status,
+        ]
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query=query,
+                params=params,
+            )
+
+            return GraphResponse(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                description=result["description"],
+                status=result["status"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                document_ids=result["document_ids"] or [],
+            )
+        except UniqueViolationError:
+            raise R2RException(
+                message="Graph with this ID already exists",
+                status_code=409,
+            )
+
+    async def reset(self, parent_id: UUID) -> None:
+        """
+        Completely reset a graph and all associated data.
+        """
+        try:
+            entity_delete_query = f"""
+                DELETE FROM {self._get_table_name("graphs_entities")}
+                WHERE parent_id = $1
+            """
+            await self.connection_manager.execute_query(
+                entity_delete_query, [parent_id]
+            )
+
+            # Delete all graph relationships
+            relationship_delete_query = f"""
+                DELETE FROM {self._get_table_name("graphs_relationships")}
+                WHERE parent_id = $1
+            """
+            await self.connection_manager.execute_query(
+                relationship_delete_query, [parent_id]
+            )
+
+            # Delete all graph relationships
+            community_delete_query = f"""
+                DELETE FROM {self._get_table_name("graphs_communities")}
+                WHERE collection_id = $1
+            """
+            await self.connection_manager.execute_query(
+                community_delete_query, [parent_id]
+            )
+
+            # Delete all graph communities and community info
+            query = f"""
+                DELETE FROM {self._get_table_name("graphs_communities")}
+                WHERE collection_id = $1
+            """
+
+            await self.connection_manager.execute_query(query, [parent_id])
+
+        except Exception as e:
+            logger.error(f"Error deleting graph {parent_id}: {str(e)}")
+            raise R2RException(f"Failed to delete graph: {str(e)}", 500)
+
+    async def list_graphs(
+        self,
+        offset: int,
+        limit: int,
+        # filter_user_ids: Optional[list[UUID]] = None,
+        filter_graph_ids: Optional[list[UUID]] = None,
+        filter_collection_id: Optional[UUID] = None,
+    ) -> dict[str, list[GraphResponse] | int]:
+        conditions = []
+        params: list[Any] = []
+        param_index = 1
+
+        if filter_graph_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(filter_graph_ids)
+            param_index += 1
+
+        # if filter_user_ids:
+        #     conditions.append(f"user_id = ANY(${param_index})")
+        #     params.append(filter_user_ids)
+        #     param_index += 1
+
+        if filter_collection_id:
+            conditions.append(f"collection_id = ${param_index}")
+            params.append(filter_collection_id)
+            param_index += 1
+
+        where_clause = (
+            f"WHERE {' AND '.join(conditions)}" if conditions else ""
+        )
+
+        query = f"""
+            WITH RankedGraphs AS (
+                SELECT
+                    id, collection_id, name, description, status, created_at, updated_at, document_ids,
+                    COUNT(*) OVER() as total_entries,
+                    ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
+                FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+                {where_clause}
+            )
+            SELECT * FROM RankedGraphs
+            WHERE rn = 1
+            ORDER BY created_at DESC
+            OFFSET ${param_index} LIMIT ${param_index + 1}
+        """
+
+        params.extend([offset, limit])
+
+        try:
+            results = await self.connection_manager.fetch_query(query, params)
+            if not results:
+                return {"results": [], "total_entries": 0}
+
+            total_entries = results[0]["total_entries"] if results else 0
+
+            graphs = [
+                GraphResponse(
+                    id=row["id"],
+                    document_ids=row["document_ids"] or [],
+                    name=row["name"],
+                    collection_id=row["collection_id"],
+                    description=row["description"],
+                    status=row["status"],
+                    created_at=row["created_at"],
+                    updated_at=row["updated_at"],
+                )
+                for row in results
+            ]
+
+            return {"results": graphs, "total_entries": total_entries}
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while fetching graphs: {e}",
+            )
+
+    async def get(
+        self, offset: int, limit: int, graph_id: Optional[UUID] = None
+    ):
+
+        if graph_id is None:
+
+            params = [offset, limit]
+
+            QUERY = f"""
+                SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+                OFFSET $1 LIMIT $2
+            """
+
+            ret = await self.connection_manager.fetch_query(QUERY, params)
+
+            COUNT_QUERY = f"""
+                SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            """
+            count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
+                0
+            ]["count"]
+
+            return {
+                "results": [Graph(**row) for row in ret],
+                "total_entries": count,
+            }
+
+        else:
+            QUERY = f"""
+                SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
+            """
+
+            params = [graph_id]  # type: ignore
+
+            return {
+                "results": [
+                    Graph(
+                        **await self.connection_manager.fetchrow_query(
+                            QUERY, params
+                        )
+                    )
+                ]
+            }
+
+    async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
+        """
+        Add documents to the graph by copying their entities and relationships.
+        """
+        # Copy entities from document_entity to graphs_entities
+        ENTITY_COPY_QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_entities")} (
+                name, category, description, parent_id, description_embedding,
+                chunk_ids, metadata
+            )
+            SELECT
+                name, category, description, $1, description_embedding,
+                chunk_ids, metadata
+            FROM {self._get_table_name("documents_entities")}
+            WHERE parent_id = ANY($2)
+        """
+        await self.connection_manager.execute_query(
+            ENTITY_COPY_QUERY, [id, document_ids]
+        )
+
+        # Copy relationships from documents_relationships to graphs_relationships
+        RELATIONSHIP_COPY_QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_relationships")} (
+                subject, predicate, object, description, subject_id, object_id,
+                weight, chunk_ids, parent_id, metadata, description_embedding
+            )
+            SELECT
+                subject, predicate, object, description, subject_id, object_id,
+                weight, chunk_ids, $1, metadata, description_embedding
+            FROM {self._get_table_name("documents_relationships")}
+            WHERE parent_id = ANY($2)
+        """
+        await self.connection_manager.execute_query(
+            RELATIONSHIP_COPY_QUERY, [id, document_ids]
+        )
+
+        # Add document_ids to the graph
+        UPDATE_GRAPH_QUERY = f"""
+            UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            SET document_ids = array_cat(
+                CASE
+                    WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
+                    ELSE document_ids
+                END,
+                $2::uuid[]
+            )
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(
+            UPDATE_GRAPH_QUERY, [id, document_ids]
+        )
+
+        return True
+
+    async def update(
+        self,
+        collection_id: UUID,
+        name: Optional[str] = None,
+        description: Optional[str] = None,
+    ) -> GraphResponse:
+        """Update an existing graph."""
+        update_fields = []
+        params: list = []
+        param_index = 1
+
+        if name is not None:
+            update_fields.append(f"name = ${param_index}")
+            params.append(name)
+            param_index += 1
+
+        if description is not None:
+            update_fields.append(f"description = ${param_index}")
+            params.append(description)
+            param_index += 1
+
+        if not update_fields:
+            raise R2RException(status_code=400, message="No fields to update")
+
+        update_fields.append("updated_at = NOW()")
+        params.append(collection_id)
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
+            SET {', '.join(update_fields)}
+            WHERE id = ${param_index}
+            RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
+        """
+
+        try:
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+
+            if not result:
+                raise R2RException(status_code=404, message="Graph not found")
+
+            return GraphResponse(
+                id=result["id"],
+                collection_id=result["collection_id"],
+                name=result["name"],
+                description=result["description"],
+                status=result["status"],
+                created_at=result["created_at"],
+                document_ids=result["document_ids"] or [],
+                updated_at=result["updated_at"],
+            )
+        except Exception as e:
+            raise HTTPException(
+                status_code=500,
+                detail=f"An error occurred while updating the graph: {e}",
+            )
+
+    async def get_creation_estimate(
+        self,
+        graph_creation_settings: KGCreationSettings,
+        document_id: Optional[UUID] = None,
+        collection_id: Optional[UUID] = None,
+    ):
+        """Get the estimated cost and time for creating a KG."""
+
+        if bool(document_id) ^ bool(collection_id) is False:
+            raise ValueError(
+                "Exactly one of document_id or collection_id must be provided."
+            )
+
+        # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id.
+
+        document_ids = (
+            [document_id]
+            if document_id
+            else [
+                doc.id for doc in (await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"]  # type: ignore
+            ]
+        )
+
+        chunk_counts = await self.connection_manager.fetch_query(
+            f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} "
+            f"WHERE document_id = ANY($1) GROUP BY document_id",
+            [document_ids],
+        )
+
+        total_chunks = (
+            sum(doc["chunk_count"] for doc in chunk_counts)
+            // graph_creation_settings.chunk_merge_count
+        )
+        estimated_entities = (total_chunks * 10, total_chunks * 20)
+        estimated_relationships = (
+            int(estimated_entities[0] * 1.25),
+            int(estimated_entities[1] * 1.5),
+        )
+        estimated_llm_calls = (
+            total_chunks * 2 + estimated_entities[0],
+            total_chunks * 2 + estimated_entities[1],
+        )
+        total_in_out_tokens = tuple(
+            2000 * calls // 1000000 for calls in estimated_llm_calls
+        )
+        cost_per_million = llm_cost_per_million_tokens(
+            graph_creation_settings.generation_config.model
+        )
+        estimated_cost = tuple(
+            tokens * cost_per_million for tokens in total_in_out_tokens
+        )
+        total_time_in_minutes = tuple(
+            tokens * 10 / 60 for tokens in total_in_out_tokens
+        )
+
+        return {
+            "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.',
+            "document_count": len(document_ids),
+            "number_of_jobs_created": len(document_ids) + 1,
+            "total_chunks": total_chunks,
+            "estimated_entities": _get_str_estimation_output(
+                estimated_entities
+            ),
+            "estimated_relationships": _get_str_estimation_output(
+                estimated_relationships
+            ),
+            "estimated_llm_calls": _get_str_estimation_output(
+                estimated_llm_calls
+            ),
+            "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
+                total_in_out_tokens
+            ),
+            "estimated_cost_in_usd": _get_str_estimation_output(
+                estimated_cost
+            ),
+            "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
+            + _get_str_estimation_output(total_time_in_minutes),
+        }
+
+    async def get_enrichment_estimate(
+        self,
+        collection_id: UUID | None = None,
+        graph_id: UUID | None = None,
+        graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
+    ):
+        """Get the estimated cost and time for enriching a KG."""
+        if collection_id is not None:
+
+            document_ids = [
+                doc.id
+                for doc in (
+                    await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1)  # type: ignore
+                )["results"]
+            ]
+
+            # Get entity and relationship counts
+            entity_count = (
+                await self.connection_manager.fetch_query(
+                    f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1);",
+                    [document_ids],
+                )
+            )[0]["count"]
+
+            if not entity_count:
+                raise ValueError(
+                    "No entities found in the graph. Please run `extract-triples` first."
+                )
+
+            relationship_count = (
+                await self.connection_manager.fetch_query(
+                    f"""SELECT COUNT(*) FROM {self._get_table_name("documents_relationships")} WHERE document_id = ANY($1);""",
+                    [document_ids],
+                )
+            )[0]["count"]
+
+        else:
+            entity_count = (
+                await self.connection_manager.fetch_query(
+                    f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE $1 = ANY(graph_ids);",
+                    [graph_id],
+                )
+            )[0]["count"]
+
+            if not entity_count:
+                raise ValueError(
+                    "No entities found in the graph. Please run `extract-triples` first."
+                )
+
+            relationship_count = (
+                await self.connection_manager.fetch_query(
+                    f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE $1 = ANY(graph_ids);",
+                    [graph_id],
+                )
+            )[0]["count"]
+
+        # Calculate estimates
+        estimated_llm_calls = (entity_count // 10, entity_count // 5)
+        tokens_in_millions = tuple(
+            2000 * calls / 1000000 for calls in estimated_llm_calls
+        )
+        cost_per_million = llm_cost_per_million_tokens(
+            graph_enrichment_settings.generation_config.model  # type: ignore
+        )
+        estimated_cost = tuple(
+            tokens * cost_per_million for tokens in tokens_in_millions
+        )
+        estimated_time = tuple(
+            tokens * 10 / 60 for tokens in tokens_in_millions
+        )
+
+        return {
+            "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.',
+            "total_entities": entity_count,
+            "total_relationships": relationship_count,
+            "estimated_llm_calls": _get_str_estimation_output(
+                estimated_llm_calls
+            ),
+            "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
+                tokens_in_millions
+            ),
+            "estimated_cost_in_usd": _get_str_estimation_output(
+                estimated_cost
+            ),
+            "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
+            + _get_str_estimation_output(estimated_time),
+        }
+
+    async def get_deduplication_estimate(
+        self,
+        collection_id: UUID,
+        kg_deduplication_settings: KGEntityDeduplicationSettings,
+    ):
+        """Get the estimated cost and time for deduplicating entities in a KG."""
+        try:
+            query = f"""
+                SELECT name, count(name)
+                FROM {self._get_table_name("entity")}
+                WHERE document_id = ANY(
+                    SELECT document_id FROM {self._get_table_name("documents")}
+                    WHERE $1 = ANY(collection_ids)
+                )
+                GROUP BY name
+                HAVING count(name) >= 5
+            """
+            entities = await self.connection_manager.fetch_query(
+                query, [collection_id]
+            )
+            num_entities = len(entities)
+
+            estimated_llm_calls = (num_entities, num_entities)
+            tokens_in_millions = (
+                estimated_llm_calls[0] * 1000 / 1000000,
+                estimated_llm_calls[1] * 5000 / 1000000,
+            )
+            cost_per_million = llm_cost_per_million_tokens(
+                kg_deduplication_settings.generation_config.model
+            )
+            estimated_cost = (
+                tokens_in_millions[0] * cost_per_million,
+                tokens_in_millions[1] * cost_per_million,
+            )
+            estimated_time = (
+                tokens_in_millions[0] * 10 / 60,
+                tokens_in_millions[1] * 10 / 60,
+            )
+
+            return {
+                "message": "Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.",
+                "num_entities": num_entities,
+                "estimated_llm_calls": _get_str_estimation_output(
+                    estimated_llm_calls
+                ),
+                "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
+                    tokens_in_millions
+                ),
+                "estimated_cost_in_usd": _get_str_estimation_output(
+                    estimated_cost
+                ),
+                "estimated_total_time_in_minutes": _get_str_estimation_output(
+                    estimated_time
+                ),
+            }
+        except UndefinedTableError:
+            raise R2RException(
+                "Entity embedding table not found. Please run `extract-triples` first.",
+                404,
+            )
+        except Exception as e:
+            logger.error(f"Error in get_deduplication_estimate: {str(e)}")
+            raise HTTPException(500, "Error fetching deduplication estimate.")
+
+    async def get_entities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        entity_ids: Optional[list[UUID]] = None,
+        entity_names: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Entity], int]:
+        """
+        Get entities for a graph.
+
+        Args:
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            parent_id: UUID of the collection
+            entity_ids: Optional list of entity IDs to filter by
+            entity_names: Optional list of entity names to filter by
+            include_embeddings: Whether to include embeddings in the response
+
+        Returns:
+            Tuple of (list of entities, total count)
+        """
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if entity_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(entity_ids)
+            param_index += 1
+
+        if entity_names:
+            conditions.append(f"name = ANY(${param_index})")
+            params.append(entity_names)
+            param_index += 1
+
+        # Count query - uses the same conditions but without offset/limit
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_entities")}
+            WHERE {' AND '.join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        # Define base columns to select
+        select_fields = """
+            id, name, category, description, parent_id,
+            chunk_ids, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        # Main query for fetching entities with pagination
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_entities")}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        entities = []
+        for row in rows:
+            entity_dict = dict(row)
+            if isinstance(entity_dict["metadata"], str):
+                try:
+                    entity_dict["metadata"] = json.loads(
+                        entity_dict["metadata"]
+                    )
+                except json.JSONDecodeError:
+                    pass
+
+            entities.append(Entity(**entity_dict))
+
+        return entities, count
+
+    async def get_relationships(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        relationship_ids: Optional[list[UUID]] = None,
+        relationship_types: Optional[list[str]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Relationship], int]:
+        """
+        Get relationships for a graph.
+
+        Args:
+            parent_id: UUID of the graph
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            relationship_ids: Optional list of relationship IDs to filter by
+            relationship_types: Optional list of relationship types to filter by
+            include_metadata: Whether to include metadata in the response
+
+        Returns:
+            Tuple of (list of relationships, total count)
+        """
+        conditions = ["parent_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if relationship_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(relationship_ids)
+            param_index += 1
+
+        if relationship_types:
+            conditions.append(f"predicate = ANY(${param_index})")
+            params.append(relationship_types)
+            param_index += 1
+
+        # Count query - uses the same conditions but without offset/limit
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_relationships")}
+            WHERE {' AND '.join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        # Define base columns to select
+        select_fields = """
+            id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        # Main query for fetching relationships with pagination
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_relationships")}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        relationships = []
+        for row in rows:
+            relationship_dict = dict(row)
+            if isinstance(relationship_dict["metadata"], str):
+                try:
+                    relationship_dict["metadata"] = json.loads(
+                        relationship_dict["metadata"]
+                    )
+                except json.JSONDecodeError:
+                    pass
+
+            relationships.append(Relationship(**relationship_dict))
+
+        return relationships, count
+
+    async def add_entities(
+        self,
+        entities: list[Entity],
+        table_name: str,
+        conflict_columns: list[str] = [],
+    ) -> asyncpg.Record:
+        """
+        Upsert entities into the entities_raw table. These are raw entities extracted from the document.
+
+        Args:
+            entities: list[Entity]: list of entities to upsert
+            collection_name: str: name of the collection
+
+        Returns:
+            result: asyncpg.Record: result of the upsert operation
+        """
+        cleaned_entities = []
+        for entity in entities:
+            entity_dict = entity.to_dict()
+            entity_dict["chunk_ids"] = (
+                entity_dict["chunk_ids"]
+                if entity_dict.get("chunk_ids")
+                else []
+            )
+            entity_dict["description_embedding"] = (
+                str(entity_dict["description_embedding"])
+                if entity_dict.get("description_embedding")  # type: ignore
+                else None
+            )
+            cleaned_entities.append(entity_dict)
+
+        return await _add_objects(
+            objects=cleaned_entities,
+            full_table_name=self._get_table_name(table_name),
+            connection_manager=self.connection_manager,
+            conflict_columns=conflict_columns,
+        )
+
+    async def delete_node_via_document_id(
+        self, document_id: UUID, collection_id: UUID
+    ) -> None:
+        # don't delete if status is PROCESSING.
+        QUERY = f"""
+            SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
+        """
+        status = (
+            await self.connection_manager.fetch_query(QUERY, [collection_id])
+        )[0]["graph_cluster_status"]
+        if status == KGExtractionStatus.PROCESSING.value:
+            return
+
+        # Execute separate DELETE queries
+        delete_queries = [
+            f"""DELETE FROM {self._get_table_name("documents_relationships")} WHERE parent_id = $1""",
+            f"""DELETE FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1""",
+        ]
+
+        for query in delete_queries:
+            await self.connection_manager.execute_query(query, [document_id])
+        return None
+
+    async def get_all_relationships(
+        self,
+        collection_id: UUID | None,
+        graph_id: UUID | None,
+        document_ids: Optional[list[UUID]] = None,
+    ) -> list[Relationship]:
+
+        QUERY = f"""
+            SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
+        """
+        relationships = await self.connection_manager.fetch_query(
+            QUERY, [collection_id]
+        )
+
+        return [Relationship(**relationship) for relationship in relationships]
+
+    async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
+        """
+        Check if a document exists in the graph's document_ids array.
+
+        Args:
+            graph_id (UUID): ID of the graph to check
+            document_id (UUID): ID of the document to look for
+
+        Returns:
+            bool: True if document exists in graph, False otherwise
+
+        Raises:
+            R2RException: If graph not found
+        """
+        QUERY = f"""
+            SELECT EXISTS (
+                SELECT 1
+                FROM {self._get_table_name("graphs")}
+                WHERE id = $1
+                AND document_ids IS NOT NULL
+                AND $2 = ANY(document_ids)
+            ) as exists;
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            QUERY, [graph_id, document_id]
+        )
+
+        if result is None:
+            raise R2RException(f"Graph {graph_id} not found", 404)
+
+        return result["exists"]
+
+    async def get_communities(
+        self,
+        parent_id: UUID,
+        offset: int,
+        limit: int,
+        community_ids: Optional[list[UUID]] = None,
+        include_embeddings: bool = False,
+    ) -> tuple[list[Community], int]:
+        """
+        Get communities for a graph.
+
+        Args:
+            collection_id: UUID of the collection
+            offset: Number of records to skip
+            limit: Maximum number of records to return (-1 for no limit)
+            community_ids: Optional list of community IDs to filter by
+            include_embeddings: Whether to include embeddings in the response
+
+        Returns:
+            Tuple of (list of communities, total count)
+        """
+        conditions = ["collection_id = $1"]
+        params: list[Any] = [parent_id]
+        param_index = 2
+
+        if community_ids:
+            conditions.append(f"id = ANY(${param_index})")
+            params.append(community_ids)
+            param_index += 1
+
+        select_fields = """
+            id, collection_id, name, summary, findings, rating, rating_explanation
+        """
+        if include_embeddings:
+            select_fields += ", description_embedding"
+
+        COUNT_QUERY = f"""
+            SELECT COUNT(*)
+            FROM {self._get_table_name("graphs_communities")}
+            WHERE {' AND '.join(conditions)}
+        """
+        count = (
+            await self.connection_manager.fetch_query(COUNT_QUERY, params)
+        )[0]["count"]
+
+        QUERY = f"""
+            SELECT {select_fields}
+            FROM {self._get_table_name("graphs_communities")}
+            WHERE {' AND '.join(conditions)}
+            ORDER BY created_at
+            OFFSET ${param_index}
+        """
+        params.append(offset)
+        param_index += 1
+
+        if limit != -1:
+            QUERY += f" LIMIT ${param_index}"
+            params.append(limit)
+
+        rows = await self.connection_manager.fetch_query(QUERY, params)
+
+        communities = []
+        for row in rows:
+            community_dict = dict(row)
+            communities.append(Community(**community_dict))
+
+        return communities, count
+
+    async def add_community(self, community: Community) -> None:
+
+        # TODO: Fix in the short term.
+        # we need to do this because postgres insert needs to be a string
+        community.description_embedding = str(community.description_embedding)  # type: ignore[assignment]
+
+        non_null_attrs = {
+            k: v for k, v in community.__dict__.items() if v is not None
+        }
+        columns = ", ".join(non_null_attrs.keys())
+        placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs)))
+
+        conflict_columns = ", ".join(
+            [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
+        )
+
+        QUERY = f"""
+            INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
+            VALUES ({placeholders})
+            ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
+                {conflict_columns}
+            """
+
+        await self.connection_manager.execute_many(
+            QUERY, [tuple(non_null_attrs.values())]
+        )
+
+    async def delete_graph_for_collection(
+        self, collection_id: UUID, cascade: bool = False
+    ) -> None:
+
+        # don't delete if status is PROCESSING.
+        QUERY = f"""
+            SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
+        """
+        status = (
+            await self.connection_manager.fetch_query(QUERY, [collection_id])
+        )[0]["graph_cluster_status"]
+        if status == KGExtractionStatus.PROCESSING.value:
+            return
+
+        # remove all relationships for these documents.
+        DELETE_QUERIES = [
+            f"DELETE FROM {self._get_table_name('graphs_communities')} WHERE collection_id = $1;",
+        ]
+
+        # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
+        document_ids_response = (
+            await self.collections_handler.documents_in_collection(
+                offset=0,
+                limit=100,
+                collection_id=collection_id,
+            )
+        )
+
+        # This type ignore is due to insufficient typing of the documents_in_collection method
+        document_ids = [doc.id for doc in document_ids_response["results"]]  # type: ignore
+
+        # TODO: make these queries more efficient. Pass the document_ids as params.
+        if cascade:
+            DELETE_QUERIES += [
+                f"DELETE FROM {self._get_table_name('graphs_relationships')} WHERE document_id = ANY($1::uuid[]);",
+                f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE document_id = ANY($1::uuid[]);",
+                f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE collection_id = $1;",
+            ]
+
+            # setting the kg_creation_status to PENDING for this collection.
+            QUERY = f"""
+                UPDATE {self._get_table_name("documents")} SET extraction_status = $1 WHERE $2::uuid = ANY(collection_ids)
+            """
+            await self.connection_manager.execute_query(
+                QUERY, [KGExtractionStatus.PENDING, collection_id]
+            )
+
+        if document_ids:
+            for query in DELETE_QUERIES:
+                if "community" in query or "graphs_entities" in query:
+                    await self.connection_manager.execute_query(
+                        query, [collection_id]
+                    )
+                else:
+                    await self.connection_manager.execute_query(
+                        query, [document_ids]
+                    )
+
+        # set status to PENDING for this collection.
+        QUERY = f"""
+            UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
+        """
+        await self.connection_manager.execute_query(
+            QUERY, [KGExtractionStatus.PENDING, collection_id]
+        )
+
+    async def perform_graph_clustering(
+        self,
+        collection_id: UUID,
+        leiden_params: dict[str, Any],
+        clustering_mode: str,
+    ) -> Tuple[int, Any]:
+        """
+        Calls the external clustering service to cluster the KG.
+        """
+
+        offset = 0
+        page_size = 1000
+        all_relationships = []
+        while True:
+            relationships, count = await self.relationships.get(
+                parent_id=collection_id,
+                store_type=StoreType.GRAPHS,
+                offset=offset,
+                limit=page_size,
+            )
+
+            if not relationships:
+                break
+
+            all_relationships.extend(relationships)
+            offset += len(relationships)
+
+            if offset >= count:
+                break
+
+        relationship_ids_cache = await self._get_relationship_ids_cache(
+            all_relationships
+        )
+
+        logger.info(
+            f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
+        )
+
+        return await self._cluster_and_add_community_info(
+            relationships=all_relationships,
+            relationship_ids_cache=relationship_ids_cache,
+            leiden_params=leiden_params,
+            collection_id=collection_id,
+            clustering_mode=clustering_mode,
+        )
+
+    async def _call_clustering_service(
+        self, relationships: list[Relationship], leiden_params: dict[str, Any]
+    ) -> list[dict]:
+        """
+        Calls the external Graspologic clustering service, sending relationships and parameters.
+        Expects a response with 'communities' field.
+        """
+        # Convert relationships to a JSON-friendly format
+        rel_data = []
+        for r in relationships:
+            rel_data.append(
+                {
+                    "id": str(r.id),
+                    "subject": r.subject,
+                    "object": r.object,
+                    "weight": r.weight if r.weight is not None else 1.0,
+                }
+            )
+
+        endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
+        if not endpoint:
+            raise ValueError("CLUSTERING_SERVICE_URL not set.")
+
+        url = f"{endpoint}/cluster"
+
+        payload = {"relationships": rel_data, "leiden_params": leiden_params}
+
+        async with httpx.AsyncClient() as client:
+            response = await client.post(url, json=payload, timeout=3600)
+            response.raise_for_status()
+
+        data = response.json()
+        communities = data.get("communities", [])
+        return communities
+
+    async def _create_graph_and_cluster(
+        self,
+        relationships: list[Relationship],
+        leiden_params: dict[str, Any],
+        clustering_mode: str = "remote",
+    ) -> Any:
+        """
+        Create a graph and cluster it. If clustering_mode='local', use hierarchical_leiden locally.
+        If clustering_mode='remote', call the external service.
+        """
+
+        if clustering_mode == "remote":
+            logger.info("Sending request to external clustering service...")
+            communities = await self._call_clustering_service(
+                relationships, leiden_params
+            )
+            logger.info("Received communities from clustering service.")
+            return communities
+        else:
+            # Local mode: run hierarchical_leiden directly
+            G = self.nx.Graph()
+            for relationship in relationships:
+                G.add_edge(
+                    relationship.subject,
+                    relationship.object,
+                    weight=relationship.weight,
+                    id=relationship.id,
+                )
+
+            logger.info(
+                f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges"
+            )
+            return await self._compute_leiden_communities(G, leiden_params)
+
+    async def _cluster_and_add_community_info(
+        self,
+        relationships: list[Relationship],
+        relationship_ids_cache: dict[str, list[int]],
+        leiden_params: dict[str, Any],
+        collection_id: Optional[UUID] = None,
+        clustering_mode: str = "local",
+    ) -> Tuple[int, Any]:
+
+        # clear if there is any old information
+        conditions = []
+        if collection_id is not None:
+            conditions.append("collection_id = $1")
+
+        await asyncio.sleep(0.1)
+
+        start_time = time.time()
+
+        logger.info(f"Creating graph and clustering for {collection_id}")
+
+        hierarchical_communities = await self._create_graph_and_cluster(
+            relationships=relationships,
+            leiden_params=leiden_params,
+            clustering_mode=clustering_mode,
+        )
+
+        logger.info(
+            f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
+        )
+
+        def relationship_ids(node: str) -> list[int]:
+            return relationship_ids_cache.get(node, [])
+
+        logger.info(
+            f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds."
+        )
+
+        # If remote: hierarchical_communities is a list of dicts like:
+        # [{"node": str, "cluster": int, "level": int}, ...]
+        # If local: hierarchical_communities is the returned structure from hierarchical_leiden (list of named tuples)
+
+        if clustering_mode == "remote":
+            if not hierarchical_communities:
+                num_communities = 0
+            else:
+                num_communities = (
+                    max(item["cluster"] for item in hierarchical_communities)
+                    + 1
+                )
+        else:
+            # Local mode: hierarchical_communities returned by hierarchical_leiden
+            # According to the original code, it's likely a list of items with .cluster attribute
+            if not hierarchical_communities:
+                num_communities = 0
+            else:
+                num_communities = (
+                    max(item.cluster for item in hierarchical_communities) + 1
+                )
+
+        logger.info(
+            f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
+        )
+
+        return num_communities, hierarchical_communities
+
+    async def _get_relationship_ids_cache(
+        self, relationships: list[Relationship]
+    ) -> dict[str, list[int]]:
+        relationship_ids_cache: dict[str, list[int]] = {}
+        for relationship in relationships:
+            if relationship.subject is not None:
+                relationship_ids_cache.setdefault(relationship.subject, [])
+                if relationship.id is not None:
+                    relationship_ids_cache[relationship.subject].append(
+                        relationship.id
+                    )
+            if relationship.object is not None:
+                relationship_ids_cache.setdefault(relationship.object, [])
+                if relationship.id is not None:
+                    relationship_ids_cache[relationship.object].append(
+                        relationship.id
+                    )
+
+        return relationship_ids_cache
+
+    async def get_entity_map(
+        self, offset: int, limit: int, document_id: UUID
+    ) -> dict[str, dict[str, list[dict[str, Any]]]]:
+
+        QUERY1 = f"""
+            WITH entities_list AS (
+                SELECT DISTINCT name
+                FROM {self._get_table_name("documents_entities")}
+                WHERE parent_id = $1
+                ORDER BY name ASC
+                LIMIT {limit} OFFSET {offset}
+            )
+            SELECT e.name, e.description, e.category,
+                   (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
+                   e.parent_id
+            FROM {self._get_table_name("documents_entities")} e
+            JOIN entities_list el ON e.name = el.name
+            GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
+            ORDER BY e.name;"""
+
+        entities_list = await self.connection_manager.fetch_query(
+            QUERY1, [document_id]
+        )
+        entities_list = [Entity(**entity) for entity in entities_list]
+
+        QUERY2 = f"""
+            WITH entities_list AS (
+
+                SELECT DISTINCT name
+                FROM {self._get_table_name("documents_entities")}
+                WHERE parent_id = $1
+                ORDER BY name ASC
+                LIMIT {limit} OFFSET {offset}
+            )
+
+            SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
+                   (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
+            FROM {self._get_table_name("documents_relationships")} t
+            JOIN entities_list el ON t.subject = el.name
+            ORDER BY t.subject, t.predicate, t.object;
+        """
+
+        relationships_list = await self.connection_manager.fetch_query(
+            QUERY2, [document_id]
+        )
+        relationships_list = [
+            Relationship(**relationship) for relationship in relationships_list
+        ]
+
+        entity_map: dict[str, dict[str, list[Any]]] = {}
+        for entity in entities_list:
+            if entity.name not in entity_map:
+                entity_map[entity.name] = {"entities": [], "relationships": []}
+            entity_map[entity.name]["entities"].append(entity)
+
+        for relationship in relationships_list:
+            if relationship.subject in entity_map:
+                entity_map[relationship.subject]["relationships"].append(
+                    relationship
+                )
+            if relationship.object in entity_map:
+                entity_map[relationship.object]["relationships"].append(
+                    relationship
+                )
+
+        return entity_map
+
+    async def graph_search(
+        self, query: str, **kwargs: Any
+    ) -> AsyncGenerator[Any, None]:
+        """
+        Perform semantic search with similarity scores while maintaining exact same structure.
+        """
+
+        query_embedding = kwargs.get("query_embedding", None)
+        if query_embedding is None:
+            raise ValueError(
+                "query_embedding must be provided for semantic search"
+            )
+
+        search_type = kwargs.get(
+            "search_type", "entities"
+        )  # entities | relationships | communities
+        embedding_type = kwargs.get("embedding_type", "description_embedding")
+        property_names = kwargs.get("property_names", ["name", "description"])
+
+        # Add metadata if not present
+        if "metadata" not in property_names:
+            property_names.append("metadata")
+
+        filters = kwargs.get("filters", {})
+        limit = kwargs.get("limit", 10)
+        use_fulltext_search = kwargs.get("use_fulltext_search", True)
+        use_hybrid_search = kwargs.get("use_hybrid_search", True)
+
+        if use_hybrid_search or use_fulltext_search:
+            logger.warning(
+                "Hybrid and fulltext search not supported for graph search, ignoring."
+            )
+
+        table_name = f"graphs_{search_type}"
+        property_names_str = ", ".join(property_names)
+
+        # Build the WHERE clause from filters
+        params: list[Union[str, int, bytes]] = [
+            json.dumps(query_embedding),
+            limit,
+        ]
+        conditions_clause = self._build_filters(filters, params, search_type)
+        where_clause = (
+            f"WHERE {conditions_clause}" if conditions_clause else ""
+        )
+
+        # Construct the query
+        # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
+        # We'll convert that to similarity_score by doing (1 - distance).
+        QUERY = f"""
+            SELECT
+                {property_names_str},
+                ({embedding_type} <=> $1) as similarity_score
+            FROM {self._get_table_name(table_name)}
+            {where_clause}
+            ORDER BY {embedding_type} <=> $1
+            LIMIT $2;
+        """
+
+        results = await self.connection_manager.fetch_query(
+            QUERY, tuple(params)
+        )
+
+        for result in results:
+            output = {
+                prop: result[prop] for prop in property_names if prop in result
+            }
+            output["similarity_score"] = 1 - float(result["similarity_score"])
+            yield output
+
+    def _build_filters(
+        self, filter_dict: dict, parameters: list[Any], search_type: str
+    ) -> str:
+        """
+        Build a WHERE clause from a nested filter dictionary for the graph search.
+        For communities we use collection_id as primary key filter; for entities/relationships we use parent_id.
+        """
+
+        # Determine primary identifier column depending on search_type
+        # communities: use collection_id
+        # entities/relationships: use parent_id
+        base_id_column = (
+            "collection_id" if search_type == "communities" else "parent_id"
+        )
+
+        def parse_condition(key: str, value: Any) -> str:
+            # This function returns a single condition (string) or empty if no valid condition.
+            # Supported keys:
+            # - base_id_column (collection_id or parent_id)
+            # - metadata fields: metadata.some_field
+            # Supported ops: $eq, $ne, $lt, $lte, $gt, $gte, $in, $contains
+            if key == base_id_column:
+                # e.g. {"collection_id": {"$eq": "<some-uuid>"}}
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        parameters.append(str(clause))
+                        return f"{base_id_column} = ${len(parameters)}::uuid"
+                    elif op == "$in":
+                        # $in expects a list of UUIDs
+                        parameters.append([str(x) for x in clause])
+                        return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
+                else:
+                    # direct equality?
+                    parameters.append(str(value))
+                    return f"{base_id_column} = ${len(parameters)}::uuid"
+
+            elif key.startswith("metadata."):
+                # Handle metadata filters
+                # Example: {"metadata.some_key": {"$eq": "value"}}
+                field = key.split("metadata.")[1]
+
+                if isinstance(value, dict):
+                    op, clause = next(iter(value.items()))
+                    if op == "$eq":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}') = ${len(parameters)}"
+                    elif op == "$ne":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}') != ${len(parameters)}"
+                    elif op == "$lt":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}')::float < ${len(parameters)}::float"
+                    elif op == "$lte":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}')::float <= ${len(parameters)}::float"
+                    elif op == "$gt":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
+                    elif op == "$gte":
+                        parameters.append(clause)
+                        return f"(metadata->>'{field}')::float >= ${len(parameters)}::float"
+                    elif op == "$in":
+                        # Ensure clause is a list
+                        if not isinstance(clause, list):
+                            raise Exception(
+                                "argument to $in filter must be a list"
+                            )
+                        # Append the Python list as a parameter; many drivers can convert Python lists to arrays
+                        parameters.append(clause)
+                        # Cast the parameter to a text array type
+                        return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
+
+                    # elif op == "$in":
+                    #     # For $in, we assume an array of values and check if the field is in that set.
+                    #     # Note: This is simplistic, adjust as needed.
+                    #     parameters.append(clause)
+                    #     # convert field to text and check membership
+                    #     return f"(metadata->>'{field}') = ANY(SELECT jsonb_array_elements_text(${len(parameters)}::jsonb))"
+                    elif op == "$contains":
+                        # $contains for metadata likely means metadata @> clause in JSON.
+                        # If clause is dict or list, we use json containment.
+                        parameters.append(json.dumps(clause))
+                        return f"metadata @> ${len(parameters)}::jsonb"
+                else:
+                    # direct equality
+                    parameters.append(value)
+                    return f"(metadata->>'{field}') = ${len(parameters)}"
+
+            # Add additional conditions for other columns if needed
+            # If key not recognized, return empty so it doesn't break query
+            return ""
+
+        def parse_filter(fd: dict) -> str:
+            filter_conditions = []
+            for k, v in fd.items():
+                if k == "$and":
+                    and_parts = [parse_filter(sub) for sub in v if sub]
+                    # Remove empty strings
+                    and_parts = [x for x in and_parts if x.strip()]
+                    if and_parts:
+                        filter_conditions.append(
+                            f"({' AND '.join(and_parts)})"
+                        )
+                elif k == "$or":
+                    or_parts = [parse_filter(sub) for sub in v if sub]
+                    # Remove empty strings
+                    or_parts = [x for x in or_parts if x.strip()]
+                    if or_parts:
+                        filter_conditions.append(f"({' OR '.join(or_parts)})")
+                else:
+                    # Regular condition
+                    c = parse_condition(k, v)
+                    if c and c.strip():
+                        filter_conditions.append(c)
+
+            if not filter_conditions:
+                return ""
+            if len(filter_conditions) == 1:
+                return filter_conditions[0]
+            return " AND ".join(filter_conditions)
+
+        return parse_filter(filter_dict)
+
+    # async def _create_graph_and_cluster(
+    #     self, relationships: list[Relationship], leiden_params: dict[str, Any]
+    # ) -> Any:
+
+    #     G = self.nx.Graph()
+    #     for relationship in relationships:
+    #         G.add_edge(
+    #             relationship.subject,
+    #             relationship.object,
+    #             weight=relationship.weight,
+    #             id=relationship.id,
+    #         )
+
+    #     logger.info(f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges")
+
+    #     return await self._compute_leiden_communities(G, leiden_params)
+
+    async def _compute_leiden_communities(
+        self,
+        graph: Any,
+        leiden_params: dict[str, Any],
+    ) -> Any:
+        """Compute Leiden communities."""
+        try:
+            from graspologic.partition import hierarchical_leiden
+
+            if "random_seed" not in leiden_params:
+                leiden_params["random_seed"] = (
+                    7272  # add seed to control randomness
+                )
+
+            start_time = time.time()
+            logger.info(
+                f"Running Leiden clustering with params: {leiden_params}"
+            )
+
+            community_mapping = hierarchical_leiden(graph, **leiden_params)
+
+            logger.info(
+                f"Leiden clustering completed in {time.time() - start_time:.2f} seconds."
+            )
+            return community_mapping
+
+        except ImportError as e:
+            raise ImportError("Please install the graspologic package.") from e
+
+    async def get_existing_document_entity_chunk_ids(
+        self, document_id: UUID
+    ) -> list[str]:
+        QUERY = f"""
+            SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
+        """
+        return [
+            item["chunk_id"]
+            for item in await self.connection_manager.fetch_query(
+                QUERY, [document_id]
+            )
+        ]
+
+    async def get_entity_count(
+        self,
+        collection_id: Optional[UUID] = None,
+        document_id: Optional[UUID] = None,
+        distinct: bool = False,
+        entity_table_name: str = "entity",
+    ) -> int:
+
+        if collection_id is None and document_id is None:
+            raise ValueError(
+                "Either collection_id or document_id must be provided."
+            )
+
+        conditions = ["parent_id = $1"]
+        params = [str(document_id)]
+
+        count_value = "DISTINCT name" if distinct else "*"
+
+        QUERY = f"""
+            SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
+            WHERE {" AND ".join(conditions)}
+        """
+
+        return (await self.connection_manager.fetch_query(QUERY, params))[0][
+            "count"
+        ]
+
+    async def update_entity_descriptions(self, entities: list[Entity]):
+
+        query = f"""
+            UPDATE {self._get_table_name("graphs_entities")}
+            SET description = $3, description_embedding = $4
+            WHERE name = $1 AND graph_id = $2
+        """
+
+        inputs = [
+            (
+                entity.name,
+                entity.parent_id,
+                entity.description,
+                entity.description_embedding,
+            )
+            for entity in entities
+        ]
+
+        await self.connection_manager.execute_many(query, inputs)  # type: ignore
+
+
+def _json_serialize(obj):
+    if isinstance(obj, UUID):
+        return str(obj)
+    elif isinstance(obj, (datetime.datetime, datetime.date)):
+        return obj.isoformat()
+    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
+
+
+async def _add_objects(
+    objects: list[dict],
+    full_table_name: str,
+    connection_manager: PostgresConnectionManager,
+    conflict_columns: list[str] = [],
+    exclude_metadata: list[str] = [],
+) -> list[UUID]:
+    """
+    Bulk insert objects into the specified table using jsonb_to_recordset.
+    """
+
+    # Exclude specified metadata and prepare data
+    cleaned_objects = []
+    for obj in objects:
+        cleaned_obj = {
+            k: v
+            for k, v in obj.items()
+            if k not in exclude_metadata and v is not None
+        }
+        cleaned_objects.append(cleaned_obj)
+
+    # Serialize the list of objects to JSON
+    json_data = json.dumps(cleaned_objects, default=_json_serialize)
+
+    # Prepare the column definitions for jsonb_to_recordset
+
+    columns = cleaned_objects[0].keys()
+    column_defs = []
+    for col in columns:
+        # Map Python types to PostgreSQL types
+        sample_value = cleaned_objects[0][col]
+        if "embedding" in col:
+            pg_type = "vector"
+        elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
+            pg_type = "uuid[]"
+        elif col == "id" or "_id" in col:
+            pg_type = "uuid"
+        elif isinstance(sample_value, str):
+            pg_type = "text"
+        elif isinstance(sample_value, UUID):
+            pg_type = "uuid"
+        elif isinstance(sample_value, (int, float)):
+            pg_type = "numeric"
+        elif isinstance(sample_value, list) and all(
+            isinstance(x, UUID) for x in sample_value
+        ):
+            pg_type = "uuid[]"
+        elif isinstance(sample_value, list):
+            pg_type = "jsonb"
+        elif isinstance(sample_value, dict):
+            pg_type = "jsonb"
+        elif isinstance(sample_value, bool):
+            pg_type = "boolean"
+        elif isinstance(sample_value, (datetime.datetime, datetime.date)):
+            pg_type = "timestamp"
+        else:
+            raise TypeError(
+                f"Unsupported data type for column '{col}': {type(sample_value)}"
+            )
+
+        column_defs.append(f"{col} {pg_type}")
+
+    columns_str = ", ".join(columns)
+    column_defs_str = ", ".join(column_defs)
+
+    if conflict_columns:
+        conflict_columns_str = ", ".join(conflict_columns)
+        update_columns_str = ", ".join(
+            f"{col}=EXCLUDED.{col}"
+            for col in columns
+            if col not in conflict_columns
+        )
+        on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
+    else:
+        on_conflict_clause = ""
+
+    QUERY = f"""
+        INSERT INTO {full_table_name} ({columns_str})
+        SELECT {columns_str}
+        FROM jsonb_to_recordset($1::jsonb)
+        AS x({column_defs_str})
+        {on_conflict_clause}
+        RETURNING id;
+    """
+
+    # Execute the query
+    result = await connection_manager.fetch_query(QUERY, [json_data])
+
+    # Extract and return the IDs
+    return [record["id"] for record in result]

+ 229 - 0
core/database/limits.py

@@ -0,0 +1,229 @@
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from core.base import Handler, R2RException
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresLimitsHandler(Handler):
+    TABLE_NAME = "request_log"
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        route_limits: dict,
+    ):
+        super().__init__(project_name, connection_manager)
+        self.route_limits = route_limits
+
+    async def create_tables(self):
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+            time TIMESTAMPTZ NOT NULL,
+            user_id UUID NOT NULL,
+            route TEXT NOT NULL
+        );
+        """
+        await self.connection_manager.execute_query(query)
+
+    async def _count_requests(
+        self, user_id: UUID, route: Optional[str], since: datetime
+    ) -> int:
+        if route:
+            query = f"""
+            SELECT COUNT(*)::int
+            FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+            WHERE user_id = $1
+              AND route = $2
+              AND time >= $3
+            """
+            params = [user_id, route, since]
+        else:
+            query = f"""
+            SELECT COUNT(*)::int
+            FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+            WHERE user_id = $1
+              AND time >= $2
+            """
+            params = [user_id, since]
+
+        result = await self.connection_manager.fetchrow_query(query, params)
+        return result["count"] if result else 0
+
+    async def _count_monthly_requests(self, user_id: UUID) -> int:
+        now = datetime.now(timezone.utc)
+        start_of_month = now.replace(
+            day=1, hour=0, minute=0, second=0, microsecond=0
+        )
+        return await self._count_requests(
+            user_id, route=None, since=start_of_month
+        )
+
+    async def check_limits(self, user_id: UUID, route: str):
+        limits = self.route_limits.get(
+            route,
+            {
+                "global_per_min": 60,
+                "route_per_min": 30,
+                "monthly_limit": 10000,
+            },
+        )
+
+        global_per_min = limits["global_per_min"]
+        route_per_min = limits["route_per_min"]
+        monthly_limit = limits["monthly_limit"]
+
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+
+        # Global per-minute check
+        user_req_count = await self._count_requests(user_id, None, one_min_ago)
+        print("min req count = ", user_req_count)
+        if user_req_count >= global_per_min:
+            raise ValueError("Global per-minute rate limit exceeded")
+
+        # Per-route per-minute check
+        route_req_count = await self._count_requests(
+            user_id, route, one_min_ago
+        )
+        if route_req_count >= route_per_min:
+            raise ValueError("Per-route per-minute rate limit exceeded")
+
+        # Monthly limit check
+        monthly_count = await self._count_monthly_requests(user_id)
+        print("monthly_count = ", monthly_count)
+
+        if monthly_count >= monthly_limit:
+            raise ValueError("Monthly rate limit exceeded")
+
+    async def log_request(self, user_id: UUID, route: str):
+        query = f"""
+        INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
+        VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+        """
+        await self.connection_manager.execute_query(query, [user_id, route])
+
+
+# import logging
+# from datetime import datetime, timedelta
+# from typing import Optional
+# from uuid import UUID
+
+# from core.base import Handler, R2RException
+
+# from .base import PostgresConnectionManager
+
+# logger = logging.getLogger()
+
+
+# class PostgresLimitsHandler(Handler):
+#     TABLE_NAME = "request_log"
+
+#     def __init__(
+#         self,
+#         project_name: str,
+#         connection_manager: PostgresConnectionManager,
+#         route_limits: dict,
+#     ):
+#         super().__init__(project_name, connection_manager)
+#         self.route_limits = route_limits
+
+#     async def create_tables(self):
+#         """
+#         Create the request_log table if it doesn't exist.
+#         """
+#         query = f"""
+#         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+#             time TIMESTAMPTZ NOT NULL,
+#             user_id UUID NOT NULL,
+#             route TEXT NOT NULL
+#         );
+#         """
+#         await self.connection_manager.execute_query(query)
+
+#     async def _count_requests(
+#         self, user_id: UUID, route: Optional[str], since: datetime
+#     ) -> int:
+#         if route:
+#             query = f"""
+#             SELECT COUNT(*)::int
+#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+#             WHERE user_id = $1
+#               AND route = $2
+#               AND time >= $3
+#             """
+#             params = [user_id, route, since]
+#         else:
+#             query = f"""
+#             SELECT COUNT(*)::int
+#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+#             WHERE user_id = $1
+#               AND time >= $2
+#             """
+#             params = [user_id, since]
+
+#         result = await self.connection_manager.fetchrow_query(query, params)
+#         return result["count"] if result else 0
+
+#     async def _count_monthly_requests(self, user_id: UUID) -> int:
+#         now = datetime.utcnow()
+#         start_of_month = now.replace(
+#             day=1, hour=0, minute=0, second=0, microsecond=0
+#         )
+#         return await self._count_requests(
+#             user_id, route=None, since=start_of_month
+#         )
+
+#     async def check_limits(self, user_id: UUID, route: str):
+#         """
+#         Check if the user can proceed with the request, using route-specific limits.
+#         Raises ValueError if the user exceeded any limit.
+#         """
+#         limits = self.route_limits.get(
+#             route,
+#             {
+#                 "global_per_min": 60,  # default global per min
+#                 "route_per_min": 20,  # default route per min
+#                 "monthly_limit": 10000,  # default monthly limit
+#             },
+#         )
+
+#         global_per_min = limits["global_per_min"]
+#         route_per_min = limits["route_per_min"]
+#         monthly_limit = limits["monthly_limit"]
+
+#         now = datetime.utcnow()
+#         one_min_ago = now - timedelta(minutes=1)
+
+#         # Global per-minute check
+#         user_req_count = await self._count_requests(user_id, None, one_min_ago)
+#         print('min req count = ', user_req_count)
+#         if user_req_count >= global_per_min:
+#             raise ValueError("Global per-minute rate limit exceeded")
+
+#         # Per-route per-minute check
+#         route_req_count = await self._count_requests(
+#             user_id, route, one_min_ago
+#         )
+#         if route_req_count >= route_per_min:
+#             raise ValueError("Per-route per-minute rate limit exceeded")
+
+#         # Monthly limit check
+#         monthly_count = await self._count_monthly_requests(user_id)
+#         print('monthly_count = ', monthly_count)
+
+#         if monthly_count >= monthly_limit:
+#             raise ValueError("Monthly rate limit exceeded")
+
+#     async def log_request(self, user_id: UUID, route: str):
+#         query = f"""
+#         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
+#         VALUES (NOW(), $1, $2)
+#         """
+#         await self.connection_manager.execute_query(query, [user_id, route])

+ 296 - 0
core/database/postgres.py

@@ -0,0 +1,296 @@
+# TODO: Clean this up and make it more congruent across the vector database and the relational database.
+import logging
+import os
+import warnings
+from typing import TYPE_CHECKING, Any, Optional
+
+from ..base.abstractions import VectorQuantizationType
+from ..base.providers import (
+    DatabaseConfig,
+    DatabaseProvider,
+    PostgresConfigurationSettings,
+)
+from .base import PostgresConnectionManager, SemaphoreConnectionPool
+from .chunks import PostgresChunksHandler
+from .collections import PostgresCollectionsHandler
+from .conversations import PostgresConversationsHandler
+from .documents import PostgresDocumentsHandler
+from .files import PostgresFilesHandler
+from .graphs import (
+    PostgresCommunitiesHandler,
+    PostgresEntitiesHandler,
+    PostgresGraphsHandler,
+    PostgresRelationshipsHandler,
+)
+from .limits import PostgresLimitsHandler
+from .prompts_handler import PostgresPromptsHandler
+from .tokens import PostgresTokensHandler
+from .users import PostgresUserHandler
+
+if TYPE_CHECKING:
+    from ..providers.crypto import BCryptProvider
+
+logger = logging.getLogger()
+
+
+def get_env_var(new_var, old_var, config_value):
+    value = config_value or os.getenv(new_var) or os.getenv(old_var)
+    if os.getenv(old_var) and not os.getenv(new_var):
+        warnings.warn(
+            f"{old_var} is deprecated and support for it will be removed in release 3.5.0. Use {new_var} instead."
+        )
+    return value
+
+
+class PostgresDatabaseProvider(DatabaseProvider):
+    # R2R configuration settings
+    config: DatabaseConfig
+    project_name: str
+
+    # Postgres connection settings
+    user: str
+    password: str
+    host: str
+    port: int
+    db_name: str
+    connection_string: str
+    dimension: int
+    conn: Optional[Any]
+
+    crypto_provider: "BCryptProvider"
+    postgres_configuration_settings: PostgresConfigurationSettings
+    default_collection_name: str
+    default_collection_description: str
+
+    connection_manager: PostgresConnectionManager
+    documents_handler: PostgresDocumentsHandler
+    collections_handler: PostgresCollectionsHandler
+    token_handler: PostgresTokensHandler
+    users_handler: PostgresUserHandler
+    chunks_handler: PostgresChunksHandler
+    entities_handler: PostgresEntitiesHandler
+    communities_handler: PostgresCommunitiesHandler
+    relationships_handler: PostgresRelationshipsHandler
+    graphs_handler: PostgresGraphsHandler
+    prompts_handler: PostgresPromptsHandler
+    files_handler: PostgresFilesHandler
+    conversations_handler: PostgresConversationsHandler
+    limits_handler: PostgresLimitsHandler
+
+    def __init__(
+        self,
+        config: DatabaseConfig,
+        dimension: int,
+        crypto_provider: "BCryptProvider",
+        quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(config)
+
+        env_vars = [
+            ("user", "R2R_POSTGRES_USER", "POSTGRES_USER"),
+            ("password", "R2R_POSTGRES_PASSWORD", "POSTGRES_PASSWORD"),
+            ("host", "R2R_POSTGRES_HOST", "POSTGRES_HOST"),
+            ("port", "R2R_POSTGRES_PORT", "POSTGRES_PORT"),
+            ("db_name", "R2R_POSTGRES_DBNAME", "POSTGRES_DBNAME"),
+        ]
+
+        for attr, new_var, old_var in env_vars:
+            if value := get_env_var(new_var, old_var, getattr(config, attr)):
+                setattr(self, attr, value)
+            else:
+                raise ValueError(
+                    f"Error, please set a valid {new_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
+                )
+
+        self.port = int(self.port)
+
+        self.project_name = (
+            get_env_var(
+                "R2R_PROJECT_NAME",
+                "R2R_POSTGRES_PROJECT_NAME",  # Remove this after deprecation
+                config.app.project_name,
+            )
+            or "r2r_default"
+        )
+
+        if not self.project_name:
+            raise ValueError(
+                "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
+            )
+
+        # Check if it's a Unix socket connection
+        if self.host.startswith("/") and not self.port:
+            self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
+            logger.info("Connecting to Postgres via Unix socket")
+        else:
+            self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
+            logger.info("Connecting to Postgres via TCP/IP")
+
+        self.dimension = dimension
+        self.quantization_type = quantization_type
+        self.conn = None
+        self.config: DatabaseConfig = config
+        self.crypto_provider = crypto_provider
+        self.postgres_configuration_settings: PostgresConfigurationSettings = (
+            self._get_postgres_configuration_settings(config)
+        )
+        self.default_collection_name = config.default_collection_name
+        self.default_collection_description = (
+            config.default_collection_description
+        )
+
+        self.connection_manager: PostgresConnectionManager = (
+            PostgresConnectionManager()
+        )
+        self.documents_handler = PostgresDocumentsHandler(
+            self.project_name, self.connection_manager, self.dimension
+        )
+        self.token_handler = PostgresTokensHandler(
+            self.project_name, self.connection_manager
+        )
+        self.collections_handler = PostgresCollectionsHandler(
+            self.project_name, self.connection_manager, self.config
+        )
+        self.users_handler = PostgresUserHandler(
+            self.project_name, self.connection_manager, self.crypto_provider
+        )
+        self.chunks_handler = PostgresChunksHandler(
+            self.project_name,
+            self.connection_manager,
+            self.dimension,
+            self.quantization_type,
+        )
+        self.conversations_handler = PostgresConversationsHandler(
+            self.project_name, self.connection_manager
+        )
+        self.entities_handler = PostgresEntitiesHandler(
+            project_name=self.project_name,
+            connection_manager=self.connection_manager,
+            collections_handler=self.collections_handler,
+            dimension=self.dimension,
+            quantization_type=self.quantization_type,
+        )
+        self.relationships_handler = PostgresRelationshipsHandler(
+            project_name=self.project_name,
+            connection_manager=self.connection_manager,
+            collections_handler=self.collections_handler,
+            dimension=self.dimension,
+            quantization_type=self.quantization_type,
+        )
+        self.communities_handler = PostgresCommunitiesHandler(
+            project_name=self.project_name,
+            connection_manager=self.connection_manager,
+            collections_handler=self.collections_handler,
+            dimension=self.dimension,
+            quantization_type=self.quantization_type,
+        )
+        self.graphs_handler = PostgresGraphsHandler(
+            project_name=self.project_name,
+            connection_manager=self.connection_manager,
+            collections_handler=self.collections_handler,
+            dimension=self.dimension,
+            quantization_type=self.quantization_type,
+        )
+        self.prompts_handler = PostgresPromptsHandler(
+            self.project_name, self.connection_manager
+        )
+        self.files_handler = PostgresFilesHandler(
+            self.project_name, self.connection_manager
+        )
+
+        self.limits_handler = PostgresLimitsHandler(
+            project_name=self.project_name,
+            connection_manager=self.connection_manager,
+            # TODO - this should be set in the config
+            route_limits={},
+        )
+
+    async def initialize(self):
+        logger.info("Initializing `PostgresDatabaseProvider`.")
+        self.pool = SemaphoreConnectionPool(
+            self.connection_string, self.postgres_configuration_settings
+        )
+        await self.pool.initialize()
+        await self.connection_manager.initialize(self.pool)
+
+        async with self.pool.get_connection() as conn:
+            await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+            await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+            await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
+            await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
+
+            # Create schema if it doesn't exist
+            await conn.execute(
+                f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
+            )
+
+        await self.documents_handler.create_tables()
+        await self.collections_handler.create_tables()
+        await self.token_handler.create_tables()
+        await self.users_handler.create_tables()
+        await self.chunks_handler.create_tables()
+        await self.prompts_handler.create_tables()
+        await self.files_handler.create_tables()
+        await self.graphs_handler.create_tables()
+        await self.communities_handler.create_tables()
+        await self.entities_handler.create_tables()
+        await self.relationships_handler.create_tables()
+        await self.conversations_handler.create_tables()
+        await self.limits_handler.create_tables()
+
+    def _get_postgres_configuration_settings(
+        self, config: DatabaseConfig
+    ) -> PostgresConfigurationSettings:
+        settings = PostgresConfigurationSettings()
+
+        env_mapping = {
+            "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
+            "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
+            "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
+            "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
+            "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
+            "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
+            "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
+            "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
+            "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
+            "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
+            "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
+            "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
+            "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
+            "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
+            "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
+            "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
+            "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
+            "work_mem": "R2R_POSTGRES_WORK_MEM",
+        }
+
+        for setting, env_var in env_mapping.items():
+            value = getattr(
+                config.postgres_configuration_settings, setting, None
+            )
+            if value is None:
+                value = os.getenv(env_var)
+
+            if value is not None:
+                field_type = settings.__annotations__[setting]
+                if field_type == Optional[int]:
+                    value = int(value)
+                elif field_type == Optional[float]:
+                    value = float(value)
+
+                setattr(settings, setting, value)
+
+        return settings
+
+    async def close(self):
+        if self.pool:
+            await self.pool.close()
+
+    async def __aenter__(self):
+        await self.initialize()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.close()

+ 0 - 0
core/database/prompts/__init__.py


+ 27 - 0
core/database/prompts/chunk_enrichment.yaml

@@ -0,0 +1,27 @@
+chunk_enrichment:
+  template: >
+    ## Task:
+
+    Enrich and refine the given chunk of text using information from the provided context chunks. The goal is to make the chunk more precise and self-contained.
+
+    ## Context Chunks:
+    {context_chunks}
+
+    ## Chunk to Enrich:
+    {chunk}
+
+    ## Instructions:
+    1. Rewrite the chunk in third person.
+    2. Replace all common nouns with appropriate proper nouns. Use specific names, titles, or identifiers instead of general terms.
+    3. Use information from the context chunks to enhance the clarity and precision of the given chunk.
+    4. Ensure the enriched chunk remains independent and self-contained.
+    5. Do not incorporate specific information or details from other chunks into this one.
+    6. Focus on making the chunk more informative and precise within its own scope.
+    7. Maintain the original meaning and intent of the chunk while improving its clarity and usefulness.
+    8. Just output the enriched chunk. Do not include any other text.
+
+    ## Enriched Chunk:
+
+  input_types:
+    chunk: str
+    context_chunks: str

+ 41 - 0
core/database/prompts/default_collection_summary.yaml

@@ -0,0 +1,41 @@
+default_collection_summary:
+  template: >
+    ## Task:
+
+    Generate a comprehensive collection-level summary that describes the overall content, themes, and relationships across multiple documents. The summary should provide a high-level understanding of what the collection contains and represents.
+
+    ### Input Documents:
+
+    Document Summaries:
+    {document_summaries}
+
+    ### Requirements:
+
+    1. SCOPE
+    - Synthesize key themes and patterns across all documents
+    - Identify common topics, entities, and relationships
+    - Capture the collection's overall purpose or domain
+
+    2. STRUCTURE
+    - Target length: Approximately 3-4 concise sentences
+    - Focus on collective insights rather than individual document details
+
+    3. CONTENT GUIDELINES
+    - Emphasize shared concepts and recurring elements
+    - Highlight any temporal or thematic progression
+    - Identify key stakeholders or entities that appear across documents
+    - Note any significant relationships between documents
+
+    4. INTEGRATION PRINCIPLES
+    - Connect related concepts across different documents
+    - Identify overarching narratives or frameworks
+    - Preserve important context from individual documents
+    - Balance breadth of coverage with depth of insight
+
+    ### Query:
+
+    Generate a collection-level summary following the above requirements. Focus on synthesizing the key themes and relationships across all documents while maintaining clarity and concision.
+
+    ## Response:
+  input_types:
+    document_summaries: str

+ 28 - 0
core/database/prompts/default_rag.yaml

@@ -0,0 +1,28 @@
+default_rag:
+  template: >
+    ## Task:
+
+    Answer the query given immediately below given the context which follows later. Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context. Pay close attention to the title of each given source to ensure it is consistent with the query.
+
+
+    ### Query:
+
+    {query}
+
+
+    ### Context:
+
+    {context}
+
+
+    ### Query:
+
+    {query}
+
+
+    REMINDER - Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context.
+
+    ## Response:
+  input_types:
+    query: str
+    context: str

+ 18 - 0
core/database/prompts/default_summary.yaml

@@ -0,0 +1,18 @@
+default_summary:
+  template: >
+    ## Task:
+
+    Your task is to generate a descriptive summary of the document that follows. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+    ### Document:
+
+    {document}
+
+
+    ### Query:
+
+    Reminder: Your task is to generate a descriptive summary of the document that was given. Your objective is to return a summary that is roughly 10% of the input document size while retaining as many key points as possible. Your response should begin with `The document contains `.
+
+    ## Response:
+  input_types:
+    document: str

+ 3 - 0
core/database/prompts/default_system.yaml

@@ -0,0 +1,3 @@
+default_system:
+  template: You are a helpful agent.
+  input_types: {}

+ 109 - 0
core/database/prompts/graphrag_communities.yaml

@@ -0,0 +1,109 @@
+graphrag_communities:
+  template: |
+      You are an AI assistant that helps a human analyst perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
+
+      # Context
+      Collection Overview:
+      {collection_description}
+
+      # Goal
+      Write a comprehensive report of a community within this collection, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will inform decision-makers about information associated with the community and their potential impact within the broader context of the collection. The content includes an overview of the community's key entities and noteworthy claims.
+
+      # Report Structure
+      The report should include:
+
+      - NAME: A specific, concise community name representing its key entities
+      - SUMMARY: An executive summary that contextualizes the community within the broader collection, explaining its structure, relationships, and significant information
+      - IMPACT SEVERITY RATING: A float score (0-10) representing the community's IMPACT severity relative to the overall collection
+      - RATING EXPLANATION: A single sentence explaining the IMPACT severity rating in context of the broader collection
+      - DETAILED FINDINGS: 5-10 key insights about the community, incorporating relevant collection-level context where appropriate
+
+
+      Output Format:
+      ```json
+      {{
+          "name": <report_name>,
+          "summary": <executive_summary>,
+          "rating": <impact_severity_rating>,
+          "rating_explanation": <rating_explanation>,
+          "findings": [
+              "<finding1>",
+              "<finding2>",
+              "<finding3>",
+              "<finding4>",
+              "<finding5>"
+              // Additional findings...
+          ]
+      }}
+      ```
+
+      # Grounding Rules
+
+      Points supported by data should list their data references as follows:
+
+      "This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
+
+      Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
+
+      For example:
+      "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23)."
+
+      where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record.
+
+      Do not include information where the supporting evidence for it is not provided.
+
+      # Example Input
+      -----------
+      Text:
+
+      Entity: OpenAI
+      descriptions:
+        101,OpenAI is an AI research and deployment company.
+      relationships:
+        201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+        203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+        204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+      Entity: Stripe
+      descriptions:
+        102,Stripe is a technology company that builds economic infrastructure for the internet.
+      relationships:
+        201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions.
+        202,Stripe,Airbnb,Stripe provides payment processing services to Airbnb.
+        204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round.
+        205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+      Entity: Airbnb
+      descriptions:
+        103,Airbnb is an online marketplace for lodging and tourism experiences.
+      relationships:
+        203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service.
+        205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options.
+
+      Output:
+      {{
+          "name": "OpenAI, Stripe, and Airbnb",
+          "summary": "The comprises key startups like OpenAI, Stripe, and Airbnb, which are interconnected through strategic partnerships and investments. These relationships highlight a robust network focused on advancing AI technologies, payment infrastructure, and online marketplaces.",
+          "rating": 7.5,
+          "rating_explanation": "The impact severity rating is high due to the significant influence these startups have on technology, finance, and the global economy.",
+          "findings": [
+              "OpenAI stands out as a leader in artificial intelligence research and deployment within YCombinator. Its partnerships with companies like Stripe and Airbnb demonstrate its integral role in integrating AI solutions across various industries. OpenAI's influence is further amplified by its involvement in key projects that drive innovation and efficiency. [Data: Entities (101), Relationships (201, 203, 204, +more)]",
+              "Stripe serves as a critical financial infrastructure provider, facilitating payment processing for startups like Airbnb and partnering with OpenAI to enhance payment solutions. Its strategic investments and collaborations underscore its importance in the Y Combinator ecosystem, enabling seamless financial transactions and supporting startup growth. [Data: Entities (102), Relationships (201, 202, 204, 205, +more)]",
+              "Airbnb leverages OpenAI's artificial intelligence tools to enhance its customer service capabilities, showcasing the practical application of AI in improving user experience. This integration highlights Airbnb's commitment to innovation and efficiency, positioning it as a forward-thinking leader within the community. [Data: Entities (103), Relationships (203, 205, +more)]",
+              "Stripe's investment in OpenAI's latest funding round illustrates the strategic financial moves that drive growth and innovation. Such investments not only strengthen partnerships but also foster an environment of collaboration and shared success among startups. [Data: Relationships (204)]",
+              "The collaboration between Airbnb and Stripe to expand global payment options demonstrates a commitment to scalability and accessibility in the Y Combinator ecosystem. This initiative is pivotal in enabling startups to reach a broader international market, thereby increasing their impact and revenue potential. [Data: Relationships (205)]"
+          ]
+      }}
+
+      # Real Data
+
+      Use the following text for your answer. Do not make anything up in your answer.
+
+      Collection Context:
+      {collection_description}
+
+      Entity Data:
+      {input_text}
+
+      Output:
+  input_types:
+    collection_description: str
+    input_text: str

+ 24 - 0
core/database/prompts/graphrag_entity_deduplication.yaml

@@ -0,0 +1,24 @@
+graphrag_entity_deduplication:
+  template: |
+    You are an expert at deduplicating entity descriptions. You are given a list of entity descriptions and you need to merge them into a single description.
+
+    Entity Name:
+    {entity_name}
+
+    Entity Descriptions:
+    {entity_descriptions}
+
+    Your summary should:
+    1. Clearly define the entity's core concept or purpose.
+    2. Integrate any relevant information from the existing description.
+    3. Maintain a neutral, factual tone.
+    4. Make sure that all information from the original descriptions is included, but not repeated.
+    5. Do not hallucinate any information, you can only use the information provided.
+
+    Return the summary in the following format. Do not output anything else.
+
+    $$<Entity Description>$$
+
+  input_types:
+    entity_name: str
+    entity_descriptions: str

+ 39 - 0
core/database/prompts/graphrag_entity_description.yaml

@@ -0,0 +1,39 @@
+graphrag_entity_description:
+  template: |
+    Given the following information about an entity:
+
+    Document Summary:
+    {document_summary}
+
+    Entity Information:
+    {entity_info}
+
+    Relationship Data:
+    {relationships_txt}
+
+    Generate a comprehensive entity description that:
+
+    1. Opens with a clear definition statement identifying the entity's primary classification and core function
+    2. Incorporates key data points from both the document summary and relationship information
+    3. Emphasizes the entity's role within its broader context or system
+    4. Highlights critical relationships, particularly those that:
+      - Demonstrate hierarchical connections
+      - Show functional dependencies
+      - Indicate primary use cases or applications
+
+    Format Requirements:
+    - Length: 2-3 sentences
+    - Style: Technical and precise
+    - Structure: Definition + Context + Key Relationships
+    - Tone: Objective and authoritative
+
+    Integration Guidelines:
+    - Prioritize information that appears in multiple sources
+    - Resolve any conflicting information by favoring the most specific source
+    - Include temporal context if relevant to the entity's current state or evolution
+
+    Output should reflect the entity's complete nature while maintaining concision and clarity.
+  input_types:
+    document_summary: str
+    entity_info: str
+    relationships_txt: str

+ 55 - 0
core/database/prompts/graphrag_map_system.yaml

@@ -0,0 +1,55 @@
+graphrag_map_system:
+  template : |
+    ---Role---
+    You are a helpful assistant responding to questions about data in the tables provided.
+    ---Goal---
+    Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
+    You should use the data provided in the data tables below as the primary context for generating the response.
+    If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
+    Each key point in the response should have the following element:
+    - Description: A comprehensive description of the point.
+    - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
+    The response should be JSON formatted as follows:
+    {{
+        "points": [
+            {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}},
+            {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}}
+        ]
+    }}
+    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
+    Points supported by data should list the relevant reports as references as follows:
+    "This is an example sentence supported by data references [Data: Reports (report ids)]"
+    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
+    For example:
+    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
+    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables.
+    Do not include information where the supporting evidence for it is not provided.
+    ---Data tables---
+    {context_data}
+    ---Goal---
+    Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
+    You should use the data provided in the data tables below as the primary context for generating the response.
+    If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
+    Each key point in the response should have the following element:
+    - Description: A comprehensive description of the point.
+    - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
+    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
+    Points supported by data should list the relevant reports as references as follows:
+    "This is an example sentence supported by data references [Data: Reports (report ids)]"
+    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
+    For example:
+    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
+    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables.
+    Do not include information where the supporting evidence for it is not provided.
+    The response should be JSON formatted as follows:
+    {{
+        "points": [
+            {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}},
+            {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}}
+        ]
+    }}
+    ---Input---
+    {input}
+  input_types:
+    context_data: str
+    input: str

+ 43 - 0
core/database/prompts/graphrag_reduce_system.yaml

@@ -0,0 +1,43 @@
+graphrag_reduce_system:
+  template: |
+    ---Role---
+    You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts.
+    ---Goal---
+    Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
+    Note that the analysts' reports provided below are ranked in the **descending order of importance**.
+    If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
+    The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
+    Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
+    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
+    The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.
+    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
+
+    For example:
+    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
+    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
+    Do not include information where the supporting evidence for it is not provided.
+    ---Target response length and format---
+    {response_type}
+    ---Analyst Reports---
+    {report_data}
+    ---Goal---
+    Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
+    Note that the analysts' reports provided below are ranked in the **descending order of importance**.
+    If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
+    The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
+    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
+    The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.
+    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
+    For example:
+    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
+    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
+    Do not include information where the supporting evidence for it is not provided.
+    ---Target response length and format---
+    {response_type}
+    Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
+    -- Query --
+    {input}
+  input_types:
+    response_type: str
+    report_data: str
+    input: str

+ 134 - 0
core/database/prompts/graphrag_relationships_extraction_few_shot.yaml

@@ -0,0 +1,134 @@
+graphrag_relationships_extraction_few_shot:
+  template: >
+    -Goal-
+    Given both a document summary and full text, identify all entities and their entity types, along with all relationships among the identified entities.
+    Extract up to {max_knowledge_relationships} entity-relation relationships using both the summary context and full text.
+
+    -Context Summary-
+    {document_summary}
+
+    -Steps-
+    1. Identify all entities. For each identified entity, extract the following information:
+    1. Identify all entities given the full text, grounding and contextualizing them based on the summary. For each identified entity, extract:
+    - entity_name: Name of the entity, capitalized
+    - entity_type: Type of the entity (constrained to {entity_types} if provided, otherwise all types)
+    - entity_description: Comprehensive description incorporating context from both summary and full text
+
+    Format each entity as ("entity"$$$$<entity_name>$$$$<entity_type>$$$$<entity_description>)
+    Note: Generate additional entities from descriptions if they contain named entities for relationship mapping.
+
+    2. From the identified entities, identify all related entity pairs, using both summary and full text context:
+    - source_entity: name of the source entity
+    - target_entity: name of the target entity
+    - relation: relationship type (constrained to {relation_types} if provided)
+    - relationship_description: justification based on both summary and full text context
+    - relationship_weight: strength score 0-10
+
+    Format each relationship as ("relationship"$$$$<source_entity>$$$$<target_entity>$$$$<relation>$$$$<relationship_description>$$$$<relationship_weight>)
+
+    3. Coverage Requirements:
+    - Each entity must have at least one relationship
+    - Create intermediate entities if needed to establish relationships
+    - Verify relationships against both summary and full text
+    - Resolve any discrepancies between sources
+
+    Example 1:
+    If the list is empty, extract all entities and relations.
+    Entity_types:
+    Relation_types:
+    Text:
+    The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%.
+    ######################
+    Output:
+    ("entity"$$$$Central Institution$$$$Organization$$$$The central bank of Verdantis, responsible for monetary policy and setting interest rates)
+    ("entity"$$$$Martin Smith$$$$Person$$$$Chair of the Central Institution of Verdantis)
+    ("entity"$$$$Market Strategy Committee$$$$Organization$$$$Committee within the Central Institution that makes key decisions on monetary policy)
+    ("entity"$$$$Monday$$$$Time$$$$First meeting day of the Central Institution)
+    ("entity"$$$$Thursday$$$$Time$$$$Second meeting day of the Central Institution, when policy decisions are announced)
+    ("entity"$$$$1:30 PM PDT$$$$Time$$$$Scheduled time for the Central Institution's policy decision release on Thursday)
+    ("entity"$$$$Press Conference$$$$Event$$$$Media briefing held by the Central Institution following the policy decision release)
+    ("entity"$$$$Interest Rate$$$$Economic Concept$$$$Key monetary policy tool used by the Central Institution to influence the economy)
+    ("entity"$$$$3.5%-3.75%$$$$Economic Value$$$$Expected range for the benchmark interest rate)
+    ("relationship"$$$$Martin Smith$$$$Central Institution$$$$Chairs$$$$Martin Smith is the Chair of the Central Institution and will lead the press conference$$$$9)
+    ("relationship"$$$$Central Institution$$$$Press Conference$$$$Conducts$$$$The Central Institution conducts a press conference following its policy decision release$$$$9)
+    ("relationship"$$$$Market Strategy Committee$$$$Central Institution$$$$Part Of$$$$The Market Strategy Committee is a key decision-making body within the Central Institution$$$$9)
+    ("relationship"$$$$Market Strategy Committee$$$$Interest Rate$$$$Sets$$$$The Market Strategy Committee determines the benchmark interest rate$$$$9)
+    ("relationship"$$$$Central Institution$$$$Interest Rate$$$$Controls$$$$The Central Institution controls interest rates as part of its monetary policy$$$$9)
+    ("relationship"$$$$3.5%-3.75%$$$$Interest Rate$$$$Expected Range$$$$Investors anticipate the benchmark interest rate to remain within this range$$$$8)
+    ("relationship"$$$$Monday$$$$Central Institution$$$$Meeting Day$$$$The Central Institution holds its first meeting of the week on Monday$$$$7)
+    ("relationship"$$$$Thursday$$$$Central Institution$$$$Decision Day$$$$The Central Institution announces its policy decision on Thursday$$$$9)
+    ("relationship"$$$$1:30 PM PDT$$$$Central Institution$$$$Press Conference$$$$The policy decision release at 1:30 PM PDT is followed by the press conference$$$$8)
+
+    ######################
+    Example 2:
+    If the list is empty, extract all entities and relations.
+    Entity_types: Organization
+    Relation_types: Formerly Owned By
+
+    Text:
+    TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform.
+
+    TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones.
+    ######################
+    Output:
+    ("entity"$$$$TECHGLOBAL$$$$Organization$$$$TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones)
+    ("entity"$$$$VISION HOLDINGS$$$$Organization$$$$Vision Holdings is a firm that previously owned TechGlobal)
+    ("relationship"$$$$TECHGLOBAL$$$$VISION HOLDINGS$$$$Formerly Owned By$$$$Vision Holdings formerly owned TechGlobal from 2014 until present$$$$5)
+
+    ######################
+    Example 3:
+    If the list is empty, extract all entities and relations.
+    Entity_types: Organization,Geo,Person
+    Relation_types: ""
+    Text:
+    Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia.
+
+    The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara.
+
+    The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara.
+
+    They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion.
+
+    The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality.
+    ######################
+    Output:
+    ("entity"$$$$FIRUZABAD$$$$Geo$$$$Firuzabad held Aurelians as hostages)
+    ("entity"$$$$AURELIA$$$$Geo$$$$Country seeking to release hostages)
+    ("entity"$$$$QUINTARA$$$$Geo$$$$Country that negotiated a swap of money in exchange for hostages)
+    ("entity"$$$$TIRUZIA$$$$Geo$$$$Capital of Firuzabad where the Aurelians were being held)
+    ("entity"$$$$KROHAARA$$$$Geo$$$$Capital city in Quintara)
+    ("entity"$$$$CASHION$$$$Geo$$$$Capital city in Aurelia)
+    ("entity"$$$$SAMUEL NAMARA$$$$Person$$$$Aurelian who spent time in Tiruzia's Alhamia Prison)
+    ("entity"$$$$ALHAMIA PRISON$$$$Geo$$$$Prison in Tiruzia)
+    ("entity"$$$$DURKE BATAGLANI$$$$Person$$$$Aurelian journalist who was held hostage)
+    ("entity"$$$$MEGGIE TAZBAH$$$$Person$$$$Bratinas national and environmentalist who was held hostage)
+    ("relationship"$$$$FIRUZABAD$$$$AURELIA$$$$Negotiated Hostage Exchange$$$$Firuzabad negotiated a hostage exchange with Aurelia$$$$2)
+    ("relationship"$$$$QUINTARA$$$$AURELIA$$$$Negotiated Hostage Exchange$$$$Quintara brokered the hostage exchange between Firuzabad and Aurelia$$$$2)
+    ("relationship"$$$$QUINTARA$$$$FIRUZABAD$$$$Negotiated Hostage Exchange$$$$Quintara brokered the hostage exchange between Firuzabad and Aurelia$$$$2)
+    ("relationship"$$$$SAMUEL NAMARA$$$$ALHAMIA PRISON$$$$Held At Alhamia Prison$$$$Samuel Namara was a prisoner at Alhamia prison$$$$8)
+    ("relationship"$$$$SAMUEL NAMARA$$$$MEGGIE TAZBAH$$$$Exchanged Hostages$$$$Samuel Namara and Meggie Tazbah were exchanged in the same hostage release$$$$2)
+    ("relationship"$$$$SAMUEL NAMARA$$$$DURKE BATAGLANI$$$$Exchanged Hostages$$$$Samuel Namara and Durke Bataglani were exchanged in the same hostage release$$$$2)
+    ("relationship"$$$$MEGGIE TAZBAH$$$$DURKE BATAGLANI$$$$Exchanged Hostages$$$$Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release$$$$2)
+    ("relationship"$$$$SAMUEL NAMARA$$$$FIRUZABAD$$$$Held As Hostage$$$$Samuel Namara was a hostage in Firuzabad$$$$2)
+    ("relationship"$$$$MEGGIE TAZBAH$$$$FIRUZABAD$$$$Held As Hostage$$$$Meggie Tazbah was a hostage in Firuzabad$$$$2)
+    ("relationship"$$$$DURKE BATAGLANI$$$$FIRUZABAD$$$$Held As Hostage$$$$Durke Bataglani was a hostage in Firuzabad$$$$2)
+
+    -Real Data-
+    ######################
+    If the list is empty, extract all entities and relations.
+    Entity_types: {entity_types}
+    Relation_types: {relation_types}
+
+    Document Summary:
+    {document_summary}
+
+    Full Text:
+    {input}
+    ######################
+    Output:
+  input_types:
+    document_summary: str
+    max_knowledge_relationships: int
+    input: str
+    entity_types: list[str]
+    relation_types: list[str]

+ 29 - 0
core/database/prompts/hyde.yaml

@@ -0,0 +1,29 @@
+hyde:
+  template: >
+    ### Instruction:
+
+    Given the query that follows write a double newline separated list of {num_outputs} single paragraph distinct attempted answers to the given query.
+
+
+    DO NOT generate any single answer which is likely to require information from multiple distinct documents,
+
+    EACH single answer will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+
+    FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two attempted answers would be
+
+    `The key themes of Great Gatsby are ... ANSWER_CONTINUED` and `The key themes of 1984 are ... ANSWER_CONTINUED`, where `ANSWER_CONTINUED` IS TO BE COMPLETED BY YOU in your response.
+
+
+    Here is the original user query to be transformed into answers:
+
+
+    ### Query:
+
+    {message}
+
+
+    ### Response:
+  input_types:
+    num_outputs: int
+    message: str

+ 16 - 0
core/database/prompts/rag_agent.yaml

@@ -0,0 +1,16 @@
+rag_agent:
+  template: >
+    ### You are a helpful agent that can search for information.
+
+
+    When asked a question, perform a search to find relevant information and provide a response.
+
+
+    The response should contain line-item attributions to relevant search results, and be as informative if possible.
+
+    If no relevant results are found, then state that no results were found.
+
+    If no obvious question is present, then do not carry out a search, and instead ask for clarification.\
+
+    REMINDER - Use line item references to like [1], [2], ... refer to specifically numbered items in the provided context.
+  input_types: {}

+ 23 - 0
core/database/prompts/rag_context.yaml

@@ -0,0 +1,23 @@
+rag_context:
+  template: >+
+    ### Instruction:
+
+
+    You are given a `query` and an associated `context`. Your task is to sequentially score each sentence in the context as either 1 or 0, based on the relevancy to the given query. For instance, if the query is "What is the capital of France?" then the sentence "The capital of France is Paris" would receive a +1 value, whereas "The french enjoy wine" would receive a 0. Return your response as a tuple containing a list of 1s and 0s, where each value corresponds to the respective sentence in the context, and then the rational fraction of 1's to the total number of sentences (e.g. '1/4'). NOTE - do not include ANY extra text other than the requested tuple.
+
+
+    Query:
+
+    {query}
+
+
+    Context:
+
+    {context}
+
+
+    ###Response
+
+  input_types:
+    query: str
+    context: str

+ 27 - 0
core/database/prompts/rag_fusion.yaml

@@ -0,0 +1,27 @@
+rag_fusion:
+  template: >
+    ### Instruction:
+
+
+    Given the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query.
+
+    DO NOT generate any single query which is likely to require information from multiple distinct documents,
+
+    EACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents.
+
+    FOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be
+
+    `What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.
+
+    Here is the original user query to be transformed into answers:
+
+
+    ### Query:
+
+    {message}
+
+
+    ### Response:
+  input_types:
+    num_outputs: int
+    message: str

+ 4 - 0
core/database/prompts/vision_img.yaml

@@ -0,0 +1,4 @@
+vision_img:
+  template: >
+    First, provide a title for the image, then explain everything that you see. Be very thorough in your analysis as a user will need to understand the image without seeing it. If it is possible to transcribe the image to text directly, then do so. The more detail you provide, the better the user will understand the image.
+  input_types: {}

+ 42 - 0
core/database/prompts/vision_pdf.yaml

@@ -0,0 +1,42 @@
+vision_pdf:
+  template: >
+    Convert this PDF page to markdown format, preserving all content and formatting. Follow these guidelines:
+
+    Text:
+    - Maintain the original text hierarchy (headings, paragraphs, lists)
+    - Preserve any special formatting (bold, italic, underline)
+    - Include all footnotes, citations, and references
+    - Keep text in its original reading order
+
+    Tables:
+    - Recreate tables using markdown table syntax
+    - Preserve all headers, rows, and columns
+    - Maintain alignment and formatting where possible
+    - Include any table captions or notes
+
+    Equations:
+    - Convert mathematical equations using LaTeX notation
+    - Preserve equation numbers if present
+    - Include any surrounding context or references
+
+    Images:
+    - Enclose image descriptions within [FIG] and [/FIG] tags
+    - Include detailed descriptions of:
+      * Main subject matter
+      * Text overlays or captions
+      * Charts, graphs, or diagrams
+      * Relevant colors, patterns, or visual elements
+    - Maintain image placement relative to surrounding text
+
+    Additional Elements:
+    - Include page numbers if visible
+    - Preserve headers and footers
+    - Maintain sidebars or callout boxes
+    - Keep any special symbols or characters
+
+    Quality Requirements:
+    - Ensure 100% content preservation
+    - Maintain logical document flow
+    - Verify all markdown syntax is valid
+    - Double-check completeness before submitting
+  input_types: {}

+ 639 - 0
core/database/prompts_handler.py

@@ -0,0 +1,639 @@
+import json
+import logging
+import os
+from abc import abstractmethod
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Generic, Optional, TypeVar
+
+import yaml
+
+from core.base import Handler, generate_default_prompt_id
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+@dataclass
+class CacheEntry(Generic[T]):
+    """Represents a cached item with metadata"""
+
+    value: T
+    created_at: datetime
+    last_accessed: datetime
+    access_count: int = 0
+
+
+class Cache(Generic[T]):
+    """A generic cache implementation with TTL and LRU-like features"""
+
+    def __init__(
+        self,
+        ttl: Optional[timedelta] = None,
+        max_size: Optional[int] = 1000,
+        cleanup_interval: timedelta = timedelta(hours=1),
+    ):
+        self._cache: dict[str, CacheEntry[T]] = {}
+        self._ttl = ttl
+        self._max_size = max_size
+        self._cleanup_interval = cleanup_interval
+        self._last_cleanup = datetime.now()
+
+    def get(self, key: str) -> Optional[T]:
+        """Retrieve an item from cache"""
+        self._maybe_cleanup()
+
+        if key not in self._cache:
+            return None
+
+        entry = self._cache[key]
+
+        if self._ttl and datetime.now() - entry.created_at > self._ttl:
+            del self._cache[key]
+            return None
+
+        entry.last_accessed = datetime.now()
+        entry.access_count += 1
+        return entry.value
+
+    def set(self, key: str, value: T) -> None:
+        """Store an item in cache"""
+        self._maybe_cleanup()
+
+        now = datetime.now()
+        self._cache[key] = CacheEntry(
+            value=value, created_at=now, last_accessed=now
+        )
+
+        if self._max_size and len(self._cache) > self._max_size:
+            self._evict_lru()
+
+    def invalidate(self, key: str) -> None:
+        """Remove an item from cache"""
+        self._cache.pop(key, None)
+
+    def clear(self) -> None:
+        """Clear all cached items"""
+        self._cache.clear()
+
+    def _maybe_cleanup(self) -> None:
+        """Periodically clean up expired entries"""
+        now = datetime.now()
+        if now - self._last_cleanup > self._cleanup_interval:
+            self._cleanup()
+            self._last_cleanup = now
+
+    def _cleanup(self) -> None:
+        """Remove expired entries"""
+        if not self._ttl:
+            return
+
+        now = datetime.now()
+        expired = [
+            k for k, v in self._cache.items() if now - v.created_at > self._ttl
+        ]
+        for k in expired:
+            del self._cache[k]
+
+    def _evict_lru(self) -> None:
+        """Remove least recently used item"""
+        if not self._cache:
+            return
+
+        lru_key = min(
+            self._cache.keys(), key=lambda k: self._cache[k].last_accessed
+        )
+        del self._cache[lru_key]
+
+
+class CacheablePromptHandler(Handler):
+    """Abstract base class that adds caching capabilities to prompt handlers"""
+
+    def __init__(
+        self,
+        cache_ttl: Optional[timedelta] = timedelta(hours=1),
+        max_cache_size: Optional[int] = 1000,
+    ):
+        self._prompt_cache = Cache[str](ttl=cache_ttl, max_size=max_cache_size)
+        self._template_cache = Cache[dict](
+            ttl=cache_ttl, max_size=max_cache_size
+        )
+
+    def _cache_key(
+        self, prompt_name: str, inputs: Optional[dict] = None
+    ) -> str:
+        """Generate a cache key for a prompt request"""
+        if inputs:
+            # Sort dict items for consistent keys
+            sorted_inputs = sorted(inputs.items())
+            return f"{prompt_name}:{sorted_inputs}"
+        return prompt_name
+
+    async def get_cached_prompt(
+        self,
+        prompt_name: str,
+        inputs: Optional[dict[str, Any]] = None,
+        prompt_override: Optional[str] = None,
+        bypass_cache: bool = False,
+    ) -> str:
+        """Get a prompt with caching support"""
+        if prompt_override:
+            if inputs:
+                try:
+                    return prompt_override.format(**inputs)
+                except KeyError:
+                    return prompt_override
+            return prompt_override
+
+        cache_key = self._cache_key(prompt_name, inputs)
+
+        if not bypass_cache:
+            cached = self._prompt_cache.get(cache_key)
+            if cached is not None:
+                logger.debug(f"Cache hit for prompt: {cache_key}")
+                return cached
+
+        result = await self._get_prompt_impl(prompt_name, inputs)
+        self._prompt_cache.set(cache_key, result)
+        return result
+
+    async def get_prompt(  # type: ignore
+        self,
+        name: str,
+        inputs: Optional[dict] = None,
+        prompt_override: Optional[str] = None,
+    ) -> dict:
+        query = f"""
+        SELECT id, name, template, input_types, created_at, updated_at
+        FROM {self._get_table_name("prompts")}
+        WHERE name = $1;
+        """
+        result = await self.connection_manager.fetchrow_query(query, [name])
+
+        if not result:
+            raise ValueError(f"Prompt template '{name}' not found")
+
+        input_types = result["input_types"]
+        if isinstance(input_types, str):
+            input_types = json.loads(input_types)
+
+        return {
+            "id": result["id"],
+            "name": result["name"],
+            "template": result["template"],
+            "input_types": input_types,
+            "created_at": result["created_at"],
+            "updated_at": result["updated_at"],
+        }
+
+    @abstractmethod
+    async def _get_prompt_impl(
+        self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
+    ) -> str:
+        """Implementation of prompt retrieval logic"""
+        pass
+
+    async def update_prompt(
+        self,
+        name: str,
+        template: Optional[str] = None,
+        input_types: Optional[dict[str, str]] = None,
+    ) -> None:
+        """Public method to update a prompt with proper cache invalidation"""
+        # First invalidate all caches for this prompt
+        self._template_cache.invalidate(name)
+        cache_keys_to_invalidate = [
+            key
+            for key in self._prompt_cache._cache.keys()
+            if key.startswith(f"{name}:") or key == name
+        ]
+        for key in cache_keys_to_invalidate:
+            self._prompt_cache.invalidate(key)
+
+        # Perform the update
+        await self._update_prompt_impl(name, template, input_types)
+
+        # Force refresh template cache
+        template_info = await self._get_template_info(name)
+        if template_info:
+            self._template_cache.set(name, template_info)
+
+    @abstractmethod
+    async def _update_prompt_impl(
+        self,
+        name: str,
+        template: Optional[str] = None,
+        input_types: Optional[dict[str, str]] = None,
+    ) -> None:
+        """Implementation of prompt update logic"""
+        pass
+
+    @abstractmethod
+    async def _get_template_info(self, prompt_name: str) -> Optional[dict]:
+        """Get template info with caching"""
+        pass
+
+
+class PostgresPromptsHandler(CacheablePromptHandler):
+    """PostgreSQL implementation of the CacheablePromptHandler."""
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        prompt_directory: Optional[Path] = None,
+        **cache_options,
+    ):
+        super().__init__(**cache_options)
+        self.prompt_directory = (
+            prompt_directory or Path(os.path.dirname(__file__)) / "prompts"
+        )
+        self.connection_manager = connection_manager
+        self.project_name = project_name
+        self.prompts: dict[str, dict[str, str | dict[str, str]]] = {}
+
+    async def _load_prompts(self) -> None:
+        """Load prompts from both database and YAML files."""
+        # First load from database
+        await self._load_prompts_from_database()
+
+        # Then load from YAML files, potentially overriding unmodified database entries
+        await self._load_prompts_from_yaml_directory()
+
+    async def _load_prompts_from_database(self) -> None:
+        """Load prompts from the database."""
+        query = f"""
+        SELECT id, name, template, input_types, created_at, updated_at
+        FROM {self._get_table_name("prompts")};
+        """
+        try:
+            results = await self.connection_manager.fetch_query(query)
+            for row in results:
+                logger.info(f"Loading saved prompt: {row['name']}")
+
+                # Ensure input_types is a dictionary
+                input_types = row["input_types"]
+                if isinstance(input_types, str):
+                    input_types = json.loads(input_types)
+
+                self.prompts[row["name"]] = {
+                    "id": row["id"],
+                    "template": row["template"],
+                    "input_types": input_types,
+                    "created_at": row["created_at"],
+                    "updated_at": row["updated_at"],
+                }
+                # Pre-populate the template cache
+                self._template_cache.set(
+                    row["name"],
+                    {
+                        "id": row["id"],
+                        "template": row["template"],
+                        "input_types": input_types,
+                    },
+                )
+            logger.debug(f"Loaded {len(results)} prompts from database")
+        except Exception as e:
+            logger.error(f"Failed to load prompts from database: {e}")
+            raise
+
+    async def _load_prompts_from_yaml_directory(self) -> None:
+        """Load prompts from YAML files in the specified directory."""
+        if not self.prompt_directory.is_dir():
+            logger.warning(
+                f"Prompt directory not found: {self.prompt_directory}"
+            )
+            return
+
+        logger.info(f"Loading prompts from {self.prompt_directory}")
+        for yaml_file in self.prompt_directory.glob("*.yaml"):
+            logger.debug(f"Processing {yaml_file}")
+            try:
+                with open(yaml_file, "r") as file:
+                    data = yaml.safe_load(file)
+                    if not isinstance(data, dict):
+                        raise ValueError(
+                            f"Invalid format in YAML file {yaml_file}"
+                        )
+
+                    for name, prompt_data in data.items():
+                        should_modify = True
+                        if name in self.prompts:
+                            # Only modify if the prompt hasn't been updated since creation
+                            existing = self.prompts[name]
+                            should_modify = (
+                                existing["created_at"]
+                                == existing["updated_at"]
+                            )
+
+                        if should_modify:
+                            logger.info(f"Loading default prompt: {name}")
+                            await self.add_prompt(
+                                name=name,
+                                template=prompt_data["template"],
+                                input_types=prompt_data.get("input_types", {}),
+                                preserve_existing=(not should_modify),
+                            )
+            except Exception as e:
+                logger.error(f"Error loading {yaml_file}: {e}")
+                continue
+
+    def _get_table_name(self, base_name: str) -> str:
+        """Get the fully qualified table name."""
+        return f"{self.project_name}.{base_name}"
+
+    # Implementation of abstract methods from CacheablePromptHandler
+    async def _get_prompt_impl(
+        self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
+    ) -> str:
+        """Implementation of database prompt retrieval"""
+        template_info = await self._get_template_info(prompt_name)
+
+        if not template_info:
+            raise ValueError(f"Prompt template '{prompt_name}' not found")
+
+        template, input_types = (
+            template_info["template"],
+            template_info["input_types"],
+        )
+
+        if inputs:
+            # Validate input types
+            for key, value in inputs.items():
+                expected_type = input_types.get(key)
+                if not expected_type:
+                    raise ValueError(
+                        f"Unexpected input key: {key} expected input types: {input_types}"
+                    )
+            return template.format(**inputs)
+
+        return template
+
+    async def _get_template_info(self, prompt_name: str) -> Optional[dict]:  # type: ignore
+        """Get template info with caching"""
+        cached = self._template_cache.get(prompt_name)
+        if cached is not None:
+            return cached
+
+        query = f"""
+        SELECT template, input_types
+        FROM {self._get_table_name("prompts")}
+        WHERE name = $1;
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            query, [prompt_name]
+        )
+
+        if result:
+            # Ensure input_types is a dictionary
+            input_types = result["input_types"]
+            if isinstance(input_types, str):
+                input_types = json.loads(input_types)
+
+            template_info = {
+                "template": result["template"],
+                "input_types": input_types,
+            }
+            self._template_cache.set(prompt_name, template_info)
+            return template_info
+
+        return None
+
+    async def _update_prompt_impl(
+        self,
+        name: str,
+        template: Optional[str] = None,
+        input_types: Optional[dict[str, str]] = None,
+    ) -> None:
+        """Implementation of database prompt update with proper connection handling"""
+        if not template and not input_types:
+            return
+
+        # Clear caches first
+        self._template_cache.invalidate(name)
+        for key in list(self._prompt_cache._cache.keys()):
+            if key.startswith(f"{name}:"):
+                self._prompt_cache.invalidate(key)
+
+        # Build update query
+        set_clauses = []
+        params = [name]  # First parameter is always the name
+        param_index = 2  # Start from 2 since $1 is name
+
+        if template:
+            set_clauses.append(f"template = ${param_index}")
+            params.append(template)
+            param_index += 1
+
+        if input_types:
+            set_clauses.append(f"input_types = ${param_index}")
+            params.append(json.dumps(input_types))
+            param_index += 1
+
+        set_clauses.append("updated_at = CURRENT_TIMESTAMP")
+
+        query = f"""
+        UPDATE {self._get_table_name("prompts")}
+        SET {', '.join(set_clauses)}
+        WHERE name = $1
+        RETURNING id, template, input_types;
+        """
+
+        try:
+            # Execute update and get returned values
+            result = await self.connection_manager.fetchrow_query(
+                query, params
+            )
+
+            if not result:
+                raise ValueError(f"Prompt template '{name}' not found")
+
+            # Update in-memory state
+            if name in self.prompts:
+                if template:
+                    self.prompts[name]["template"] = template
+                if input_types:
+                    self.prompts[name]["input_types"] = input_types
+                self.prompts[name]["updated_at"] = datetime.now().isoformat()
+
+        except Exception as e:
+            logger.error(f"Failed to update prompt {name}: {str(e)}")
+            raise
+
+    async def create_tables(self):
+        """Create the necessary tables for storing prompts."""
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name("prompts")} (
+            id UUID PRIMARY KEY,
+            name VARCHAR(255) NOT NULL UNIQUE,
+            template TEXT NOT NULL,
+            input_types JSONB NOT NULL,
+            created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
+            updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
+        );
+
+        CREATE OR REPLACE FUNCTION {self.project_name}.update_updated_at_column()
+        RETURNS TRIGGER AS $$
+        BEGIN
+            NEW.updated_at = CURRENT_TIMESTAMP;
+            RETURN NEW;
+        END;
+        $$ language 'plpgsql';
+
+        DROP TRIGGER IF EXISTS update_prompts_updated_at
+        ON {self._get_table_name("prompts")};
+
+        CREATE TRIGGER update_prompts_updated_at
+            BEFORE UPDATE ON {self._get_table_name("prompts")}
+            FOR EACH ROW
+            EXECUTE FUNCTION {self.project_name}.update_updated_at_column();
+        """
+        await self.connection_manager.execute_query(query)
+        await self._load_prompts()
+
+    async def add_prompt(
+        self,
+        name: str,
+        template: str,
+        input_types: dict[str, str],
+        preserve_existing: bool = False,
+    ) -> None:
+        """Add or update a prompt."""
+        if preserve_existing and name in self.prompts:
+            return
+
+        id = generate_default_prompt_id(name)
+
+        # Ensure input_types is properly serialized
+        input_types_json = (
+            json.dumps(input_types)
+            if isinstance(input_types, dict)
+            else input_types
+        )
+
+        query = f"""
+        INSERT INTO {self._get_table_name("prompts")} (id, name, template, input_types)
+        VALUES ($1, $2, $3, $4)
+        ON CONFLICT (name) DO UPDATE
+        SET template = EXCLUDED.template,
+            input_types = EXCLUDED.input_types,
+            updated_at = CURRENT_TIMESTAMP
+        RETURNING id, created_at, updated_at;
+        """
+
+        result = await self.connection_manager.fetchrow_query(
+            query, [id, name, template, input_types_json]
+        )
+
+        self.prompts[name] = {
+            "id": result["id"],
+            "template": template,
+            "input_types": input_types,
+            "created_at": result["created_at"],
+            "updated_at": result["updated_at"],
+        }
+
+        # Update template cache
+        self._template_cache.set(
+            name,
+            {
+                "id": id,
+                "template": template,
+                "input_types": input_types,
+            },  # Store as dict in cache
+        )
+
+        # Invalidate any cached formatted prompts
+        for key in list(self._prompt_cache._cache.keys()):
+            if key.startswith(f"{name}:"):
+                self._prompt_cache.invalidate(key)
+
+    async def get_all_prompts(self) -> dict[str, Any]:
+        """Retrieve all stored prompts."""
+        query = f"""
+        SELECT id, name, template, input_types, created_at, updated_at, COUNT(*) OVER() AS total_entries
+        FROM {self._get_table_name("prompts")};
+        """
+        results = await self.connection_manager.fetch_query(query)
+
+        if not results:
+            return {"results": [], "total_entries": 0}
+
+        total_entries = results[0]["total_entries"] if results else 0
+
+        prompts = [
+            {
+                "name": row["name"],
+                "id": row["id"],
+                "template": row["template"],
+                "input_types": (
+                    json.loads(row["input_types"])
+                    if isinstance(row["input_types"], str)
+                    else row["input_types"]
+                ),
+                "created_at": row["created_at"],
+                "updated_at": row["updated_at"],
+            }
+            for row in results
+        ]
+
+        return {"results": prompts, "total_entries": total_entries}
+
+    async def delete_prompt(self, name: str) -> None:
+        """Delete a prompt template."""
+        query = f"""
+        DELETE FROM {self._get_table_name("prompts")}
+        WHERE name = $1;
+        """
+        result = await self.connection_manager.execute_query(query, [name])
+        if result == "DELETE 0":
+            raise ValueError(f"Prompt template '{name}' not found")
+
+        # Invalidate caches
+        self._template_cache.invalidate(name)
+        for key in list(self._prompt_cache._cache.keys()):
+            if key.startswith(f"{name}:"):
+                self._prompt_cache.invalidate(key)
+
+    async def get_message_payload(
+        self,
+        system_prompt_name: Optional[str] = None,
+        system_role: str = "system",
+        system_inputs: dict = {},
+        system_prompt_override: Optional[str] = None,
+        task_prompt_name: Optional[str] = None,
+        task_role: str = "user",
+        task_inputs: dict = {},
+        task_prompt_override: Optional[str] = None,
+    ) -> list[dict]:
+        """Create a message payload from system and task prompts."""
+        if system_prompt_override:
+            system_prompt = system_prompt_override
+        else:
+            system_prompt = await self.get_cached_prompt(
+                system_prompt_name or "default_system",
+                system_inputs,
+                prompt_override=system_prompt_override,
+            )
+
+        task_prompt = await self.get_cached_prompt(
+            task_prompt_name or "default_rag",
+            task_inputs,
+            prompt_override=task_prompt_override,
+        )
+
+        return [
+            {
+                "role": system_role,
+                "content": system_prompt,
+            },
+            {
+                "role": task_role,
+                "content": task_prompt,
+            },
+        ]

+ 67 - 0
core/database/tokens.py

@@ -0,0 +1,67 @@
+from datetime import datetime, timedelta
+from typing import Optional
+
+from core.base import Handler
+
+from .base import PostgresConnectionManager
+
+
+class PostgresTokensHandler(Handler):
+    TABLE_NAME = "blacklisted_tokens"
+
+    def __init__(
+        self, project_name: str, connection_manager: PostgresConnectionManager
+    ):
+        super().__init__(project_name, connection_manager)
+
+    async def create_tables(self):
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            token TEXT NOT NULL,
+            blacklisted_at TIMESTAMPTZ DEFAULT NOW()
+        );
+        CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_token
+        ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token);
+        CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_blacklisted_at
+        ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (blacklisted_at);
+        """
+        await self.connection_manager.execute_query(query)
+
+    async def blacklist_token(
+        self, token: str, current_time: Optional[datetime] = None
+    ):
+        if current_time is None:
+            current_time = datetime.utcnow()
+
+        query = f"""
+        INSERT INTO {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token, blacklisted_at)
+        VALUES ($1, $2)
+        """
+        await self.connection_manager.execute_query(
+            query, [token, current_time]
+        )
+
+    async def is_token_blacklisted(self, token: str) -> bool:
+        query = f"""
+        SELECT 1 FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+        WHERE token = $1
+        LIMIT 1
+        """
+        result = await self.connection_manager.fetchrow_query(query, [token])
+        return bool(result)
+
+    async def clean_expired_blacklisted_tokens(
+        self,
+        max_age_hours: int = 7 * 24,
+        current_time: Optional[datetime] = None,
+    ):
+        if current_time is None:
+            current_time = datetime.utcnow()
+        expiry_time = current_time - timedelta(hours=max_age_hours)
+
+        query = f"""
+        DELETE FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
+        WHERE blacklisted_at < $1
+        """
+        await self.connection_manager.execute_query(query, [expiry_time])

+ 660 - 0
core/database/users.py

@@ -0,0 +1,660 @@
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from fastapi import HTTPException
+
+from core.base import CryptoProvider, Handler
+from core.base.abstractions import R2RException
+from core.utils import generate_user_id
+from shared.abstractions import User
+
+from .base import PostgresConnectionManager, QueryBuilder
+from .collections import PostgresCollectionsHandler
+
+
+class PostgresUserHandler(Handler):
+    TABLE_NAME = "users"
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        crypto_provider: CryptoProvider,
+    ):
+        super().__init__(project_name, connection_manager)
+        self.crypto_provider = crypto_provider
+
+    async def create_tables(self):
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
+            id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
+            email TEXT UNIQUE NOT NULL,
+            hashed_password TEXT NOT NULL,
+            is_superuser BOOLEAN DEFAULT FALSE,
+            is_active BOOLEAN DEFAULT TRUE,
+            is_verified BOOLEAN DEFAULT FALSE,
+            verification_code TEXT,
+            verification_code_expiry TIMESTAMPTZ,
+            name TEXT,
+            bio TEXT,
+            profile_picture TEXT,
+            reset_token TEXT,
+            reset_token_expiry TIMESTAMPTZ,
+            collection_ids UUID[] NULL,
+            created_at TIMESTAMPTZ DEFAULT NOW(),
+            updated_at TIMESTAMPTZ DEFAULT NOW()
+        );
+        """
+        await self.connection_manager.execute_query(query)
+
+    async def get_user_by_id(self, id: UUID) -> User:
+        query, _ = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "hashed_password",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                ]
+            )
+            .where("id = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(query, [id])
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            hashed_password=result["hashed_password"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"],
+        )
+
+    async def get_user_by_email(self, email: str) -> User:
+        query, params = (
+            QueryBuilder(self._get_table_name("users"))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "hashed_password",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "name",
+                    "profile_picture",
+                    "bio",
+                    "collection_ids",
+                ]
+            )
+            .where("email = $1")
+            .build()
+        )
+        result = await self.connection_manager.fetchrow_query(query, [email])
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            hashed_password=result["hashed_password"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"],
+        )
+
+    async def create_user(
+        self, email: str, password: str, is_superuser: bool = False
+    ) -> User:
+        try:
+            if await self.get_user_by_email(email):
+                raise R2RException(
+                    status_code=400,
+                    message="User with this email already exists",
+                )
+        except R2RException as e:
+            if e.status_code != 404:
+                raise e
+
+        hashed_password = self.crypto_provider.get_password_hash(password)  # type: ignore
+        query = f"""
+            INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            (email, id, is_superuser, hashed_password, collection_ids)
+            VALUES ($1, $2, $3, $4, $5)
+            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query,
+            [
+                email,
+                generate_user_id(email),
+                is_superuser,
+                hashed_password,
+                [],
+            ],
+        )
+
+        if not result:
+            raise HTTPException(
+                status_code=500,
+                detail="Failed to create user",
+            )
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            collection_ids=result["collection_ids"],
+            hashed_password=hashed_password,
+        )
+
+    async def update_user(self, user: User) -> User:
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
+                name = $5, profile_picture = $6, bio = $7, collection_ids = $8
+            WHERE id = $9
+            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query,
+            [
+                user.email,
+                user.is_superuser,
+                user.is_active,
+                user.is_verified,
+                user.name,
+                user.profile_picture,
+                user.bio,
+                user.collection_ids,
+                user.id,
+            ],
+        )
+
+        if not result:
+            raise HTTPException(
+                status_code=500,
+                detail="Failed to update user",
+            )
+
+        return User(
+            id=result["id"],
+            email=result["email"],
+            is_superuser=result["is_superuser"],
+            is_active=result["is_active"],
+            is_verified=result["is_verified"],
+            created_at=result["created_at"],
+            updated_at=result["updated_at"],
+            name=result["name"],
+            profile_picture=result["profile_picture"],
+            bio=result["bio"],
+            collection_ids=result["collection_ids"],
+        )
+
+    async def delete_user_relational(self, id: UUID) -> None:
+        # Get the collections the user belongs to
+        collection_query = f"""
+            SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE id = $1
+        """
+        collection_result = await self.connection_manager.fetchrow_query(
+            collection_query, [id]
+        )
+
+        if not collection_result:
+            raise R2RException(status_code=404, message="User not found")
+
+        # Remove user from documents
+        doc_update_query = f"""
+            UPDATE {self._get_table_name('documents')}
+            SET id = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(doc_update_query, [id])
+
+        # Delete the user
+        delete_query = f"""
+            DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE id = $1
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            delete_query, [id]
+        )
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+    async def update_user_password(self, id: UUID, new_hashed_password: str):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET hashed_password = $1, updated_at = NOW()
+            WHERE id = $2
+        """
+        await self.connection_manager.execute_query(
+            query, [new_hashed_password, id]
+        )
+
+    async def get_all_users(self) -> list[User]:
+        query = f"""
+            SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
+            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+        """
+        results = await self.connection_manager.fetch_query(query)
+
+        return [
+            User(
+                id=result["id"],
+                email=result["email"],
+                hashed_password="null",
+                is_superuser=result["is_superuser"],
+                is_active=result["is_active"],
+                is_verified=result["is_verified"],
+                created_at=result["created_at"],
+                updated_at=result["updated_at"],
+                collection_ids=result["collection_ids"],
+            )
+            for result in results
+        ]
+
+    async def store_verification_code(
+        self, id: UUID, verification_code: str, expiry: datetime
+    ):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code = $1, verification_code_expiry = $2
+            WHERE id = $3
+        """
+        await self.connection_manager.execute_query(
+            query, [verification_code, expiry, id]
+        )
+
+    async def verify_user(self, verification_code: str) -> None:
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            WHERE verification_code = $1 AND verification_code_expiry > NOW()
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [verification_code]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=400, message="Invalid or expired verification code"
+            )
+
+    async def remove_verification_code(self, verification_code: str):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code = NULL, verification_code_expiry = NULL
+            WHERE verification_code = $1
+        """
+        await self.connection_manager.execute_query(query, [verification_code])
+
+    async def expire_verification_code(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET verification_code_expiry = NOW() - INTERVAL '1 day'
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def store_reset_token(
+        self, id: UUID, reset_token: str, expiry: datetime
+    ):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET reset_token = $1, reset_token_expiry = $2
+            WHERE id = $3
+        """
+        await self.connection_manager.execute_query(
+            query, [reset_token, expiry, id]
+        )
+
+    async def get_user_id_by_reset_token(
+        self, reset_token: str
+    ) -> Optional[UUID]:
+        query = f"""
+            SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE reset_token = $1 AND reset_token_expiry > NOW()
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [reset_token]
+        )
+        return result["id"] if result else None
+
+    async def remove_reset_token(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET reset_token = NULL, reset_token_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def remove_user_from_all_collections(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = ARRAY[]::UUID[]
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def add_user_to_collection(
+        self, id: UUID, collection_id: UUID
+    ) -> bool:
+        # Check if the user exists
+        if not await self.get_user_by_id(id):
+            raise R2RException(status_code=404, message="User not found")
+
+        # Check if the collection exists
+        if not await self._collection_exists(collection_id):
+            raise R2RException(status_code=404, message="Collection not found")
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = array_append(collection_ids, $1)
+            WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id, id]
+        )
+        if not result:
+            raise R2RException(
+                status_code=400, message="User already in collection"
+            )
+
+        update_collection_query = f"""
+            UPDATE {self._get_table_name('collections')}
+            SET user_count = user_count + 1
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(
+            query=update_collection_query,
+            params=[collection_id],
+        )
+
+        return True
+
+    async def remove_user_from_collection(
+        self, id: UUID, collection_id: UUID
+    ) -> bool:
+        if not await self.get_user_by_id(id):
+            raise R2RException(status_code=404, message="User not found")
+
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET collection_ids = array_remove(collection_ids, $1)
+            WHERE id = $2 AND $1 = ANY(collection_ids)
+            RETURNING id
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id, id]
+        )
+        if not result:
+            raise R2RException(
+                status_code=400,
+                message="User is not a member of the specified collection",
+            )
+        return True
+
+    async def get_users_in_collection(
+        self, collection_id: UUID, offset: int, limit: int
+    ) -> dict[str, list[User] | int]:
+        """
+        Get all users in a specific collection with pagination.
+
+        Args:
+            collection_id (UUID): The ID of the collection to get users from.
+            offset (int): The number of users to skip.
+            limit (int): The maximum number of users to return.
+
+        Returns:
+            List[User]: A list of User objects representing the users in the collection.
+
+        Raises:
+            R2RException: If the collection doesn't exist.
+        """
+        if not await self._collection_exists(collection_id):  # type: ignore
+            raise R2RException(status_code=404, message="Collection not found")
+
+        query = f"""
+            SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
+                u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
+                COUNT(*) OVER() AS total_entries
+            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+            WHERE $1 = ANY(u.collection_ids)
+            ORDER BY u.name
+            OFFSET $2
+        """
+
+        conditions = [collection_id, offset]
+        if limit != -1:
+            query += " LIMIT $3"
+            conditions.append(limit)
+
+        results = await self.connection_manager.fetch_query(query, conditions)
+
+        users = [
+            User(
+                id=row["id"],
+                email=row["email"],
+                is_active=row["is_active"],
+                is_superuser=row["is_superuser"],
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                is_verified=row["is_verified"],
+                collection_ids=row["collection_ids"],
+                name=row["name"],
+                bio=row["bio"],
+                profile_picture=row["profile_picture"],
+                hashed_password=None,
+                verification_code_expiry=None,
+            )
+            for row in results
+        ]
+
+        total_entries = results[0]["total_entries"] if results else 0
+
+        return {"results": users, "total_entries": total_entries}
+
+    async def mark_user_as_superuser(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def get_user_id_by_verification_code(
+        self, verification_code: str
+    ) -> Optional[UUID]:
+        query = f"""
+            SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            WHERE verification_code = $1 AND verification_code_expiry > NOW()
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [verification_code]
+        )
+
+        if not result:
+            raise R2RException(
+                status_code=400, message="Invalid or expired verification code"
+            )
+
+        return result["id"]
+
+    async def mark_user_as_verified(self, id: UUID):
+        query = f"""
+            UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
+            SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            WHERE id = $1
+        """
+        await self.connection_manager.execute_query(query, [id])
+
+    async def get_users_overview(
+        self,
+        offset: int,
+        limit: int,
+        user_ids: Optional[list[UUID]] = None,
+    ) -> dict[str, list[User] | int]:
+
+        query = f"""
+            WITH user_document_ids AS (
+                SELECT
+                    u.id as user_id,
+                    ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
+                FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+                LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
+                GROUP BY u.id
+            ),
+            user_docs AS (
+                SELECT
+                    u.id,
+                    u.email,
+                    u.is_superuser,
+                    u.is_active,
+                    u.is_verified,
+                    u.created_at,
+                    u.updated_at,
+                    u.collection_ids,
+                    COUNT(d.id) AS num_files,
+                    COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
+                    ud.doc_ids as document_ids
+                FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
+                LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
+                LEFT JOIN user_document_ids ud ON u.id = ud.user_id
+                {' WHERE u.id = ANY($3::uuid[])' if user_ids else ''}
+                GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
+                         u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
+            )
+            SELECT
+                user_docs.*,
+                COUNT(*) OVER() AS total_entries
+            FROM user_docs
+            ORDER BY email
+            OFFSET $1
+        """
+
+        params: list = [offset]
+
+        if limit != -1:
+            query += " LIMIT $2"
+            params.append(limit)
+
+        if user_ids:
+            params.append(user_ids)
+
+        results = await self.connection_manager.fetch_query(query, params)
+
+        users = [
+            User(
+                id=row["id"],
+                email=row["email"],
+                is_superuser=row["is_superuser"],
+                is_active=row["is_active"],
+                is_verified=row["is_verified"],
+                created_at=row["created_at"],
+                updated_at=row["updated_at"],
+                collection_ids=row["collection_ids"] or [],
+                num_files=row["num_files"],
+                total_size_in_bytes=row["total_size_in_bytes"],
+                document_ids=(
+                    []
+                    if row["document_ids"] is None
+                    else [doc_id for doc_id in row["document_ids"]]
+                ),
+            )
+            for row in results
+        ]
+
+        if not users:
+            raise R2RException(status_code=404, message="No users found")
+
+        total_entries = results[0]["total_entries"]
+
+        return {"results": users, "total_entries": total_entries}
+
+    async def _collection_exists(self, collection_id: UUID) -> bool:
+        """Check if a collection exists."""
+        query = f"""
+            SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
+            WHERE id = $1
+        """
+        result = await self.connection_manager.fetchrow_query(
+            query, [collection_id]
+        )
+        return result is not None
+
+    async def get_user_validation_data(
+        self,
+        user_id: UUID,
+    ) -> dict:
+        """
+        Get verification data for a specific user.
+        This method should be called after superuser authorization has been verified.
+        """
+        query = f"""
+            SELECT
+                verification_code,
+                verification_code_expiry,
+                reset_token,
+                reset_token_expiry
+            FROM {self._get_table_name("users")}
+            WHERE id = $1
+        """
+        result = await self.connection_manager.fetchrow_query(query, [user_id])
+
+        if not result:
+            raise R2RException(status_code=404, message="User not found")
+
+        return {
+            "verification_data": {
+                "verification_code": result["verification_code"],
+                "verification_code_expiry": (
+                    result["verification_code_expiry"].isoformat()
+                    if result["verification_code_expiry"]
+                    else None
+                ),
+                "reset_token": result["reset_token"],
+                "reset_token_expiry": (
+                    result["reset_token_expiry"].isoformat()
+                    if result["reset_token_expiry"]
+                    else None
+                ),
+            }
+        }

+ 5 - 0
core/database/vecs/__init__.py

@@ -0,0 +1,5 @@
+from . import exc
+
+__all__ = [
+    "exc",
+]

+ 16 - 0
core/database/vecs/adapter/__init__.py

@@ -0,0 +1,16 @@
+from .base import Adapter, AdapterContext, AdapterStep
+from .markdown import MarkdownChunker
+from .noop import NoOp, Record
+from .text import ParagraphChunker, TextEmbedding, TextEmbeddingModel
+
+__all__ = [
+    "Adapter",
+    "AdapterContext",
+    "AdapterStep",
+    "NoOp",
+    "Record",
+    "ParagraphChunker",
+    "TextEmbedding",
+    "TextEmbeddingModel",
+    "MarkdownChunker",
+]

+ 126 - 0
core/database/vecs/adapter/base.py

@@ -0,0 +1,126 @@
+"""
+The `vecs.experimental.adapter.base` module provides abstract classes and utilities
+for creating and handling adapters in vecs. Adapters allow users to interact with
+a collection using media types other than vectors.
+
+All public classes, enums, and functions are re-exported by `vecs.adapters` module.
+"""
+
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Any, Generator, Iterable, Optional, Tuple, Union
+from uuid import UUID
+
+from vecs.exc import ArgError
+
+MetadataValues = Union[str, int, float, bool, list[str]]
+Metadata = dict[str, MetadataValues]
+Numeric = Union[int, float, complex]
+
+Record = Tuple[
+    UUID,
+    UUID,
+    UUID,
+    list[UUID],
+    Iterable[Numeric],
+    str,
+    Metadata,
+]
+
+
+class AdapterContext(str, Enum):
+    """
+    An enum representing the different contexts in which a Pipeline
+    will be invoked.
+
+    Attributes:
+        upsert (str): The Collection.upsert method
+        query (str): The Collection.query method
+    """
+
+    upsert = "upsert"
+    query = "query"
+
+
+class AdapterStep(ABC):
+    """
+    Abstract class representing a step in the adapter pipeline.
+
+    Each adapter step should adapt a user media into a tuple of:
+     - id (str)
+     - media (unknown type)
+     - metadata (dict)
+
+    If the user provides id or metadata, default production is overridden.
+    """
+
+    @property
+    def exported_dimension(self) -> Optional[int]:
+        """
+        Property that should be overridden by subclasses to provide the output dimension
+        of the adapter step.
+        """
+        return None
+
+    @abstractmethod
+    def __call__(
+        self,
+        records: Iterable[Tuple[str, Any, Optional[dict]]],
+        adapter_context: AdapterContext,
+    ) -> Generator[Tuple[str, Any, dict], None, None]:
+        """
+        Abstract method that should be overridden by subclasses to handle each record.
+        """
+
+
+class Adapter:
+    """
+    Class representing a sequence of AdapterStep instances forming a pipeline.
+    """
+
+    def __init__(self, steps: list[AdapterStep]):
+        """
+        Initialize an Adapter instance with a list of AdapterStep instances.
+
+        Args:
+            steps: list of AdapterStep instances.
+
+        Raises:
+            ArgError: Raised if the steps list is empty.
+        """
+        self.steps = steps
+        if len(steps) < 1:
+            raise ArgError("Adapter must contain at least 1 step")
+
+    @property
+    def exported_dimension(self) -> Optional[int]:
+        """
+        The output dimension of the adapter. Returns the exported dimension of the last
+        AdapterStep that provides one (from end to start of the steps list).
+        """
+        for step in reversed(self.steps):
+            step_dim = step.exported_dimension
+            if step_dim is not None:
+                return step_dim
+        return None
+
+    def __call__(
+        self,
+        records: Iterable[Tuple[str, Any, Optional[dict]]],
+        adapter_context: AdapterContext,
+    ) -> Generator[Tuple[str, Any, dict], None, None]:
+        """
+        Invokes the adapter pipeline on an iterable of records.
+
+        Args:
+            records: Iterable of tuples each containing an id, a media and an optional dict.
+            adapter_context: Context of the adapter.
+
+        Yields:
+            Tuples each containing an id, a media and a dict.
+        """
+        pipeline = records
+        for step in self.steps:
+            pipeline = step(pipeline, adapter_context)
+
+        yield from pipeline  # type: ignore

+ 93 - 0
core/database/vecs/adapter/markdown.py

@@ -0,0 +1,93 @@
+import re
+from typing import Any, Generator, Iterable, Optional, Tuple
+
+from flupy import flu
+
+from .base import AdapterContext, AdapterStep
+
+
+class MarkdownChunker(AdapterStep):
+    """
+    MarkdownChunker is an AdapterStep that splits a markdown string into chunks where a heading signifies the start of a chunk, and yields each chunk as a separate record.
+    """
+
+    def __init__(self, *, skip_during_query: bool):
+        """
+        Initializes the MarkdownChunker adapter.
+
+        Args:
+            skip_during_query (bool): Whether to skip chunking during querying.
+        """
+        self.skip_during_query = skip_during_query
+
+    @staticmethod
+    def split_by_heading(
+        md: str, max_tokens: int
+    ) -> Generator[str, None, None]:
+        regex_split = r"^(#{1,6}\s+.+)$"
+        headings = [
+            match.span()[0]
+            for match in re.finditer(regex_split, md, flags=re.MULTILINE)
+        ]
+
+        if headings == [] or headings[0] != 0:
+            headings.insert(0, 0)
+
+        sections = [md[i:j] for i, j in zip(headings, headings[1:] + [None])]
+
+        for section in sections:
+            chunks = flu(section.split(" ")).chunk(max_tokens)
+
+            is_not_useless_chunk = lambda i: not i in [
+                "",
+                "\n",
+                [],
+            ]  # noqa: E731, E713
+
+            joined_chunks = filter(
+                is_not_useless_chunk,
+                [" ".join(chunk) for chunk in chunks],  # noqa: E731, E713
+            )
+
+            for joined_chunk in joined_chunks:
+                yield joined_chunk
+
+    def __call__(
+        self,
+        records: Iterable[Tuple[str, Any, Optional[dict]]],
+        adapter_context: AdapterContext,
+        max_tokens: int = 99999999,
+    ) -> Generator[Tuple[str, Any, dict], None, None]:
+        """
+        Splits each markdown string in the records into chunks where each heading starts a new chunk, and yields each chunk
+        as a separate record. If the `skip_during_query` attribute is set to True,
+        this step is skipped during querying.
+
+        Args:
+            records (Iterable[Tuple[str, Any, Optional[dict]]]): Iterable of tuples each containing an id, a markdown string and an optional dict.
+            adapter_context (AdapterContext): Context of the adapter.
+            max_tokens (int): The maximum number of tokens per chunk
+
+        Yields:
+            Tuple[str, Any, dict]: The id appended with chunk index, the chunk, and the metadata.
+        """
+        if max_tokens and max_tokens < 1:
+            raise ValueError("max_tokens must be a nonzero positive integer")
+
+        if (
+            adapter_context == AdapterContext("query")
+            and self.skip_during_query
+        ):
+            for id, markdown, metadata in records:
+                yield (id, markdown, metadata or {})
+        else:
+            for id, markdown, metadata in records:
+                headings = MarkdownChunker.split_by_heading(
+                    markdown, max_tokens
+                )
+                for heading_ix, heading in enumerate(headings):
+                    yield (
+                        f"{id}_head_{str(heading_ix).zfill(3)}",
+                        heading,
+                        metadata or {},
+                    )

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.