jack hai 3 meses
pai
achega
fed6a49f29
Modificáronse 98 ficheiros con 6175 adicións e 1211 borrados
  1. 148 12
      cli/command_group.py
  2. 72 48
      cli/commands/collections.py
  3. 165 0
      cli/commands/config.py
  4. 56 52
      cli/commands/conversations.py
  5. 4 6
      cli/commands/database.py
  6. 303 127
      cli/commands/documents.py
  7. 220 144
      cli/commands/graphs.py
  8. 39 41
      cli/commands/indices.py
  9. 36 23
      cli/commands/prompts.py
  10. 52 39
      cli/commands/retrieval.py
  11. 40 17
      cli/commands/system.py
  12. 98 53
      cli/commands/users.py
  13. 76 4
      cli/main.py
  14. 0 6
      cli/utils/docker_utils.py
  15. 2 1
      cli/utils/timer.py
  16. 124 0
      compose.yaml
  17. 0 1
      core/base/parsers/base_parser.py
  18. 0 1
      core/base/providers/ingestion.py
  19. 53 10
      core/database/base.py
  20. 0 2
      core/database/chunks.py
  21. 0 1
      core/database/collections.py
  22. 0 1
      core/database/documents.py
  23. 0 13
      core/database/graphs.py
  24. 89 54
      core/database/limits.py
  25. 234 116
      core/database/users.py
  26. 1 1
      core/examples/hello_r2r.py
  27. 6 6
      core/main/abstractions.py
  28. 57 104
      core/main/api/v3/base_router.py
  29. 9 9
      core/main/api/v3/chunks_router.py
  30. 23 23
      core/main/api/v3/collections_router.py
  31. 14 14
      core/main/api/v3/conversations_router.py
  32. 38 36
      core/main/api/v3/documents_router.py
  33. 34 36
      core/main/api/v3/graph_router.py
  34. 9 11
      core/main/api/v3/indices_router.py
  35. 10 10
      core/main/api/v3/prompts_router.py
  36. 10 11
      core/main/api/v3/retrieval_router.py
  37. 6 6
      core/main/api/v3/system_router.py
  38. 49 40
      core/main/api/v3/users_router.py
  39. 0 1
      core/main/app.py
  40. 5 12
      core/main/assembly/builder.py
  41. 0 2
      core/main/assembly/factory.py
  42. 0 1
      core/main/orchestration/hatchet/ingestion_workflow.py
  43. 0 8
      core/main/orchestration/hatchet/kg_workflow.py
  44. 0 3
      core/main/orchestration/simple/ingestion_workflow.py
  45. 0 7
      core/main/orchestration/simple/kg_workflow.py
  46. 3 0
      core/main/services/auth_service.py
  47. 0 10
      core/main/services/graph_service.py
  48. 0 1
      core/main/services/ingestion_service.py
  49. 0 1
      core/main/services/management_service.py
  50. 7 3
      core/parsers/media/bmp_parser.py
  51. 25 25
      core/pipes/kg/community_summary.py
  52. 0 1
      core/pipes/kg/deduplication.py
  53. 0 6
      core/pipes/kg/deduplication_summary.py
  54. 0 1
      core/pipes/kg/extraction.py
  55. 0 1
      core/pipes/kg/storage.py
  56. 0 1
      core/pipes/retrieval/chunk_search_pipe.py
  57. 0 1
      core/pipes/retrieval/graph_search_pipe.py
  58. 0 1
      core/providers/auth/r2r_auth.py
  59. 1 1
      core/providers/crypto/bcrypt.py
  60. 26 8
      core/providers/crypto/nacl.py
  61. 0 1
      core/providers/embeddings/litellm.py
  62. 0 1
      core/providers/ingestion/r2r/base.py
  63. 0 1
      core/providers/ingestion/unstructured/base.py
  64. 0 1
      core/telemetry/telemetry_decorator.py
  65. 0 2
      migrations/versions/8077140e1e99_v3_api_database_revision.py
  66. 3 0
      pyproject.toml
  67. 1 1
      sdk/async_client.py
  68. 3 3
      sdk/base/base_client.py
  69. 0 1
      sdk/sync_client.py
  70. 1 2
      sdk/v2/management.py
  71. 3 3
      sdk/v3/graphs.py
  72. 6 14
      sdk/v3/users.py
  73. 0 1
      shared/abstractions/graph.py
  74. 2 3
      shared/abstractions/search.py
  75. 3 0
      shared/abstractions/user.py
  76. 56 0
      tests/cli/async_invoke.py
  77. 162 0
      tests/cli/commands/test_collections_cli.py
  78. 7 0
      tests/cli/commands/test_config_cli.py
  79. 168 0
      tests/cli/commands/test_conversations_cli.py
  80. 0 0
      tests/cli/commands/test_database_cli.py
  81. 330 0
      tests/cli/commands/test_documents_cli.py
  82. 312 0
      tests/cli/commands/test_graphs_cli.py
  83. 6 0
      tests/cli/commands/test_indices_cli.py
  84. 95 0
      tests/cli/commands/test_prompts_cli.py
  85. 213 0
      tests/cli/commands/test_retrieval_cli.py
  86. 338 0
      tests/cli/commands/test_system_cli.py
  87. 143 0
      tests/cli/commands/test_users_cli.py
  88. 118 0
      tests/cli/utils/test_timer.py
  89. 0 2
      tests/integration/test_conversations.py
  90. 187 0
      tests/integration/test_filters.py
  91. 37 1
      tests/integration/test_users.py
  92. 344 0
      tests/unit/conftest.py
  93. 315 0
      tests/unit/test_chunks.py
  94. 205 0
      tests/unit/test_collections.py
  95. 132 0
      tests/unit/test_conversations.py
  96. 143 0
      tests/unit/test_documents.py
  97. 449 0
      tests/unit/test_graphs.py
  98. 249 0
      tests/unit/test_limits.py

+ 148 - 12
cli/command_group.py

@@ -1,11 +1,47 @@
+# from .main import load_config
+import json
+import types
 from functools import wraps
+from pathlib import Path
+from typing import Any, Never
 
 import asyncclick as click
 from asyncclick import pass_context
 from asyncclick.exceptions import Exit
+from rich import box
+from rich.console import Console
+from rich.table import Table
 
 from sdk import R2RAsyncClient
 
+console = Console()
+
+CONFIG_DIR = Path.home() / ".r2r"
+CONFIG_FILE = CONFIG_DIR / "config.json"
+
+
+def load_config() -> dict[str, Any]:
+    """
+    Load the CLI config from ~/.r2r/config.json.
+    Returns an empty dict if the file doesn't exist or is invalid.
+    """
+    if not CONFIG_FILE.is_file():
+        return {}
+    try:
+        with open(CONFIG_FILE, "r", encoding="utf-8") as f:
+            data = json.load(f)
+            # Ensure we always have a dict
+            if not isinstance(data, dict):
+                return {}
+            return data
+    except (IOError, json.JSONDecodeError):
+        return {}
+
+
+def silent_exit(ctx, code=0):
+    if code != 0:
+        raise Exit(code)
+
 
 def deprecated_command(new_name):
     def decorator(f):
@@ -23,19 +59,119 @@ def deprecated_command(new_name):
     return decorator
 
 
-@click.group()
+def custom_help_formatter(commands):
+    """Create a nicely formatted help table using rich"""
+    table = Table(
+        box=box.ROUNDED,
+        border_style="blue",
+        pad_edge=False,
+        collapse_padding=True,
+    )
+
+    table.add_column("Command", style="cyan", no_wrap=True)
+    table.add_column("Description", style="white")
+
+    command_groups = {
+        "Document Management": [
+            ("documents", "Document ingestion and management commands"),
+            ("collections", "Collections management commands"),
+        ],
+        "Knowledge Graph": [
+            ("graphs", "Graph creation and management commands"),
+            ("prompts", "Prompt template management"),
+        ],
+        "Interaction": [
+            ("conversations", "Conversation management commands"),
+            ("retrieval", "Knowledge retrieval commands"),
+        ],
+        "System": [
+            ("configure", "Configuration management commands"),
+            ("users", "User management commands"),
+            ("indices", "Index management commands"),
+            ("system", "System administration commands"),
+        ],
+        "Database": [
+            ("db", "Database management commands"),
+            ("upgrade", "Upgrade database schema"),
+            ("downgrade", "Downgrade database schema"),
+            ("current", "Show current schema version"),
+            ("history", "View schema migration history"),
+        ],
+    }
+
+    for group_name, group_commands in command_groups.items():
+        table.add_row(
+            f"[bold yellow]{group_name}[/bold yellow]", "", style="dim"
+        )
+        for cmd_name, description in group_commands:
+            if cmd_name in commands:
+                table.add_row(f"  {cmd_name}", commands[cmd_name].help or "")
+        table.add_row("", "")  # Add spacing between groups
+
+    return table
+
+
+class CustomGroup(click.Group):
+    def format_help(self, ctx, formatter):
+        console.print("\n[bold blue]R2R Command Line Interface[/bold blue]")
+        console.print("The most advanced AI retrieval system\n")
+
+        if self.get_help_option(ctx) is not None:
+            console.print("[bold cyan]Usage:[/bold cyan]")
+            console.print("  r2r [OPTIONS] COMMAND [ARGS]...\n")
+
+        console.print("[bold cyan]Options:[/bold cyan]")
+        console.print(
+            "  --base-url TEXT  Base URL for the API [default: https://api.cloud.sciphi.ai]"
+        )
+        console.print("  --help           Show this message and exit.\n")
+
+        console.print("[bold cyan]Commands:[/bold cyan]")
+        console.print(custom_help_formatter(self.commands))
+        console.print(
+            "\nFor more details on a specific command, run: [bold]r2r COMMAND --help[/bold]\n"
+        )
+
+
+class CustomContext(click.Context):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        super().__init__(*args, **kwargs)
+        self.exit_func = types.MethodType(silent_exit, self)
+
+    def exit(self, code: int = 0) -> Never:
+        self.exit_func(code)
+        raise SystemExit(code)
+
+
+def initialize_client(base_url: str) -> R2RAsyncClient:
+    """Initialize R2R client with API key from config if available."""
+    client = R2RAsyncClient()
+
+    try:
+        config = load_config()
+        if api_key := config.get("api_key"):
+            client.set_api_key(api_key)
+            if not client.api_key:
+                console.print(
+                    "[yellow]Warning: API key not properly set in client[/yellow]"
+                )
+
+    except Exception as e:
+        console.print(
+            "[yellow]Warning: Failed to load API key from config[/yellow]"
+        )
+        console.print_exception()
+
+    return client
+
+
+@click.group(cls=CustomGroup)
 @click.option(
-    "--base-url", default="http://localhost:7272", help="Base URL for the API"
+    "--base-url",
+    default="https://cloud.sciphi.ai",
+    help="Base URL for the API",
 )
 @pass_context
-async def cli(ctx, base_url):
+async def cli(ctx: click.Context, base_url: str) -> None:
     """R2R CLI for all core operations."""
-
-    ctx.obj = R2RAsyncClient(base_url=base_url)
-
-    # Override the default exit behavior
-    def silent_exit(self, code=0):
-        if code != 0:
-            raise Exit(code)
-
-    ctx.exit = silent_exit.__get__(ctx)
+    ctx.obj = initialize_client(base_url)

+ 72 - 48
cli/commands/collections.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -17,17 +17,21 @@ def collections():
 @click.argument("name", required=True, type=str)
 @click.option("--description", type=str)
 @pass_context
-async def create(ctx, name, description):
+async def create(ctx: click.Context, name, description):
     """Create a collection."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.create(
-            name=name,
-            description=description,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.create(
+                name=name,
+                description=description,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @collections.command()
@@ -43,46 +47,57 @@ async def create(ctx, name, description):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(ctx: click.Context, ids, offset, limit):
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
     ids = list(ids) if ids else None
 
-    with timer():
-        response = await client.collections.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @collections.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a collection by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @collections.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a collection by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @collections.command()
@@ -98,19 +113,24 @@ async def delete(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_documents(ctx, id, offset, limit):
+async def list_documents(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.list_documents(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.collections.list_documents(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
 
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @collections.command()
@@ -126,16 +146,20 @@ async def list_documents(ctx, id, offset, limit):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_users(ctx, id, offset, limit):
+async def list_users(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.list_users(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list_users(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 165 - 0
cli/commands/config.py

@@ -0,0 +1,165 @@
+import configparser
+from pathlib import Path
+
+import asyncclick as click
+from rich.box import ROUNDED
+from rich.console import Console
+from rich.table import Table
+
+console = Console()
+
+
+def get_config_dir():
+    """Create and return the config directory path."""
+    config_dir = Path.home() / ".r2r"
+    config_dir.mkdir(exist_ok=True)
+    return config_dir
+
+
+def get_config_file():
+    """Get the config file path."""
+    return get_config_dir() / "config.ini"
+
+
+class Config:
+    _instance = None
+    _config = configparser.ConfigParser()
+    _config_file = get_config_file()
+
+    @classmethod
+    def load(cls):
+        """Load the configuration file."""
+        if cls._config_file.exists():
+            cls._config.read(cls._config_file)
+
+    @classmethod
+    def save(cls):
+        """Save the configuration to file."""
+        with open(cls._config_file, "w") as f:
+            cls._config.write(f)
+
+    @classmethod
+    def get_credentials(cls, service):
+        """Get credentials for a specific service."""
+        cls.load()  # Ensure we have latest config
+        return dict(cls._config[service]) if service in cls._config else {}
+
+    @classmethod
+    def set_credentials(cls, service, credentials):
+        """Set credentials for a specific service."""
+        cls.load()  # Ensure we have latest config
+        if service not in cls._config:
+            cls._config[service] = {}
+        cls._config[service].update(credentials)
+        cls.save()
+
+
+@click.group()
+def configure():
+    """Configuration management commands."""
+    pass
+
+
+@configure.command()
+@click.confirmation_option(
+    prompt="Are you sure you want to reset all settings?"
+)
+async def reset():
+    """Reset all configuration to defaults."""
+    if Config._config_file.exists():
+        Config._config_file.unlink()  # Delete the config file
+    Config._config = configparser.ConfigParser()  # Reset the config in memory
+
+    # Set default values
+    Config.set_credentials(
+        "Base URL", {"base_url": "https://api.cloud.sciphi.ai"}
+    )
+
+    console.print(
+        "[green]Successfully reset configuration to defaults[/green]"
+    )
+
+
+@configure.command()
+@click.option(
+    "--api-key",
+    prompt="SciPhi API Key",
+    hide_input=True,
+    help="API key for SciPhi cloud",
+)
+async def key(api_key):
+    """Configure SciPhi cloud API credentials."""
+    Config.set_credentials("SciPhi", {"api_key": api_key})
+    console.print(
+        "[green]Successfully configured SciPhi cloud credentials[/green]"
+    )
+
+
+@configure.command()
+@click.option(
+    "--base-url",
+    prompt="R2R Base URL",
+    default="https://api.cloud.sciphi.ai",
+    hide_input=False,
+    help="Host URL for R2R",
+)
+async def host(host):
+    """Configure R2R host URL."""
+    Config.set_credentials("Host", {"R2R_HOST": host})
+    console.print("[green]Successfully configured R2R host URL[/green]")
+
+
+@configure.command()
+async def view():
+    """View current configuration."""
+    Config.load()
+
+    table = Table(
+        title="[bold blue]R2R Settings[/bold blue]",
+        show_header=True,
+        header_style="bold white on blue",
+        border_style="blue",
+        box=ROUNDED,
+        pad_edge=False,
+        collapse_padding=True,
+    )
+
+    table.add_column(
+        "Section", justify="left", style="bright_yellow", no_wrap=True
+    )
+    table.add_column(
+        "Key", justify="left", style="bright_magenta", no_wrap=True
+    )
+    table.add_column(
+        "Value", justify="left", style="bright_green", no_wrap=True
+    )
+
+    # Group related configurations together
+    config_groups = {
+        "API Credentials": ["SciPhi"],
+        "Server Settings": ["Base URL", "Port"],
+    }
+
+    for group_name, sections in config_groups.items():
+        has_items = any(section in Config._config for section in sections)
+        if has_items:
+            table.add_row(
+                f"[bold]{group_name}[/bold]", "", "", style="bright_blue"
+            )
+
+            for section in sections:
+                if section in Config._config:
+                    for key, value in Config._config[section].items():
+                        # Mask API keys for security
+                        displayed_value = (
+                            f"****{value[-4:]}"
+                            if "api_key" in key.lower()
+                            else value
+                        )
+                        table.add_row(
+                            f"  {section}", key.lower(), displayed_value
+                        )
+
+    console.print("\n")
+    console.print(table)
+    console.print("\n")

+ 56 - 52
cli/commands/conversations.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -15,14 +15,18 @@ def conversations():
 
 @conversations.command()
 @pass_context
-async def create(ctx):
+async def create(ctx: click.Context):
     """Create a conversation."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.conversations.create()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.create()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @conversations.command()
@@ -38,62 +42,58 @@ async def create(ctx):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(ctx: click.Context, ids, offset, limit):
     """Get an overview of conversations."""
     client: R2RAsyncClient = ctx.obj
     ids = list(ids) if ids else None
 
-    with timer():
-        response = await client.conversations.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @conversations.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a collection by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.conversations.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @conversations.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a collection by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.conversations.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
-
-
-@conversations.command()
-@click.argument("id", required=True, type=str)
-@pass_context
-async def list_branches(ctx, id):
-    """List all branches in a conversation."""
-    client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.conversations.list_branches(
-            id=id,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @conversations.command()
@@ -109,16 +109,20 @@ async def list_branches(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_users(ctx, id, offset, limit):
+async def list_users(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.collections.list_users(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list_users(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 4 - 6
cli/commands/database.py

@@ -82,7 +82,6 @@ async def upgrade(schema, revision):
         click.echo(
             f"Running database upgrade for schema {schema or 'default'}..."
         )
-        print(f"Upgrading revision = {revision}")
         command = f"upgrade {revision}" if revision else "upgrade"
         result = await run_alembic_command(command, schema_name=schema)
 
@@ -104,11 +103,10 @@ async def upgrade(schema, revision):
 @click.option("--revision", help="Downgrade to a specific revision")
 async def downgrade(schema, revision):
     """Downgrade database schema to the previous revision or a specific revision."""
-    if not revision:
-        if not click.confirm(
-            "No revision specified. This will downgrade the database by one revision. Continue?"
-        ):
-            return
+    if not revision and not click.confirm(
+        "No revision specified. This will downgrade the database by one revision. Continue?"
+    ):
+        return
 
     try:
         db_url = get_database_url_from_env(log=False)

+ 303 - 127
cli/commands/documents.py

@@ -2,15 +2,23 @@ import json
 import os
 import tempfile
 import uuid
+from builtins import list as _list
+from typing import Any, Optional, Sequence
 from urllib.parse import urlparse
+from uuid import UUID
 
 import asyncclick as click
 import requests
 from asyncclick import pass_context
+from rich.box import ROUNDED
+from rich.console import Console
+from rich.table import Table
 
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
+
+console = Console()
 
 
 @click.group()
@@ -31,15 +39,21 @@ def documents():
     "--run-without-orchestration", is_flag=True, help="Run with orchestration"
 )
 @pass_context
-async def create(ctx, file_paths, ids, metadatas, run_without_orchestration):
+async def create(
+    ctx: click.Context,
+    file_paths: tuple[str, ...],
+    ids: Optional[tuple[str, ...]] = None,
+    metadatas: Optional[Sequence[dict[str, Any]]] = None,
+    run_without_orchestration: bool = False,
+):
     """Ingest files into R2R."""
     client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
-    responses = []
+    responses: _list[dict[str, Any]] = []
 
     for idx, file_path in enumerate(file_paths):
         with timer():
-            current_id = [ids[idx]] if ids and idx < len(ids) else None
+            current_id = ids[idx] if ids and idx < len(ids) else None
             current_metadata = (
                 metadatas[idx] if metadatas and idx < len(metadatas) else None
             )
@@ -47,74 +61,181 @@ async def create(ctx, file_paths, ids, metadatas, run_without_orchestration):
             click.echo(
                 f"Processing file {idx + 1}/{len(file_paths)}: {file_path}"
             )
-            response = await client.documents.create(
-                file_path=file_path,
-                metadata=current_metadata,
-                id=current_id,
-                run_with_orchestration=run_with_orchestration,
-            )
-            responses.append(response)
-            click.echo(json.dumps(response, indent=2))
-            click.echo("-" * 40)
+            try:
+                response = await client.documents.create(
+                    file_path=file_path,
+                    metadata=current_metadata,
+                    id=current_id,
+                    run_with_orchestration=run_with_orchestration,
+                )
+                responses.append(response)  # type: ignore
+                click.echo(json.dumps(response, indent=2))
+                click.echo("-" * 40)
+            except R2RException as e:
+                click.echo(str(e), err=True)
+            except Exception as e:
+                click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
     click.echo(f"\nProcessed {len(responses)} files successfully.")
 
 
 @documents.command()
-@click.argument("file_path", required=True, type=click.Path(exists=True))
-@click.option("--id", required=True, help="Existing document ID to update")
+@click.option("--ids", multiple=True, help="Document IDs to fetch")
 @click.option(
-    "--metadata", type=JSON, help="Metadatas for ingestion as a JSON string"
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
 )
 @click.option(
-    "--run-without-orchestration", is_flag=True, help="Run with orchestration"
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def update(ctx, file_path, id, metadata, run_without_orchestration):
-    """Update an existing file in R2R."""
+async def list(
+    ctx: click.Context,
+    ids: Optional[tuple[str, ...]] = None,
+    offset: int = 0,
+    limit: int = 100,
+) -> None:
+    """Get an overview of documents."""
+    ids = list(ids) if ids else None
     client: R2RAsyncClient = ctx.obj
-    run_with_orchestration = not run_without_orchestration
-    responses = []
 
-    with timer():
-        click.echo(f"Updating file {id}: {file_path}")
-        response = await client.documents.update(
-            file_path=file_path,
-            metadata=metadata,
-            id=id,
-            run_with_orchestration=run_with_orchestration,
+    try:
+        with timer():
+            response = await client.documents.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Documents[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
-        responses.append(response)
-        click.echo(json.dumps(response, indent=2))
-        click.echo("-" * 40)
 
-    click.echo(f"Updated file {id} file successfully.")
+        # Add columns based on your document structure
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Type", style="bright_magenta")
+        table.add_column("Title", style="bright_green")
+        table.add_column("Ingestion Status", style="bright_cyan")
+        table.add_column("Extraction Status", style="bright_cyan")
+        table.add_column("Summary", style="bright_white")
+        table.add_column("Created At", style="bright_white")
+
+        for document in response["results"]:  # type: ignore
+            table.add_row(
+                document.get("id", ""),
+                document.get("document_type", ""),
+                document.get("title", ""),
+                document.get("ingestion_status", ""),
+                document.get("extraction_status", ""),
+                document.get("summary", ""),
+                document.get("created_at", "")[:19],
+            )
+
+        console = Console()
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} documents (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 @documents.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id: UUID):
     """Retrieve a document by ID."""
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
-    with timer():
-        response = await client.documents.retrieve(id=id)
+    try:
+        with timer():
+            response = await client.documents.retrieve(id=id)
+
+        # Get the actual document data from the results
+        document = response["results"]  # type: ignore
+
+        metadata_table = Table(
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            title="[bold blue]Document Details[/bold blue]",
+            show_lines=True,
+        )
 
-    click.echo(json.dumps(response, indent=2))
+        metadata_table.add_column("Field", style="bright_yellow")
+        metadata_table.add_column("Value", style="bright_white")
+
+        # Add core document information
+        core_fields = [
+            ("ID", document.get("id", "")),
+            ("Type", document.get("document_type", "")),
+            ("Title", document.get("title", "")),
+            ("Created At", document.get("created_at", "")[:19]),
+            ("Updated At", document.get("updated_at", "")[:19]),
+            ("Ingestion Status", document.get("ingestion_status", "")),
+            ("Extraction Status", document.get("extraction_status", "")),
+            ("Size", f"{document.get('size_in_bytes', 0):,} bytes"),
+        ]
+
+        for field, value in core_fields:
+            metadata_table.add_row(field, str(value))
+
+        # Add metadata section if it exists
+        if "metadata" in document:
+            metadata_table.add_row(
+                "[bold]Metadata[/bold]", "", style="bright_blue"
+            )
+            for key, value in document["metadata"].items():
+                metadata_table.add_row(f"  {key}", str(value))
+
+        # Add summary if it exists
+        if "summary" in document:
+            metadata_table.add_row(
+                "[bold]Summary[/bold]",
+                document["summary"],
+            )
+
+        console.print("\n")
+        console.print(metadata_table)
+        console.print("\n")
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 @documents.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a document by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.documents.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @documents.command()
@@ -130,18 +251,53 @@ async def delete(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_chunks(ctx, id, offset, limit):
-    """List collections for a specific document."""
+async def list_chunks(ctx: click.Context, id, offset, limit):
+    """List chunks for a specific document."""
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
-    with timer():
-        response = await client.documents.list_chunks(
-            id=id,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await client.documents.list_chunks(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Document Chunks[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
 
-    click.echo(json.dumps(response, indent=2))
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Text", style="bright_white")
+
+        for chunk in response["results"]:  # type: ignore
+            table.add_row(
+                chunk.get("id", ""),
+                (
+                    chunk.get("text", "")[:200] + "..."
+                    if len(chunk.get("text", "")) > 200
+                    else chunk.get("text", "")
+                ),
+            )
+
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} chunks (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 @documents.command()
@@ -157,18 +313,53 @@ async def list_chunks(ctx, id, offset, limit):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_collections(ctx, id, offset, limit):
+async def list_collections(ctx: click.Context, id, offset, limit):
     """List collections for a specific document."""
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
-    with timer():
-        response = await client.documents.list_collections(
-            id=id,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await client.documents.list_collections(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Document Collections[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
 
-    click.echo(json.dumps(response, indent=2))
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Name", style="bright_green")
+        table.add_column("Description", style="bright_white")
+        table.add_column("Created At", style="bright_white")
+
+        for collection in response["results"]:  # type: ignore
+            table.add_row(
+                collection.get("id", ""),
+                collection.get("name", ""),
+                collection.get("description", ""),
+                collection.get("created_at", "")[:19],
+            )
+
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} collections (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 # TODO
@@ -228,7 +419,9 @@ async def ingest_files_from_urls(client, urls):
     help="Run without orchestration",
 )
 @pass_context
-async def extract(ctx, id, run_type, settings, run_without_orchestration):
+async def extract(
+    ctx: click.Context, id, run_type, settings, run_without_orchestration
+):
     """Extract entities and relationships from a document."""
     client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
@@ -262,19 +455,25 @@ async def extract(ctx, id, run_type, settings, run_without_orchestration):
     help="Include embeddings in response",
 )
 @pass_context
-async def list_entities(ctx, id, offset, limit, include_embeddings):
+async def list_entities(
+    ctx: click.Context, id, offset, limit, include_embeddings
+):
     """List entities extracted from a document."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.documents.list_entities(
-            id=id,
-            offset=offset,
-            limit=limit,
-            include_embeddings=include_embeddings,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.list_entities(
+                id=id,
+                offset=offset,
+                limit=limit,
+                include_embeddings=include_embeddings,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @documents.command()
@@ -301,50 +500,52 @@ async def list_entities(ctx, id, offset, limit, include_embeddings):
 )
 @pass_context
 async def list_relationships(
-    ctx, id, offset, limit, entity_names, relationship_types
+    ctx: click.Context, id, offset, limit, entity_names, relationship_types
 ):
     """List relationships extracted from a document."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.documents.list_relationships(
-            id=id,
-            offset=offset,
-            limit=limit,
-            entity_names=list(entity_names) if entity_names else None,
-            relationship_types=(
-                list(relationship_types) if relationship_types else None
-            ),
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.list_relationships(
+                id=id,
+                offset=offset,
+                limit=limit,
+                entity_names=list(entity_names) if entity_names else None,
+                relationship_types=(
+                    list(relationship_types) if relationship_types else None
+                ),
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @documents.command()
-@click.option(
-    "--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)"
-)
-@click.option(
-    "--v3", is_flag=True, help="use aristotle_v3.txt (a larger file)"
-)
 @pass_context
-async def create_sample(ctx, v2=True, v3=False):
+async def create_sample(ctx: click.Context) -> None:
     """Ingest the first sample file into R2R."""
-    sample_file_url = f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt"
+    sample_file_url = "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt"
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await ingest_files_from_urls(client, [sample_file_url])
-    click.echo(
-        f"Sample file ingestion completed. Ingest files response:\n\n{response}"
-    )
+    try:
+        with timer():
+            response = await ingest_files_from_urls(client, [sample_file_url])
+        click.echo(
+            f"Sample file ingestion completed. Ingest files response:\n\n{response}"
+        )
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @documents.command()
 @pass_context
-async def create_samples(ctx):
+async def create_samples(ctx: click.Context) -> None:
     """Ingest multiple sample files into R2R."""
-    client: R2RAsyncClient = ctx.obj
     urls = [
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_3.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_4.html",
@@ -356,38 +557,13 @@ async def create_samples(ctx):
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_2.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt",
     ]
-    with timer():
-        response = await ingest_files_from_urls(client, urls)
-
-    click.echo(
-        f"Sample files ingestion completed. Ingest files response:\n\n{response}"
-    )
-
-
-@documents.command()
-@click.option("--ids", multiple=True, help="Document IDs to fetch")
-@click.option(
-    "--offset",
-    default=0,
-    help="The offset to start from. Defaults to 0.",
-)
-@click.option(
-    "--limit",
-    default=100,
-    help="The maximum number of nodes to return. Defaults to 100.",
-)
-@pass_context
-async def list(ctx, ids, offset, limit):
-    """Get an overview of documents."""
     client: R2RAsyncClient = ctx.obj
-    ids = list(ids) if ids else None
 
-    with timer():
-        response = await client.documents.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await ingest_files_from_urls(client, urls)
+        click.echo(
+            f"Sample files ingestion completed. Ingest files response:\n\n{response}"
         )
-
-    for document in response["results"]:
-        click.echo(document)
+    except R2RException as e:
+        click.echo(str(e), err=True)

+ 220 - 144
cli/commands/graphs.py

@@ -5,7 +5,7 @@ from asyncclick import pass_context
 
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -29,45 +29,59 @@ def graphs():
     help="The maximum number of graphs to return. Defaults to 100.",
 )
 @pass_context
-async def list(ctx, collection_ids, offset, limit):
+async def list(ctx: click.Context, collection_ids, offset, limit):
     """List available graphs."""
-    client: R2RAsyncClient = ctx.obj
     collection_ids = list(collection_ids) if collection_ids else None
+    client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.list(
-            collection_ids=collection_ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list(
+                collection_ids=collection_ids,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @pass_context
-async def retrieve(ctx, collection_id):
+async def retrieve(ctx: click.Context, collection_id):
     """Retrieve a specific graph by collection ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.retrieve(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.retrieve(
+                collection_id=collection_id
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @pass_context
-async def reset(ctx, collection_id):
+async def reset(ctx: click.Context, collection_id):
     """Reset a graph, removing all its data."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.reset(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.reset(collection_id=collection_id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -75,18 +89,22 @@ async def reset(ctx, collection_id):
 @click.option("--name", help="New name for the graph")
 @click.option("--description", help="New description for the graph")
 @pass_context
-async def update(ctx, collection_id, name, description):
+async def update(ctx: click.Context, collection_id, name, description):
     """Update graph information."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.update(
-            collection_id=collection_id,
-            name=name,
-            description=description,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.update(
+                collection_id=collection_id,
+                name=name,
+                description=description,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -102,52 +120,64 @@ async def update(ctx, collection_id, name, description):
     help="The maximum number of entities to return. Defaults to 100.",
 )
 @pass_context
-async def list_entities(ctx, collection_id, offset, limit):
+async def list_entities(ctx: click.Context, collection_id, offset, limit):
     """List entities in a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.list_entities(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_entities(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @pass_context
-async def get_entity(ctx, collection_id, entity_id):
+async def get_entity(ctx: click.Context, collection_id, entity_id):
     """Get entity information from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.get_entity(
-            collection_id=collection_id,
-            entity_id=entity_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_entity(
+                collection_id=collection_id,
+                entity_id=entity_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @pass_context
-async def remove_entity(ctx, collection_id, entity_id):
+async def remove_entity(ctx: click.Context, collection_id, entity_id):
     """Remove an entity from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.remove_entity(
-            collection_id=collection_id,
-            entity_id=entity_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.remove_entity(
+                collection_id=collection_id,
+                entity_id=entity_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -163,52 +193,66 @@ async def remove_entity(ctx, collection_id, entity_id):
     help="The maximum number of relationships to return. Defaults to 100.",
 )
 @pass_context
-async def list_relationships(ctx, collection_id, offset, limit):
+async def list_relationships(ctx: click.Context, collection_id, offset, limit):
     """List relationships in a graph."""
     client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.graphs.list_relationships(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_relationships(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @pass_context
-async def get_relationship(ctx, collection_id, relationship_id):
+async def get_relationship(ctx: click.Context, collection_id, relationship_id):
     """Get relationship information from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.get_relationship(
-            collection_id=collection_id,
-            relationship_id=relationship_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_relationship(
+                collection_id=collection_id,
+                relationship_id=relationship_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @pass_context
-async def remove_relationship(ctx, collection_id, relationship_id):
+async def remove_relationship(
+    ctx: click.Context, collection_id, relationship_id
+):
     """Remove a relationship from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.remove_relationship(
-            collection_id=collection_id,
-            relationship_id=relationship_id,
-        )
+    try:
+        with timer():
+            response = await client.graphs.remove_relationship(
+                collection_id=collection_id,
+                relationship_id=relationship_id,
+            )
 
-    click.echo(json.dumps(response, indent=2))
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -224,21 +268,29 @@ async def remove_relationship(ctx, collection_id, relationship_id):
 )
 @pass_context
 async def build(
-    ctx, collection_id, settings, run_type, run_without_orchestration
+    ctx: click.Context,
+    collection_id,
+    settings,
+    run_type,
+    run_without_orchestration,
 ):
     """Build a graph with specified settings."""
-    client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
+    client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.build(
-            collection_id=collection_id,
-            settings=settings,
-            run_type=run_type,
-            run_with_orchestration=run_with_orchestration,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.build(
+                collection_id=collection_id,
+                settings=settings,
+                run_type=run_type,
+                run_with_orchestration=run_with_orchestration,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -254,35 +306,43 @@ async def build(
     help="The maximum number of communities to return. Defaults to 100.",
 )
 @pass_context
-async def list_communities(ctx, collection_id, offset, limit):
+async def list_communities(ctx: click.Context, collection_id, offset, limit):
     """List communities in a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.list_communities(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_communities(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @pass_context
-async def get_community(ctx, collection_id, community_id):
+async def get_community(ctx: click.Context, collection_id, community_id):
     """Get community information from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.get_community(
-            collection_id=collection_id,
-            community_id=community_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_community(
+                collection_id=collection_id,
+                community_id=community_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
@@ -305,7 +365,7 @@ async def get_community(ctx, collection_id, community_id):
 )
 @pass_context
 async def update_community(
-    ctx,
+    ctx: click.Context,
     collection_id,
     community_id,
     name,
@@ -319,64 +379,80 @@ async def update_community(
     """Update community information."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.update_community(
-            collection_id=collection_id,
-            community_id=community_id,
-            name=name,
-            summary=summary,
-            findings=findings,
-            rating=rating,
-            rating_explanation=rating_explanation,
-            level=level,
-            attributes=attributes,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.update_community(
+                collection_id=collection_id,
+                community_id=community_id,
+                name=name,
+                summary=summary,
+                findings=findings,
+                rating=rating,
+                rating_explanation=rating_explanation,
+                level=level,
+                attributes=attributes,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @pass_context
-async def delete_community(ctx, collection_id, community_id):
+async def delete_community(ctx: click.Context, collection_id, community_id):
     """Delete a community from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.delete_community(
-            collection_id=collection_id,
-            community_id=community_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.delete_community(
+                collection_id=collection_id,
+                community_id=community_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @pass_context
-async def pull(ctx, collection_id):
+async def pull(ctx: click.Context, collection_id):
     """Pull documents into a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.pull(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.pull(collection_id=collection_id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("document_id", required=True, type=str)
 @pass_context
-async def remove_document(ctx, collection_id, document_id):
+async def remove_document(ctx: click.Context, collection_id, document_id):
     """Remove a document from a graph."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.graphs.remove_document(
-            collection_id=collection_id,
-            document_id=document_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.remove_document(
+                collection_id=collection_id,
+                document_id=document_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 39 - 41
cli/commands/indices.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -25,65 +25,63 @@ def indices():
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list(ctx, offset, limit):
+async def list(ctx: click.Context, offset, limit):
     """Get an overview of indices."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.indices.list(
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.indices.list(
+                offset=offset,
+                limit=limit,
+            )
 
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @indices.command()
 @click.argument("index_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @pass_context
-async def retrieve(ctx, index_name, table_name):
+async def retrieve(ctx: click.Context, index_name, table_name):
     """Retrieve an index by name."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.indices.retrieve(
-            index_name=index_name,
-            table_name=table_name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.indices.retrieve(
+                index_name=index_name,
+                table_name=table_name,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @indices.command()
 @click.argument("index_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @pass_context
-async def delete(ctx, index_name, table_name):
+async def delete(ctx: click.Context, index_name, table_name):
     """Delete an index by name."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.indices.retrieve(
-            index_name=index_name,
-            table_name=table_name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
-
-
-@indices.command()
-@click.argument("id", required=True, type=str)
-@pass_context
-async def list_branches(ctx, id):
-    """List all branches in a conversation."""
-    client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.indices.list_branches(
-            id=id,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.indices.retrieve(
+                index_name=index_name,
+                table_name=table_name,
+            )
+
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 36 - 23
cli/commands/prompts.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -15,15 +15,20 @@ def prompts():
 
 @prompts.command()
 @pass_context
-async def list(ctx):
+async def list(ctx: click.Context):
     """Get an overview of prompts."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.prompts.list()
+    try:
+        with timer():
+            response = await client.prompts.list()
 
-    for prompt in response["results"]:
-        click.echo(json.dumps(prompt, indent=2))
+        for prompt in response["results"]:  # type: ignore
+            click.echo(json.dumps(prompt, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @prompts.command()
@@ -31,30 +36,38 @@ async def list(ctx):
 @click.option("--inputs", default=None, type=str)
 @click.option("--prompt-override", default=None, type=str)
 @pass_context
-async def retrieve(ctx, name, inputs, prompt_override):
+async def retrieve(ctx: click.Context, name, inputs, prompt_override):
     """Retrieve an prompts by name."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.prompts.retrieve(
-            name=name,
-            inputs=inputs,
-            prompt_override=prompt_override,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.prompts.retrieve(
+                name=name,
+                inputs=inputs,
+                prompt_override=prompt_override,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @prompts.command()
 @click.argument("name", required=True, type=str)
 @pass_context
-async def delete(ctx, name):
-    """Delete an index by name."""
+async def delete(ctx: click.Context, name):
+    """Delete a prompt by name."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.prompts.delete(
-            name=name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.prompts.delete(
+                name=name,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 52 - 39
cli/commands/retrieval.py

@@ -5,7 +5,7 @@ from asyncclick import pass_context
 
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -52,9 +52,8 @@ def retrieval():
     help="Use search over document chunks?",
 )
 @pass_context
-async def search(ctx, query, **kwargs):
+async def search(ctx: click.Context, query, **kwargs):
     """Perform a search query."""
-    client: R2RAsyncClient = ctx.obj
     search_settings = {
         k: v
         for k, v in kwargs.items()
@@ -78,28 +77,36 @@ async def search(ctx, query, **kwargs):
     if chunk_search_enabled != None:
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
 
-    with timer():
-        results = await client.retrieval.search(
-            query,
-            "custom",
-            search_settings,
-        )
-
-        if isinstance(results, dict) and "results" in results:
-            results = results["results"]
-
-        if "chunk_search_results" in results:
-            click.echo("Vector search results:")
-            for result in results["chunk_search_results"]:
-                click.echo(json.dumps(result, indent=2))
+    client: R2RAsyncClient = ctx.obj
 
-        if (
-            "graph_search_results" in results
-            and results["graph_search_results"]
-        ):
-            click.echo("KG search results:")
-            for result in results["graph_search_results"]:
-                click.echo(json.dumps(result, indent=2))
+    print("client.base_url = ", client.base_url)
+    try:
+        with timer():
+            results = await client.retrieval.search(
+                query,
+                "custom",
+                search_settings,
+            )
+
+            if isinstance(results, dict) and "results" in results:
+                results = results["results"]
+
+            if "chunk_search_results" in results:  # type: ignore
+                click.echo("Vector search results:")
+                for result in results["chunk_search_results"]:  # type: ignore
+                    click.echo(json.dumps(result, indent=2))
+
+            if (
+                "graph_search_results" in results  # type: ignore
+                and results["graph_search_results"]  # type: ignore
+            ):
+                click.echo("KG search results:")
+                for result in results["graph_search_results"]:  # type: ignore
+                    click.echo(json.dumps(result, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @retrieval.command()
@@ -142,9 +149,8 @@ async def search(ctx, query, **kwargs):
 @click.option("--stream", is_flag=True, help="Stream the RAG response")
 @click.option("--rag-model", default=None, help="Model for RAG")
 @pass_context
-async def rag(ctx, query, **kwargs):
+async def rag(ctx: click.Context, query, **kwargs):
     """Perform a RAG query."""
-    client: R2RAsyncClient = ctx.obj
     rag_generation_config = {
         "stream": kwargs.get("stream", False),
     }
@@ -174,16 +180,23 @@ async def rag(ctx, query, **kwargs):
     if chunk_search_enabled != None:
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
 
-    with timer():
-        response = await client.retrieval.rag(
-            query=query,
-            rag_generation_config=rag_generation_config,
-            search_settings={**search_settings},
-        )
-
-        if rag_generation_config.get("stream"):
-            async for chunk in response:
-                click.echo(chunk, nl=False)
-            click.echo()
-        else:
-            click.echo(json.dumps(response["results"]["completion"], indent=2))
+    client: R2RAsyncClient = ctx.obj
+
+    try:
+        with timer():
+            response = await client.retrieval.rag(
+                query=query,
+                rag_generation_config=rag_generation_config,
+                search_settings={**search_settings},
+            )
+
+            if rag_generation_config.get("stream"):
+                async for chunk in response:  # type: ignore
+                    click.echo(chunk, nl=False)
+                click.echo()
+            else:
+                click.echo(json.dumps(response["results"]["completion"], indent=2))  # type: ignore
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 40 - 17
cli/commands/system.py

@@ -18,7 +18,7 @@ from cli.utils.docker_utils import (
     wait_for_container_health,
 )
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -29,35 +29,53 @@ def system():
 
 @cli.command()
 @pass_context
-async def health(ctx):
+async def health(ctx: click.Context):
     """Check the health of the server."""
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.health()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.health()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 @system.command()
 @pass_context
-async def settings(ctx):
+async def settings(ctx: click.Context):
     """Retrieve application settings."""
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.settings()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.settings()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 @system.command()
 @pass_context
-async def status(ctx):
+async def status(ctx: click.Context):
     """Get statistics about the server, including the start time, uptime, CPU usage, and memory usage."""
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.status()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.status()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 @cli.command()
@@ -400,4 +418,9 @@ def version():
     """Reports the SDK version."""
     from importlib.metadata import version
 
-    click.echo(json.dumps(version("r2r"), indent=2))
+    try:
+        r2r_version = version("r2r")
+        click.echo(json.dumps(r2r_version, indent=2))
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise

+ 98 - 53
cli/commands/users.py

@@ -1,10 +1,12 @@
 import json
+from builtins import list as _list
+from uuid import UUID
 
 import asyncclick as click
 from asyncclick import pass_context
 
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 @click.group()
@@ -17,68 +19,98 @@ def users():
 @click.argument("email", required=True, type=str)
 @click.argument("password", required=True, type=str)
 @pass_context
-async def create(ctx, email, password):
+async def create(ctx: click.Context, email: str, password: str):
     """Create a new user."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.create(email=email, password=password)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.create(
+                email=email,
+                password=password,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
-@click.option("--ids", multiple=True, help="Document IDs to fetch")
+@click.option("--ids", multiple=True, help="Document IDs to fetch", type=str)
 @click.option(
     "--offset",
     default=0,
     help="The offset to start from. Defaults to 0.",
+    type=int,
 )
 @click.option(
     "--limit",
     default=100,
     help="The maximum number of nodes to return. Defaults to 100.",
+    type=int,
 )
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(
+    ctx: click.Context,
+    ids: tuple[str, ...],
+    offset: int,
+    limit: int,
+):
     """Get an overview of users."""
-    client: R2RAsyncClient = ctx.obj
-    ids = list(ids) if ids else None
+    uuids: _list[str | UUID] | None = (
+        [UUID(id_) for id_ in ids] if ids else None
+    )
 
-    with timer():
-        response = await client.users.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    client: R2RAsyncClient = ctx.obj
+    try:
+        with timer():
+            response = await client.users.list(
+                ids=uuids,
+                offset=offset,
+                limit=limit,
+            )
+
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
 @click.argument("id", required=True, type=str)
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
 @pass_context
-async def me(ctx):
+async def me(ctx: click.Context):
     """Retrieve the current user."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.me()
+    try:
+        with timer():
+            response = await client.users.me()
 
-    click.echo(json.dumps(response, indent=2))
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
@@ -94,50 +126,63 @@ async def me(ctx):
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 @pass_context
-async def list_collections(ctx, id, offset, limit):
+async def list_collections(ctx: click.Context, id, offset, limit):
     """List collections for a specific user."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.list_collections(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.users.list_collections(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
 
-    for collection in response["results"]:
-        click.echo(json.dumps(collection, indent=2))
+        for collection in response["results"]:  # type: ignore
+            click.echo(json.dumps(collection, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
 @click.argument("id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
-async def add_to_collection(ctx, id, collection_id):
+async def add_to_collection(ctx: click.Context, id, collection_id):
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.add_to_collection(
-            id=id,
-            collection_id=collection_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.add_to_collection(
+                id=id,
+                collection_id=collection_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 @users.command()
 @click.argument("id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
-async def remove_from_collection(ctx, id, collection_id):
+async def remove_from_collection(ctx: click.Context, id, collection_id):
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
 
-    with timer():
-        response = await client.users.remove_from_collection(
-            id=id,
-            collection_id=collection_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.remove_from_collection(
+                id=id,
+                collection_id=collection_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 76 - 4
cli/main.py

@@ -1,6 +1,13 @@
+import json
+from typing import Any, Dict
+
+import asyncclick as click
+from rich.console import Console
+
 from cli.command_group import cli
 from cli.commands import (
     collections,
+    config,
     conversations,
     database,
     documents,
@@ -12,6 +19,11 @@ from cli.commands import (
     users,
 )
 from cli.utils.telemetry import posthog, telemetry
+from r2r import R2RAsyncClient
+
+from .command_group import CONFIG_DIR, CONFIG_FILE, load_config
+
+console = Console()
 
 
 def add_command_with_telemetry(command):
@@ -39,22 +51,82 @@ add_command_with_telemetry(database.downgrade)
 add_command_with_telemetry(database.current)
 add_command_with_telemetry(database.history)
 
+add_command_with_telemetry(config.configure)
+
 
 def main():
     try:
         cli()
     except SystemExit:
-        # Silently exit without printing the traceback
         pass
     except Exception as e:
-        # Handle other exceptions if needed
-        raise e
+        console.print("[red]CLI error: An error occurred[/red]")
+        console.print_exception()
     finally:
-        # Ensure all events are flushed before exiting
         if posthog:
             posthog.flush()
             posthog.shutdown()
 
 
+def _ensure_config_dir_exists() -> None:
+    """Ensure that the ~/.r2r/ directory exists."""
+    CONFIG_DIR.mkdir(parents=True, exist_ok=True)
+
+
+def save_config(config_data: Dict[str, Any]) -> None:
+    """
+    Persist the given config data to ~/.r2r/config.json.
+    """
+    _ensure_config_dir_exists()
+    with open(CONFIG_FILE, "w", encoding="utf-8") as f:
+        json.dump(config_data, f, indent=2)
+
+
+@cli.command("set-api-key", short_help="Set your R2R API key")
+@click.argument("api_key", required=True, type=str)
+@click.pass_context
+async def set_api_key(ctx, api_key: str):
+    """
+    Store your R2R API key locally so you don’t have to pass it on every command.
+    Example usage:
+      r2r set-api sk-1234abcd
+    """
+    try:
+        # 1) Load existing config
+        config = load_config()
+
+        # 2) Overwrite or add the API key
+        config["api_key"] = api_key
+
+        # 3) Save changes
+        save_config(config)
+
+        console.print("[green]API key set successfully![/green]")
+    except Exception as e:
+        console.print("[red]Failed to set API key:[/red]", str(e))
+
+
+@cli.command("get-api", short_help="Get your stored R2R API key")
+@click.pass_context
+async def get_api(ctx):
+    """
+    Display your stored R2R API key.
+    Example usage:
+      r2r get-api
+    """
+    try:
+        config = load_config()
+        api_key = config.get("api_key")
+
+        if api_key:
+            console.print(f"API Key: {api_key}")
+        else:
+            console.print(
+                "[yellow]No API key found. Set one using 'r2r set-api <key>'[/yellow]"
+            )
+    except Exception as e:
+        console.print("[red]Failed to retrieve API key:[/red]", str(e))
+
+
 if __name__ == "__main__":
     main()

+ 0 - 6
cli/utils/docker_utils.py

@@ -116,12 +116,6 @@ async def run_local_serve(
 
     await r2r_instance.orchestration_provider.start_worker()
 
-    # TODO: make this work with autoreload, currently due to hatchet, it causes a reload error
-    # import uvicorn
-    # uvicorn.run(
-    #     "core.main.app_entry:app", host=host, port=available_port, reload=False
-    # )
-
     await r2r_instance.serve(host, available_port)
 
 

+ 2 - 1
cli/utils/timer.py

@@ -13,4 +13,5 @@ def timer():
     start = time.time()
     yield
     end = time.time()
-    click.echo(f"Time taken: {end - start:.2f} seconds")
+    duration = max(0, end - start)
+    click.echo(f"Time taken: {duration:.2f} seconds")

+ 124 - 0
compose.yaml

@@ -0,0 +1,124 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    ports:
+      - "${R2R_PORT:-7272}:${R2R_PORT:-7272}"
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+    extra_hosts:
+      - host.docker.internal:host-gateway
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"

+ 0 - 1
core/base/parsers/base_parser.py

@@ -7,7 +7,6 @@ T = TypeVar("T")
 
 
 class AsyncParser(ABC, Generic[T]):
-
     @abstractmethod
     async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
         pass

+ 0 - 1
core/base/providers/ingestion.py

@@ -179,7 +179,6 @@ class IngestionConfig(ProviderConfig):
 
 
 class IngestionProvider(Provider, ABC):
-
     config: IngestionConfig
     database_provider: "PostgresDatabaseProvider"
     llm_provider: CompletionProvider

+ 53 - 10
core/database/base.py

@@ -54,11 +54,18 @@ class QueryBuilder:
     def __init__(self, table_name: str):
         self.table_name = table_name
         self.conditions: list[str] = []
-        self.params: dict = {}
+        self.params: list = (
+            []
+        )  # Changed from dict to list for PostgreSQL $1, $2 style
         self.select_fields = "*"
         self.operation = "SELECT"
         self.limit_value: Optional[int] = None
+        self.offset_value: Optional[int] = None
+        self.order_by_fields: Optional[str] = None
+        self.returning_fields: Optional[list[str]] = None
         self.insert_data: Optional[dict] = None
+        self.update_data: Optional[dict] = None
+        self.param_counter = 1  # For generating $1, $2, etc.
 
     def select(self, fields: list[str]):
         self.select_fields = ", ".join(fields)
@@ -69,45 +76,81 @@ class QueryBuilder:
         self.insert_data = data
         return self
 
+    def update(self, data: dict):
+        self.operation = "UPDATE"
+        self.update_data = data
+        return self
+
     def delete(self):
         self.operation = "DELETE"
         return self
 
-    def where(self, condition: str, **kwargs):
+    def where(self, condition: str):
         self.conditions.append(condition)
-        self.params.update(kwargs)
         return self
 
-    def limit(self, value: int):
+    def limit(self, value: Optional[str]):
         self.limit_value = value
         return self
 
+    def offset(self, value: str):
+        self.offset_value = value
+        return self
+
+    def order_by(self, fields: str):
+        self.order_by_fields = fields
+        return self
+
+    def returning(self, fields: list[str]):
+        self.returning_fields = fields
+        return self
+
     def build(self):
         if self.operation == "SELECT":
             query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
         elif self.operation == "INSERT":
             columns = ", ".join(self.insert_data.keys())
-            values = ", ".join(f":{key}" for key in self.insert_data.keys())
-            query = (
-                f"INSERT INTO {self.table_name} ({columns}) VALUES ({values})"
+            placeholders = ", ".join(
+                f"${i}" for i in range(1, len(self.insert_data) + 1)
             )
-            self.params.update(self.insert_data)
+            query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+            self.params.extend(list(self.insert_data.values()))
+
+        elif self.operation == "UPDATE":
+            set_clauses = []
+            for i, (key, value) in enumerate(
+                self.update_data.items(), start=len(self.params) + 1
+            ):
+                set_clauses.append(f"{key} = ${i}")
+                self.params.append(value)
+            query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
         elif self.operation == "DELETE":
             query = f"DELETE FROM {self.table_name}"
+
         else:
             raise ValueError(f"Unsupported operation: {self.operation}")
 
         if self.conditions:
             query += " WHERE " + " AND ".join(self.conditions)
 
-        if self.limit_value is not None and self.operation == "SELECT":
+        if self.order_by_fields and self.operation == "SELECT":
+            query += f" ORDER BY {self.order_by_fields}"
+
+        if self.offset_value is not None:
+            query += f" OFFSET {self.offset_value}"
+
+        if self.limit_value is not None:
             query += f" LIMIT {self.limit_value}"
 
+        if self.returning_fields:
+            query += f" RETURNING {', '.join(self.returning_fields)}"
+
         return query, self.params
 
 
 class PostgresConnectionManager(DatabaseConnectionManager):
-
     def __init__(self):
         self.pool: Optional[SemaphoreConnectionPool] = None
 

+ 0 - 2
core/database/chunks.py

@@ -423,7 +423,6 @@ class PostgresChunksHandler(Handler):
     async def full_text_search(
         self, query_text: str, search_settings: SearchSettings
     ) -> list[ChunkSearchResult]:
-
         conditions = []
         params: list[str | int | bytes] = [query_text]
 
@@ -1049,7 +1048,6 @@ class PostgresChunksHandler(Handler):
         id: UUID,
         similarity_threshold: float = 0.5,
     ) -> list[dict[str, Any]]:
-
         table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
         query = f"""
         WITH target_vector AS (

+ 0 - 1
core/database/collections.py

@@ -73,7 +73,6 @@ class PostgresCollectionsHandler(Handler):
         description: str = "",
         collection_id: Optional[UUID] = None,
     ) -> CollectionResponse:
-
         if not name and not collection_id:
             name = self.config.default_collection_name
             collection_id = generate_default_user_collection_id(owner_id)

+ 0 - 1
core/database/documents.py

@@ -152,7 +152,6 @@ class PostgresDocumentsHandler(Handler):
                                     document.id,
                                 )
                             else:
-
                                 insert_query = f"""
                                 INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                                 (id, collection_ids, owner_id, type, metadata, title, version,

+ 0 - 13
core/database/graphs.py

@@ -776,7 +776,6 @@ class PostgresRelationshipsHandler(Handler):
 
 
 class PostgresCommunitiesHandler(Handler):
-
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
@@ -784,7 +783,6 @@ class PostgresCommunitiesHandler(Handler):
         self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
 
     async def create_tables(self) -> None:
-
         vector_column_str = _decorate_vector_type(
             f"({self.dimension})", self.quantization_type
         )
@@ -1072,7 +1070,6 @@ class PostgresGraphsHandler(Handler):
         *args: Any,
         **kwargs: Any,
     ) -> None:
-
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
         self.dimension: int = kwargs.get("dimension")  # type: ignore
@@ -1258,9 +1255,7 @@ class PostgresGraphsHandler(Handler):
     async def get(
         self, offset: int, limit: int, graph_id: Optional[UUID] = None
     ):
-
         if graph_id is None:
-
             params = [offset, limit]
 
             QUERY = f"""
@@ -1498,7 +1493,6 @@ class PostgresGraphsHandler(Handler):
     ):
         """Get the estimated cost and time for enriching a KG."""
         if collection_id is not None:
-
             document_ids = [
                 doc.id
                 for doc in (
@@ -1857,7 +1851,6 @@ class PostgresGraphsHandler(Handler):
         graph_id: UUID | None,
         document_ids: Optional[list[UUID]] = None,
     ) -> list[Relationship]:
-
         QUERY = f"""
             SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
         """
@@ -1969,7 +1962,6 @@ class PostgresGraphsHandler(Handler):
         return communities, count
 
     async def add_community(self, community: Community) -> None:
-
         # TODO: Fix in the short term.
         # we need to do this because postgres insert needs to be a string
         community.description_embedding = str(community.description_embedding)  # type: ignore[assignment]
@@ -1997,7 +1989,6 @@ class PostgresGraphsHandler(Handler):
 
     # async def delete(self, collection_id: UUID, cascade: bool = False) -> None:
     async def delete(self, collection_id: UUID) -> None:
-
         graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
 
         if len(graphs["results"]) == 0:
@@ -2168,7 +2159,6 @@ class PostgresGraphsHandler(Handler):
         collection_id: Optional[UUID] = None,
         clustering_mode: str = "local",
     ) -> Tuple[int, Any]:
-
         # clear if there is any old information
         conditions = []
         if collection_id is not None:
@@ -2248,7 +2238,6 @@ class PostgresGraphsHandler(Handler):
     async def get_entity_map(
         self, offset: int, limit: int, document_id: UUID
     ) -> dict[str, dict[str, list[dict[str, Any]]]]:
-
         QUERY1 = f"""
             WITH entities_list AS (
                 SELECT DISTINCT name
@@ -2555,7 +2544,6 @@ class PostgresGraphsHandler(Handler):
         distinct: bool = False,
         entity_table_name: str = "entity",
     ) -> int:
-
         if collection_id is None and document_id is None:
             raise ValueError(
                 "Either collection_id or document_id must be provided."
@@ -2576,7 +2564,6 @@ class PostgresGraphsHandler(Handler):
         ]
 
     async def update_entity_descriptions(self, entities: list[Entity]):
-
         query = f"""
             UPDATE {self._get_table_name("graphs_entities")}
             SET description = $3, description_embedding = $4

+ 89 - 54
core/database/limits.py

@@ -4,6 +4,7 @@ from typing import Optional
 from uuid import UUID
 
 from core.base import Handler
+from shared.abstractions import User  # your domain user model
 
 from ..base.providers.database import DatabaseConfig, LimitSettings
 from .base import PostgresConnectionManager
@@ -20,8 +21,12 @@ class PostgresLimitsHandler(Handler):
         connection_manager: PostgresConnectionManager,
         config: DatabaseConfig,
     ):
+        """
+        :param config: The global DatabaseConfig with default rate limits.
+        """
         super().__init__(project_name, connection_manager)
-        self.config = config
+        self.config = config  # Contains e.g. self.config.limits for fallback
+
         logger.debug(
             f"Initialized PostgresLimitsHandler with project: {project_name}"
         )
@@ -38,8 +43,15 @@ class PostgresLimitsHandler(Handler):
         await self.connection_manager.execute_query(query)
 
     async def _count_requests(
-        self, user_id: UUID, route: Optional[str], since: datetime
+        self,
+        user_id: UUID,
+        route: Optional[str],
+        since: datetime,
     ) -> int:
+        """
+        Count how many requests a user (optionally for a specific route)
+        has made since the given datetime.
+        """
         if route:
             query = f"""
             SELECT COUNT(*)::int
@@ -49,7 +61,9 @@ class PostgresLimitsHandler(Handler):
               AND time >= $3
             """
             params = [user_id, route, since]
-            logger.debug(f"Counting requests for route {route}")
+            logger.debug(
+                f"Counting requests for user={user_id}, route={route}"
+            )
         else:
             query = f"""
             SELECT COUNT(*)::int
@@ -58,68 +72,86 @@ class PostgresLimitsHandler(Handler):
               AND time >= $2
             """
             params = [user_id, since]
-            logger.debug("Counting all requests")
+            logger.debug(f"Counting all requests for user={user_id}")
 
         result = await self.connection_manager.fetchrow_query(query, params)
-        count = result["count"] if result else 0
-
-        return count
+        return result["count"] if result else 0
 
     async def _count_monthly_requests(self, user_id: UUID) -> int:
+        """
+        Count the number of requests so far this month for a given user.
+        """
         now = datetime.now(timezone.utc)
         start_of_month = now.replace(
             day=1, hour=0, minute=0, second=0, microsecond=0
         )
+        return await self._count_requests(user_id, None, start_of_month)
 
-        count = await self._count_requests(
-            user_id, route=None, since=start_of_month
-        )
-        return count
-
-    def _determine_limits_for(
-        self, user_id: UUID, route: str
-    ) -> LimitSettings:
-        # Start with base limits
-        limits = self.config.limits
-
-        # Route-specific limits - directly override if present
-        route_limits = self.config.route_limits.get(route)
-        if route_limits:
-            # Only override non-None values from route_limits
-            if route_limits.global_per_min is not None:
-                limits.global_per_min = route_limits.global_per_min
-            if route_limits.route_per_min is not None:
-                limits.route_per_min = route_limits.route_per_min
-            if route_limits.monthly_limit is not None:
-                limits.monthly_limit = route_limits.monthly_limit
-
-        # User-specific limits - directly override if present
-        user_limits = self.config.user_limits.get(user_id)
-        if user_limits:
-            # Only override non-None values from user_limits
-            if user_limits.global_per_min is not None:
-                limits.global_per_min = user_limits.global_per_min
-            if user_limits.route_per_min is not None:
-                limits.route_per_min = user_limits.route_per_min
-            if user_limits.monthly_limit is not None:
-                limits.monthly_limit = user_limits.monthly_limit
-
-        return limits
-
-    async def check_limits(self, user_id: UUID, route: str):
-        # Determine final applicable limits
-        limits = self._determine_limits_for(user_id, route)
-        if not limits:
-            limits = self.config.default_limits
-
-        global_per_min = limits.global_per_min
-        route_per_min = limits.route_per_min
-        monthly_limit = limits.monthly_limit
+    async def check_limits(self, user: User, route: str):
+        """
+        Perform rate limit checks for a user on a specific route.
 
+        :param user: The fully-fetched User object with .limits_overrides, etc.
+        :param route: The route/path being accessed.
+        :raises ValueError: if any limit is exceeded.
+        """
+        user_id = user.id
         now = datetime.now(timezone.utc)
         one_min_ago = now - timedelta(minutes=1)
 
-        # Global per-minute check
+        # 1) First check route-specific configuration limits
+        route_config = self.config.route_limits.get(route)
+        if route_config:
+            # Check route-specific per-minute limit
+            if route_config.route_per_min is not None:
+                route_req_count = await self._count_requests(
+                    user_id, route, one_min_ago
+                )
+                if route_req_count > route_config.route_per_min:
+                    logger.warning(
+                        f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
+                    )
+                    raise ValueError(
+                        "Per-route per-minute rate limit exceeded"
+                    )
+
+            # Check route-specific monthly limit
+            if route_config.monthly_limit is not None:
+                monthly_count = await self._count_monthly_requests(user_id)
+                if monthly_count > route_config.monthly_limit:
+                    logger.warning(
+                        f"Route monthly limit exceeded for user_id={user_id}, route={route}"
+                    )
+                    raise ValueError("Route monthly limit exceeded")
+
+        # 2) Get user overrides and base limits
+        user_overrides = user.limits_overrides or {}
+        base_limits = self.config.limits
+
+        # Extract user-level overrides
+        global_per_min = user_overrides.get(
+            "global_per_min", base_limits.global_per_min
+        )
+        monthly_limit = user_overrides.get(
+            "monthly_limit", base_limits.monthly_limit
+        )
+
+        # 3) Check route-specific overrides from user config
+        route_overrides = user_overrides.get("route_overrides", {})
+        specific_config = route_overrides.get(route, {})
+
+        # Apply route-specific overrides for per-minute limits
+        route_per_min = specific_config.get(
+            "route_per_min", base_limits.route_per_min
+        )
+
+        # If route specifically overrides global or monthly limits, apply them
+        if "global_per_min" in specific_config:
+            global_per_min = specific_config["global_per_min"]
+        if "monthly_limit" in specific_config:
+            monthly_limit = specific_config["monthly_limit"]
+
+        # 4) Check global per-minute limit
         if global_per_min is not None:
             user_req_count = await self._count_requests(
                 user_id, None, one_min_ago
@@ -130,7 +162,7 @@ class PostgresLimitsHandler(Handler):
                 )
                 raise ValueError("Global per-minute rate limit exceeded")
 
-        # Per-route per-minute check
+        # 5) Check user-specific route per-minute limit
         if route_per_min is not None:
             route_req_count = await self._count_requests(
                 user_id, route, one_min_ago
@@ -141,7 +173,7 @@ class PostgresLimitsHandler(Handler):
                 )
                 raise ValueError("Per-route per-minute rate limit exceeded")
 
-        # Monthly limit check
+        # 6) Check monthly limit
         if monthly_limit is not None:
             monthly_count = await self._count_monthly_requests(user_id)
             if monthly_count > monthly_limit:
@@ -151,6 +183,9 @@ class PostgresLimitsHandler(Handler):
                 raise ValueError("Monthly rate limit exceeded")
 
     async def log_request(self, user_id: UUID, route: str):
+        """
+        Log a successful request to the request_log table.
+        """
         query = f"""
         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)

+ 234 - 116
core/database/users.py

@@ -1,5 +1,6 @@
+import json
 from datetime import datetime
-from typing import Optional
+from typing import Any, Dict, List, Optional
 from uuid import UUID
 
 from fastapi import HTTPException
@@ -43,10 +44,12 @@ class PostgresUserHandler(Handler):
             reset_token TEXT,
             reset_token_expiry TIMESTAMPTZ,
             collection_ids UUID[] NULL,
+            limits_overrides JSONB,
             created_at TIMESTAMPTZ DEFAULT NOW(),
             updated_at TIMESTAMPTZ DEFAULT NOW()
         );
         """
+
         # API keys table with updated_at instead of last_used_at
         api_keys_table_query = f"""
         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
@@ -86,6 +89,7 @@ class PostgresUserHandler(Handler):
                     "profile_picture",
                     "bio",
                     "collection_ids",
+                    "limits_overrides",  # Fetch JSONB column
                 ]
             )
             .where("id = $1")
@@ -109,6 +113,8 @@ class PostgresUserHandler(Handler):
             profile_picture=result["profile_picture"],
             bio=result["bio"],
             collection_ids=result["collection_ids"],
+            # Add the new field
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
         )
 
     async def get_user_by_email(self, email: str) -> User:
@@ -128,6 +134,7 @@ class PostgresUserHandler(Handler):
                     "profile_picture",
                     "bio",
                     "collection_ids",
+                    "limits_overrides",
                 ]
             )
             .where("email = $1")
@@ -150,13 +157,16 @@ class PostgresUserHandler(Handler):
             profile_picture=result["profile_picture"],
             bio=result["bio"],
             collection_ids=result["collection_ids"],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
         )
 
     async def create_user(
         self, email: str, password: str, is_superuser: bool = False
     ) -> User:
+        """Create a new user."""
         try:
-            if await self.get_user_by_email(email):
+            existing = await self.get_user_by_email(email)
+            if existing:
                 raise R2RException(
                     status_code=400,
                     message="User with this email already exists",
@@ -166,27 +176,39 @@ class PostgresUserHandler(Handler):
                 raise e
 
         hashed_password = self.crypto_provider.get_password_hash(password)  # type: ignore
-        query = f"""
-            INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            (email, id, is_superuser, hashed_password, collection_ids)
-            VALUES ($1, $2, $3, $4, $5)
-            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
-        """
-        result = await self.connection_manager.fetchrow_query(
-            query,
-            [
-                email,
-                generate_user_id(email),
-                is_superuser,
-                hashed_password,
-                [],
-            ],
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .insert(
+                {
+                    "email": email,
+                    "id": generate_user_id(email),
+                    "is_superuser": is_superuser,
+                    "hashed_password": hashed_password,
+                    "collection_ids": [],
+                    "limits_overrides": None,
+                }
+            )
+            .returning(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "limits_overrides",
+                ]
+            )
+            .build()
         )
 
+        result = await self.connection_manager.fetchrow_query(query, params)
         if not result:
-            raise HTTPException(
+            raise R2RException(
                 status_code=500,
-                detail="Failed to create user",
+                message="Failed to create user",
             )
 
         return User(
@@ -197,17 +219,62 @@ class PostgresUserHandler(Handler):
             is_verified=result["is_verified"],
             created_at=result["created_at"],
             updated_at=result["updated_at"],
-            collection_ids=result["collection_ids"],
+            collection_ids=result["collection_ids"] or [],
             hashed_password=hashed_password,
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            name=None,
+            bio=None,
+            profile_picture=None,
         )
 
-    async def update_user(self, user: User) -> User:
+    async def update_user(
+        self, user: User, merge_limits: bool = False
+    ) -> User:
+        """
+        Update user information including limits_overrides.
+
+        Args:
+            user: User object containing updated information
+            merge_limits: If True, will merge existing limits_overrides with new ones.
+                        If False, will overwrite existing limits_overrides.
+
+        Returns:
+            Updated User object
+        """
+        # Get current user if we need to merge limits or get hashed password
+        current_user = None
+        try:
+            current_user = await self.get_user_by_id(user.id)
+        except R2RException:
+            raise R2RException(status_code=404, message="User not found")
+
+        # Merge or replace limits_overrides
+        final_limits = user.limits_overrides
+        if (
+            merge_limits
+            and current_user.limits_overrides
+            and user.limits_overrides
+        ):
+            final_limits = {
+                **current_user.limits_overrides,
+                **user.limits_overrides,
+            }
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
-                name = $5, profile_picture = $6, bio = $7, collection_ids = $8
-            WHERE id = $9
-            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
+            SET email = $1,
+                is_superuser = $2,
+                is_active = $3,
+                is_verified = $4,
+                updated_at = NOW(),
+                name = $5,
+                profile_picture = $6,
+                bio = $7,
+                collection_ids = $8,
+                limits_overrides = $9::jsonb
+            WHERE id = $10
+            RETURNING id, email, is_superuser, is_active, is_verified,
+                    created_at, updated_at, name, profile_picture, bio,
+                    collection_ids, limits_overrides, hashed_password
         """
         result = await self.connection_manager.fetchrow_query(
             query,
@@ -219,7 +286,8 @@ class PostgresUserHandler(Handler):
                 user.name,
                 user.profile_picture,
                 user.bio,
-                user.collection_ids,
+                user.collection_ids or [],  # Ensure null becomes empty array
+                json.dumps(final_limits),  # Already handled null case
                 user.id,
             ],
         )
@@ -233,6 +301,9 @@ class PostgresUserHandler(Handler):
         return User(
             id=result["id"],
             email=result["email"],
+            hashed_password=result[
+                "hashed_password"
+            ],  # Include hashed_password
             is_superuser=result["is_superuser"],
             is_active=result["is_active"],
             is_verified=result["is_verified"],
@@ -241,15 +312,23 @@ class PostgresUserHandler(Handler):
             name=result["name"],
             profile_picture=result["profile_picture"],
             bio=result["bio"],
-            collection_ids=result["collection_ids"],
+            collection_ids=result["collection_ids"]
+            or [],  # Ensure null becomes empty array
+            limits_overrides=json.loads(
+                result["limits_overrides"] or "{}"
+            ),  # Can be null
         )
 
     async def delete_user_relational(self, id: UUID) -> None:
+        """Delete a user and update related records."""
         # Get the collections the user belongs to
-        collection_query = f"""
-            SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            WHERE id = $1
-        """
+        collection_query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(["collection_ids"])
+            .where("id = $1")
+            .build()
+        )
+
         collection_result = await self.connection_manager.fetchrow_query(
             collection_query, [id]
         )
@@ -257,20 +336,25 @@ class PostgresUserHandler(Handler):
         if not collection_result:
             raise R2RException(status_code=404, message="User not found")
 
-        # Remove user from documents
-        doc_update_query = f"""
-            UPDATE {self._get_table_name('documents')}
-            SET id = NULL
-            WHERE id = $1
-        """
+        # Update documents query
+        doc_update_query, doc_params = (
+            QueryBuilder(self._get_table_name("documents"))
+            .update({"id": None})
+            .where("id = $1")
+            .build()
+        )
+
         await self.connection_manager.execute_query(doc_update_query, [id])
 
-        # Delete the user
-        delete_query = f"""
-            DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            WHERE id = $1
-            RETURNING id
-        """
+        # Delete user query
+        delete_query, del_params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .delete()
+            .where("id = $1")
+            .returning(["id"])
+            .build()
+        )
+
         result = await self.connection_manager.fetchrow_query(
             delete_query, [id]
         )
@@ -288,24 +372,48 @@ class PostgresUserHandler(Handler):
             query, [new_hashed_password, id]
         )
 
-    async def get_all_users(self) -> list[User]:
-        query = f"""
-            SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
-            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-        """
-        results = await self.connection_manager.fetch_query(query)
+    async def get_all_users(self) -> List[User]:
+        """Get all users with minimal information."""
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "hashed_password",
+                    "limits_overrides",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                ]
+            )
+            .build()
+        )
 
+        results = await self.connection_manager.fetch_query(query, params)
         return [
             User(
                 id=result["id"],
                 email=result["email"],
-                hashed_password="null",
+                hashed_password=result["hashed_password"],
                 is_superuser=result["is_superuser"],
                 is_active=result["is_active"],
                 is_verified=result["is_verified"],
                 created_at=result["created_at"],
                 updated_at=result["updated_at"],
-                collection_ids=result["collection_ids"],
+                collection_ids=result["collection_ids"] or [],
+                limits_overrides=json.loads(
+                    result["limits_overrides"] or "{}"
+                ),
+                name=result["name"],
+                bio=result["bio"],
+                profile_picture=result["profile_picture"],
             )
             for result in results
         ]
@@ -456,41 +564,44 @@ class PostgresUserHandler(Handler):
     async def get_users_in_collection(
         self, collection_id: UUID, offset: int, limit: int
     ) -> dict[str, list[User] | int]:
-        """
-        Get all users in a specific collection with pagination.
-
-        Args:
-            collection_id (UUID): The ID of the collection to get users from.
-            offset (int): The number of users to skip.
-            limit (int): The maximum number of users to return.
-
-        Returns:
-            List[User]: A list of User objects representing the users in the collection.
-
-        Raises:
-            R2RException: If the collection doesn't exist.
-        """
-        if not await self._collection_exists(collection_id):  # type: ignore
+        """Get all users in a specific collection with pagination."""
+        if not await self._collection_exists(collection_id):
             raise R2RException(status_code=404, message="Collection not found")
 
-        query = f"""
-            SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
-                u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
-                COUNT(*) OVER() AS total_entries
-            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
-            WHERE $1 = ANY(u.collection_ids)
-            ORDER BY u.name
-            OFFSET $2
-        """
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_active",
+                    "is_superuser",
+                    "created_at",
+                    "updated_at",
+                    "is_verified",
+                    "collection_ids",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                    "hashed_password",
+                    "limits_overrides",
+                    "COUNT(*) OVER() AS total_entries",
+                ]
+            )
+            .where("$1 = ANY(collection_ids)")
+            .order_by("name")
+            .offset("$2")
+            .limit("$3" if limit != -1 else None)
+            .build()
+        )
 
         conditions = [collection_id, offset]
         if limit != -1:
-            query += " LIMIT $3"
             conditions.append(limit)
 
         results = await self.connection_manager.fetch_query(query, conditions)
 
-        users = [
+        users_list = [
             User(
                 id=row["id"],
                 email=row["email"],
@@ -499,24 +610,24 @@ class PostgresUserHandler(Handler):
                 created_at=row["created_at"],
                 updated_at=row["updated_at"],
                 is_verified=row["is_verified"],
-                collection_ids=row["collection_ids"],
+                collection_ids=row["collection_ids"] or [],
                 name=row["name"],
                 bio=row["bio"],
                 profile_picture=row["profile_picture"],
-                hashed_password=None,
-                verification_code_expiry=None,
+                hashed_password=row["hashed_password"],
+                limits_overrides=json.loads(row["limits_overrides"] or "{}"),
             )
             for row in results
         ]
 
         total_entries = results[0]["total_entries"] if results else 0
-
-        return {"results": users, "total_entries": total_entries}
+        return {"results": users_list, "total_entries": total_entries}
 
     async def mark_user_as_superuser(self, id: UUID):
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            SET is_superuser = TRUE, is_verified = TRUE,
+                verification_code = NULL, verification_code_expiry = NULL
             WHERE id = $1
         """
         await self.connection_manager.execute_query(query, [id])
@@ -542,7 +653,9 @@ class PostgresUserHandler(Handler):
     async def mark_user_as_verified(self, id: UUID):
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            SET is_verified = TRUE,
+                verification_code = NULL,
+                verification_code_expiry = NULL
             WHERE id = $1
         """
         await self.connection_manager.execute_query(query, [id])
@@ -553,7 +666,9 @@ class PostgresUserHandler(Handler):
         limit: int,
         user_ids: Optional[list[UUID]] = None,
     ) -> dict[str, list[User] | int]:
-
+        """
+        Return users with document usage and total entries.
+        """
         query = f"""
             WITH user_document_ids AS (
                 SELECT
@@ -604,36 +719,36 @@ class PostgresUserHandler(Handler):
             params.append(user_ids)
 
         results = await self.connection_manager.fetch_query(query, params)
+        if not results:
+            raise R2RException(status_code=404, message="No users found")
 
-        users = [
-            User(
-                id=row["id"],
-                email=row["email"],
-                is_superuser=row["is_superuser"],
-                is_active=row["is_active"],
-                is_verified=row["is_verified"],
-                name=row["name"],
-                bio=row["bio"],
-                created_at=row["created_at"],
-                updated_at=row["updated_at"],
-                collection_ids=row["collection_ids"] or [],
-                num_files=row["num_files"],
-                total_size_in_bytes=row["total_size_in_bytes"],
-                document_ids=(
-                    []
-                    if row["document_ids"] is None
-                    else list(row["document_ids"])
-                ),
+        users_list = []
+        for row in results:
+            users_list.append(
+                User(
+                    id=row["id"],
+                    email=row["email"],
+                    is_superuser=row["is_superuser"],
+                    is_active=row["is_active"],
+                    is_verified=row["is_verified"],
+                    name=row["name"],
+                    bio=row["bio"],
+                    created_at=row["created_at"],
+                    updated_at=row["updated_at"],
+                    profile_picture=row["profile_picture"],
+                    collection_ids=row["collection_ids"] or [],
+                    num_files=row["num_files"],
+                    total_size_in_bytes=row["total_size_in_bytes"],
+                    document_ids=(
+                        list(row["document_ids"])
+                        if row["document_ids"]
+                        else []
+                    ),
+                )
             )
-            for row in results
-        ]
-
-        if not users:
-            raise R2RException(status_code=404, message="No users found")
 
         total_entries = results[0]["total_entries"]
-
-        return {"results": users, "total_entries": total_entries}
+        return {"results": users_list, "total_entries": total_entries}
 
     async def _collection_exists(self, collection_id: UUID) -> bool:
         """Check if a collection exists."""
@@ -693,7 +808,7 @@ class PostgresUserHandler(Handler):
         hashed_key: str,
         name: Optional[str] = None,
     ) -> UUID:
-        """Store a new API key for a user"""
+        """Store a new API key for a user."""
         query = f"""
             INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             (user_id, public_key, hashed_key, name)
@@ -710,7 +825,10 @@ class PostgresUserHandler(Handler):
         return result["id"]
 
     async def get_api_key_record(self, key_id: str) -> Optional[dict]:
-        """Get API key record and update updated_at"""
+        """
+        Get API key record by 'public_key' and update 'updated_at' to now.
+        Returns { "user_id", "hashed_key" } or None if not found.
+        """
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             SET updated_at = NOW()
@@ -726,7 +844,7 @@ class PostgresUserHandler(Handler):
         }
 
     async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
-        """Get all API keys for a user"""
+        """Get all API keys for a user."""
         query = f"""
             SELECT id, public_key, name, created_at, updated_at
             FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
@@ -745,7 +863,7 @@ class PostgresUserHandler(Handler):
         ]
 
     async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
-        """Delete a specific API key"""
+        """Delete a specific API key."""
         query = f"""
             DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             WHERE id = $1 AND user_id = $2
@@ -766,7 +884,7 @@ class PostgresUserHandler(Handler):
     async def update_api_key_name(
         self, user_id: UUID, key_id: UUID, name: str
     ) -> bool:
-        """Update the name of an API key"""
+        """Update the name of an existing API key."""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             SET name = $1, updated_at = NOW()

+ 1 - 1
core/examples/hello_r2r.py

@@ -1,6 +1,6 @@
 from r2r import R2RClient
 
-client = R2RClient("http://localhost:7272")
+client = R2RClient()
 
 with open("test.txt", "w") as file:
     file.write("John is a person that works at Google.")

+ 6 - 6
core/main/abstractions.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
 
 from pydantic import BaseModel
 
@@ -107,8 +107,8 @@ class R2RAgents(BaseModel):
 
 @dataclass
 class R2RServices:
-    auth: Optional["AuthService"] = None
-    ingestion: Optional["IngestionService"] = None
-    management: Optional["ManagementService"] = None
-    retrieval: Optional["RetrievalService"] = None
-    graph: Optional["GraphService"] = None
+    auth: "AuthService"
+    ingestion: "IngestionService"
+    management: "ManagementService"
+    retrieval: "RetrievalService"
+    graph: "GraphService"

+ 57 - 104
core/main/api/v3/base_router.py

@@ -1,7 +1,7 @@
 import functools
 import logging
 from abc import abstractmethod
-from typing import Callable
+from typing import Callable, Optional
 
 from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
 from fastapi.responses import StreamingResponse
@@ -15,100 +15,19 @@ logger = logging.getLogger()
 
 class BaseRouterV3:
     def __init__(self, providers: R2RProviders, services: R2RServices):
+        """
+        :param providers: Typically includes auth, database, etc.
+        :param services: Additional service references (ingestion, run_manager, etc).
+        """
         self.providers = providers
         self.services = services
         self.router = APIRouter()
         self.openapi_extras = self._load_openapi_extras()
-        self._setup_routes()
-        self._register_workflows()
-
-    def get_router(self):
-        return self.router
-
-    def base_endpoint(self, func: Callable):
-        @functools.wraps(func)
-        async def wrapper(*args, **kwargs):
-            async with manage_run(
-                self.services.ingestion.run_manager, func.__name__
-            ) as run_id:
-                auth_user = kwargs.get("auth_user")
-                if auth_user:
-                    await self.services.ingestion.run_manager.log_run_info(  # TODO - this is a bit of a hack
-                        user=auth_user,
-                    )
-
-                try:
-                    func_result = await func(*args, **kwargs)
-                    if (
-                        isinstance(func_result, tuple)
-                        and len(func_result) == 2
-                    ):
-                        results, outer_kwargs = func_result
-                    else:
-                        results, outer_kwargs = func_result, {}
-
-                    if isinstance(results, StreamingResponse):
-                        return results
-                    return {"results": results, **outer_kwargs}
-
-                except R2RException:
-                    raise
-
-                except Exception as e:
 
-                    logger.error(
-                        f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
-                        exc_info=True,
-                    )
-
-                    raise HTTPException(
-                        status_code=500,
-                        detail={
-                            "message": f"An error '{e}' occurred during {func.__name__}",
-                            "error": str(e),
-                            "error_type": type(e).__name__,
-                        },
-                    ) from e
-
-        return wrapper
-
-    @classmethod
-    def build_router(cls, engine):
-        return cls(engine).router
-
-    def _register_workflows(self):
-        pass
-
-    def _load_openapi_extras(self):
-        return {}
-
-    @abstractmethod
-    def _setup_routes(self):
-        pass
-
-
-import functools
-import logging
-from abc import abstractmethod
-from typing import Callable, Optional
-
-from fastapi import APIRouter, Depends, HTTPException, Request
-from fastapi.responses import StreamingResponse
-
-from core.base import R2RException, manage_run
-
-from ...abstractions import R2RProviders, R2RServices
-
-logger = logging.getLogger()
-
-
-class BaseRouterV3:
-    def __init__(self, providers: R2RProviders, services: R2RServices):
-        self.providers = providers
-        self.services = services
-        self.router = APIRouter()
-        self.openapi_extras = self._load_openapi_extras()
+        # Add the rate-limiting dependency
         self.set_rate_limiting()
+
+        # Initialize any routes
         self._setup_routes()
         self._register_workflows()
 
@@ -116,6 +35,13 @@ class BaseRouterV3:
         return self.router
 
     def base_endpoint(self, func: Callable):
+        """
+        A decorator to wrap endpoints in a standard pattern:
+         - manage_run context
+         - error handling
+         - response shaping
+        """
+
         @functools.wraps(func)
         async def wrapper(*args, **kwargs):
             async with manage_run(
@@ -123,6 +49,7 @@ class BaseRouterV3:
             ) as run_id:
                 auth_user = kwargs.get("auth_user")
                 if auth_user:
+                    # Optionally log run info with the user
                     await self.services.ingestion.run_manager.log_run_info(
                         user=auth_user,
                     )
@@ -143,13 +70,11 @@ class BaseRouterV3:
 
                 except R2RException:
                     raise
-
                 except Exception as e:
                     logger.error(
-                        f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
+                        f"Error in base endpoint {func.__name__}() - {str(e)}",
                         exc_info=True,
                     )
-
                     raise HTTPException(
                         status_code=500,
                         detail={
@@ -163,6 +88,9 @@ class BaseRouterV3:
 
     @classmethod
     def build_router(cls, engine):
+        """
+        Class method for building a router instance (if you have a standard pattern).
+        """
         return cls(engine).router
 
     def _register_workflows(self):
@@ -173,48 +101,73 @@ class BaseRouterV3:
 
     @abstractmethod
     def _setup_routes(self):
+        """
+        Subclasses override this to define actual endpoints.
+        """
         pass
 
     def set_rate_limiting(self):
         """
-        Set up a yield dependency for rate limiting and logging.
+        Adds a yield-based dependency for rate limiting each request.
+        Checks the limits, then logs the request if the check passes.
         """
 
         async def rate_limit_dependency(
             request: Request,
             auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
+            """
+            1) Fetch the user from the DB (including .limits_overrides).
+            2) Pass it to limits_handler.check_limits.
+            3) After the endpoint completes, call limits_handler.log_request.
+            """
+            # If the user is superuser, skip checks
+            if auth_user.is_superuser:
+                yield
+                return
+
             user_id = auth_user.id
             route = request.scope["path"]
-            # Check the limits before proceeding
+
+            # 1) Fetch the user from DB
+            user = await self.providers.database.users_handler.get_user_by_id(
+                user_id
+            )
+            if not user:
+                raise HTTPException(status_code=404, detail="User not found.")
+
+            # 2) Rate-limit check
             try:
-                if not auth_user.is_superuser:
-                    await self.providers.database.limits_handler.check_limits(
-                        user_id, route
-                    )
+                await self.providers.database.limits_handler.check_limits(
+                    user=user, route=route  # Pass the User object
+                )
             except ValueError as e:
+                # If check_limits raises ValueError -> 429 Too Many Requests
                 raise HTTPException(status_code=429, detail=str(e))
 
             request.state.user_id = user_id
             request.state.route = route
-            # Yield to run the route
+
+            # 3) Execute the route
             try:
                 yield
             finally:
-                # After the route completes successfully, log the request
+                # 4) Log the request afterwards
                 await self.providers.database.limits_handler.log_request(
                     user_id, route
                 )
 
-        async def websocket_rate_limit_dependency(
-            websocket: WebSocket,
-        ):
+        async def websocket_rate_limit_dependency(websocket: WebSocket):
+            # Example: if you want to rate-limit websockets similarly
             route = websocket.scope["path"]
+            # If you had a user or token, you'd do the same check.
             try:
+                # e.g. check_limits(user_id, route)
                 return True
-            except ValueError as e:
+            except ValueError:
                 await websocket.close(code=4429, reason="Rate limit exceeded")
                 return False
 
+        # Attach the dependencies so you can use them in your endpoints
         self.rate_limit_dependency = rate_limit_dependency
         self.websocket_rate_limit_dependency = websocket_rate_limit_dependency

+ 9 - 9
core/main/api/v3/chunks_router.py

@@ -55,7 +55,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.search(
                                 query="search query",
                                 search_settings={
@@ -110,7 +110,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.retrieve(
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                             )
@@ -123,7 +123,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.chunks.retrieve({
@@ -183,7 +183,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.update(
                                 {
                                     "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
@@ -200,7 +200,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.chunks.update({
@@ -276,7 +276,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.delete(
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                             )
@@ -289,7 +289,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.chunks.delete({
@@ -347,7 +347,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.list(
                                 metadata_filter={"key": "value"},
                                 include_vectors=False,
@@ -363,7 +363,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.chunks.list({

+ 23 - 23
core/main/api/v3/collections_router.py

@@ -101,7 +101,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.create(
@@ -117,7 +117,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.create({
@@ -189,7 +189,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.list(
@@ -205,7 +205,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.list();
@@ -298,7 +298,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000")
@@ -311,7 +311,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -387,7 +387,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.update(
@@ -404,7 +404,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.update({
@@ -485,7 +485,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000")
@@ -498,7 +498,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -562,7 +562,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.add_document(
@@ -578,7 +578,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.addDocument({
@@ -634,7 +634,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.list_documents(
@@ -651,7 +651,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -733,7 +733,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.remove_document(
@@ -749,7 +749,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.removeDocument({
@@ -811,7 +811,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.list_users(
@@ -828,7 +828,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.listUsers({
@@ -912,7 +912,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.add_user(
@@ -928,7 +928,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.addUser({
@@ -990,7 +990,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.collections.remove_user(
@@ -1006,7 +1006,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.collections.removeUser({
@@ -1070,7 +1070,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.documents.extract(

+ 14 - 14
core/main/api/v3/conversations_router.py

@@ -42,7 +42,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.create()
@@ -55,7 +55,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.create();
@@ -116,7 +116,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.list(
@@ -132,7 +132,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.list();
@@ -218,7 +218,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.get(
@@ -233,7 +233,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.retrieve({
@@ -299,7 +299,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
@@ -312,7 +312,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.update({
@@ -382,7 +382,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000")
@@ -395,7 +395,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.delete({
@@ -462,7 +462,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.add_message(
@@ -481,7 +481,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.addMessage({
@@ -558,7 +558,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.conversations.update_message(
@@ -575,7 +575,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.conversations.updateMessage({

+ 38 - 36
core/main/api/v3/documents_router.py

@@ -198,7 +198,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.create(
@@ -215,7 +215,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.create({
@@ -558,7 +558,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.list(
@@ -574,7 +574,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.list({
@@ -680,7 +680,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.retrieve(
@@ -695,7 +695,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.retrieve({
@@ -776,7 +776,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.list_chunks(
@@ -791,7 +791,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.listChunks({
@@ -910,7 +910,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.download(
@@ -925,7 +925,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.download({
@@ -1053,7 +1053,7 @@ class DocumentsRouter(BaseRouterV3):
                         "source": textwrap.dedent(
                             """
                             from r2r import R2RClient
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             response = client.documents.delete_by_filter(
                                 filters={"document_type": {"$eq": "txt"}}
@@ -1105,7 +1105,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.delete(
@@ -1120,7 +1120,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.delete({
@@ -1186,7 +1186,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.list_collections(
@@ -1201,7 +1201,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.listCollections({
@@ -1291,7 +1291,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.extract(
@@ -1403,7 +1403,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.extract(
@@ -1477,14 +1477,15 @@ class DocumentsRouter(BaseRouterV3):
                 raise R2RException("Document not found.", 404)
 
             # Get all entities for this document from the document_entity table
-            entities, count = (
-                await self.providers.database.graphs_handler.entities.get(
-                    parent_id=id,
-                    store_type="documents",
-                    offset=offset,
-                    limit=limit,
-                    include_embeddings=include_embeddings,
-                )
+            (
+                entities,
+                count,
+            ) = await self.providers.database.graphs_handler.entities.get(
+                parent_id=id,
+                store_type="documents",
+                offset=offset,
+                limit=limit,
+                include_embeddings=include_embeddings,
             )
 
             return entities, {"total_entries": count}  # type: ignore
@@ -1501,7 +1502,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.documents.list_relationships(
@@ -1518,7 +1519,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.documents.listRelationships({
@@ -1618,15 +1619,16 @@ class DocumentsRouter(BaseRouterV3):
                 raise R2RException("Document not found.", 404)
 
             # Get relationships for this document
-            relationships, count = (
-                await self.providers.database.graphs_handler.relationships.get(
-                    parent_id=id,
-                    store_type="documents",
-                    entity_names=entity_names,
-                    relationship_types=relationship_types,
-                    offset=offset,
-                    limit=limit,
-                )
+            (
+                relationships,
+                count,
+            ) = await self.providers.database.graphs_handler.relationships.get(
+                parent_id=id,
+                store_type="documents",
+                entity_names=entity_names,
+                relationship_types=relationship_types,
+                offset=offset,
+                limit=limit,
             )
 
             return relationships, {"total_entries": count}  # type: ignore

+ 34 - 36
core/main/api/v3/graph_router.py

@@ -40,7 +40,6 @@ class GraphRouter(BaseRouterV3):
         self._register_workflows()
 
     def _register_workflows(self):
-
         workflow_messages = {}
         if self.providers.orchestration.config.provider == "hatchet":
             workflow_messages["extract-triples"] = (
@@ -164,7 +163,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.list()
@@ -177,7 +176,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.list({});
@@ -247,7 +246,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.get(
@@ -261,7 +260,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.retrieve({
@@ -386,7 +385,6 @@ class GraphRouter(BaseRouterV3):
             }
 
             if run_with_orchestration:
-
                 return await self.providers.orchestration.run_workflow(  # type: ignore
                     "build-communities", {"request": workflow_input}, {}
                 )
@@ -413,7 +411,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.reset(
@@ -427,7 +425,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.reset({
@@ -493,7 +491,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.update(
@@ -511,7 +509,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.update({
@@ -579,10 +577,10 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
-                            response = client.graphs.get_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+                            response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
                             """
                         ),
                     },
@@ -592,10 +590,10 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
-                                const response = await client.graphs.get_entities({
+                                const response = await client.graphs.listEntities({
                                     collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
                                 });
                             }
@@ -767,7 +765,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.get_entity(
@@ -783,7 +781,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.get_entity({
@@ -894,7 +892,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.remove_entity(
@@ -910,7 +908,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.removeEntity({
@@ -973,7 +971,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
@@ -986,7 +984,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.listRelationships({
@@ -1055,7 +1053,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.get_relationship(
@@ -1071,7 +1069,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.getRelationship({
@@ -1202,7 +1200,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.delete_relationship(
@@ -1218,7 +1216,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.deleteRelationship({
@@ -1280,7 +1278,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.create_community(
@@ -1300,7 +1298,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.createCommunity({
@@ -1389,7 +1387,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
@@ -1402,7 +1400,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.listCommunities({
@@ -1471,7 +1469,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
@@ -1484,7 +1482,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.getCommunity({
@@ -1549,7 +1547,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.delete_community(
@@ -1565,7 +1563,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.graphs.deleteCommunity({
@@ -1629,7 +1627,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.update_community(
@@ -1649,7 +1647,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             async function main() {
                                 const response = await client.graphs.updateCommunity({
@@ -1724,7 +1722,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response = client.graphs.pull(
@@ -1738,7 +1736,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             async function main() {
                                 const response = await client.graphs.pull({

+ 9 - 11
core/main/api/v3/indices_router.py

@@ -23,7 +23,6 @@ logger = logging.getLogger()
 
 
 class IndicesRouter(BaseRouterV3):
-
     def __init__(
         self,
         providers: R2RProviders,
@@ -32,7 +31,6 @@ class IndicesRouter(BaseRouterV3):
         super().__init__(providers, services)
 
     def _setup_routes(self):
-
         ## TODO - Allow developer to pass the index id with the request
         @self.router.post(
             "/indices",
@@ -46,7 +44,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             # Create an HNSW index for efficient similarity search
@@ -91,7 +89,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.indicies.create({
@@ -246,7 +244,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
                             # List all indices
                             indices = client.indices.list(
@@ -262,7 +260,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.indicies.list({
@@ -350,7 +348,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
                             # Get detailed information about a specific index
                             index = client.indices.retrieve("index_1")
@@ -363,7 +361,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.indicies.retrieve({
@@ -454,7 +452,7 @@ class IndicesRouter(BaseRouterV3):
         #                         "source": """
         # from r2r import R2RClient
 
-        # client = R2RClient("http://localhost:7272")
+        # client = R2RClient()
 
         # # Update HNSW index parameters
         # result = client.indices.update(
@@ -514,7 +512,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
                             # Delete an index with orchestration for cleanup
                             result = client.indices.delete(
@@ -531,7 +529,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.indicies.delete({

+ 10 - 10
core/main/api/v3/prompts_router.py

@@ -38,7 +38,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.prompts.create(
@@ -55,7 +55,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.prompts.create({
@@ -122,7 +122,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.prompts.list()
@@ -135,7 +135,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.prompts.list();
@@ -202,7 +202,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.prompts.get(
@@ -219,7 +219,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.prompts.retrieve({
@@ -292,7 +292,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.prompts.update(
@@ -309,7 +309,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.prompts.update({
@@ -376,7 +376,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.prompts.delete("greeting_prompt")
@@ -389,7 +389,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.prompts.delete({

+ 10 - 11
core/main/api/v3/retrieval_router.py

@@ -82,7 +82,6 @@ class RetrievalRouterV3(BaseRouterV3):
         return effective_settings
 
     def _setup_routes(self):
-
         @self.router.post(
             "/retrieval/search",
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -95,7 +94,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # if using auth, do client.login(...)
 
                             # Basic mode, no overrides
@@ -135,7 +134,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.search({
@@ -278,7 +277,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response =client.retrieval.rag(
@@ -309,7 +308,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.retrieval.rag({
@@ -464,7 +463,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                         from r2r import R2RClient
 
-                        client = R2RClient("http://localhost:7272")
+                        client = R2RClient()
                         # when using auth, do client.login(...)
 
                         response =client.retrieval.agent(
@@ -500,7 +499,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.retrieval.agent({
@@ -693,7 +692,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             response =client.completion(
@@ -719,7 +718,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.completion({
@@ -830,7 +829,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.retrieval.embedding(
@@ -845,7 +844,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.retrieval.embedding({

+ 6 - 6
core/main/api/v3/system_router.py

@@ -39,7 +39,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.system.health()
@@ -52,7 +52,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.system.health();
@@ -98,7 +98,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.system.settings()
@@ -111,7 +111,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.system.settings();
@@ -164,7 +164,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
 
                             result = client.system.status()
@@ -177,7 +177,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.system.status();

+ 49 - 40
core/main/api/v3/users_router.py

@@ -1,5 +1,5 @@
 import textwrap
-from typing import Optional
+from typing import Optional, Union
 from uuid import UUID
 
 from fastapi import Body, Depends, Path, Query
@@ -19,6 +19,7 @@ from core.base.api.models import (
     WrappedUserResponse,
     WrappedUsersResponse,
 )
+from core.base.providers.database import LimitSettings
 
 from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
@@ -31,7 +32,6 @@ class UsersRouter(BaseRouterV3):
         super().__init__(providers, services)
 
     def _setup_routes(self):
-
         @self.router.post(
             "/users",
             # dependencies=[Depends(self.rate_limit_dependency)],
@@ -44,7 +44,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             new_user = client.users.create(
                                 email="jane.doe@example.com",
                                 password="secure_password123"
@@ -57,7 +57,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.create({
@@ -152,7 +152,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             new_user = client.users.register(
                                 email="jane.doe@example.com",
                                 password="secure_password123"
@@ -165,7 +165,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.register({
@@ -221,7 +221,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             tokens = client.users.verify_email(
                                 email="jane.doe@example.com",
                                 verification_code="1lklwal!awdclm"
@@ -234,7 +234,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.verifyEmail({
@@ -296,7 +296,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             tokens = client.users.login(
                                 email="jane.doe@example.com",
                                 password="secure_password123"
@@ -310,7 +310,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.login({
@@ -354,7 +354,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             result = client.users.logout()
                             """
@@ -366,7 +366,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.logout();
@@ -408,7 +408,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             new_tokens = client.users.refresh_token()
@@ -421,7 +421,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.refreshAccessToken();
@@ -467,7 +467,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             result = client.users.change_password(
@@ -482,7 +482,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.changePassword({
@@ -537,7 +537,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             result = client.users.request_password_reset(
                                 email="jane.doe@example.com"
                             )"""
@@ -549,7 +549,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.requestPasswordReset({
@@ -595,7 +595,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             result = client.users.reset_password(
                                 reset_token="reset_token_received_via_email",
                                 new_password="new_secure_password789"
@@ -608,7 +608,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.resetPassword({
@@ -659,7 +659,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # List users with filters
@@ -676,7 +676,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.list();
@@ -766,7 +766,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Get user details
@@ -780,7 +780,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.retrieve();
@@ -831,7 +831,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Get user details
@@ -847,7 +847,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.retrieve({
@@ -918,7 +918,7 @@ class UsersRouter(BaseRouterV3):
                             """
                         from r2r import R2RClient
 
-                        client = R2RClient("http://localhost:7272")
+                        client = R2RClient()
                         # client.login(...)
 
                         # Delete user
@@ -932,7 +932,7 @@ class UsersRouter(BaseRouterV3):
                             """
                         const { r2rClient } = require("r2r-js");
 
-                        const client = new r2rClient("http://localhost:7272");
+                        const client = new r2rClient();
 
                         function main() {
                             const response = await client.users.delete({
@@ -992,7 +992,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Get user collections
@@ -1010,7 +1010,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.listCollections({
@@ -1095,7 +1095,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Add user to collection
@@ -1112,7 +1112,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.addToCollection({
@@ -1179,7 +1179,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Remove user from collection
@@ -1196,7 +1196,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.removeFromCollection({
@@ -1267,7 +1267,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             # Update user
@@ -1284,7 +1284,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             const { r2rClient } = require("r2r-js");
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
                             function main() {
                                 const response = await client.users.update({
@@ -1329,6 +1329,10 @@ class UsersRouter(BaseRouterV3):
             profile_picture: str | None = Body(
                 None, description="Updated profile picture URL"
             ),
+            limits_overrides: dict = Body(
+                None,
+                description="Updated limits overrides",
+            ),
             auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
             """
@@ -1347,7 +1351,11 @@ class UsersRouter(BaseRouterV3):
                     "Only superusers can update other users' information",
                     403,
                 )
-
+            if not auth_user.is_superuser and limits_overrides is not None:
+                raise R2RException(
+                    "Only superusers can update other users' limits overrides",
+                    403,
+                )
             return await self.services.auth.update_user(
                 user_id=id,
                 email=email,
@@ -1355,6 +1363,7 @@ class UsersRouter(BaseRouterV3):
                 name=name,
                 bio=bio,
                 profile_picture=profile_picture,
+                limits_overrides=limits_overrides,
             )
 
         @self.router.post(
@@ -1370,7 +1379,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             result = client.users.create_api_key(
@@ -1424,7 +1433,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             from r2r import R2RClient
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             keys = client.users.list_api_keys(
@@ -1482,7 +1491,7 @@ class UsersRouter(BaseRouterV3):
                             from r2r import R2RClient
                             from uuid import UUID
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
 
                             response = client.users.delete_api_key(

+ 0 - 1
core/main/app.py

@@ -72,7 +72,6 @@ class R2RApp:
         self._apply_cors()
 
     def _setup_routes(self):
-
         self.app.include_router(self.chunks_router, prefix="/v3")
         self.app.include_router(self.collections_router, prefix="/v3")
         self.app.include_router(self.conversations_router, prefix="/v3")

+ 5 - 12
core/main/assembly/builder.py

@@ -72,21 +72,14 @@ class R2RBuilder:
         ).create_pipelines(*args, **kwargs)
 
     def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+        services = ["auth", "ingestion", "management", "retrieval", "graph"]
         service_instances = {}
-        for service_type, override in vars(R2RServices()).items():
-            logger.info(f"Creating {service_type} service")
+
+        for service_type in services:
             service_class = globals()[f"{service_type.capitalize()}Service"]
-            service_instances[service_type] = override or service_class(
-                **service_params
-            )
+            service_instances[service_type] = service_class(**service_params)
 
-        return R2RServices(
-            auth=service_instances["auth"],
-            ingestion=service_instances["ingestion"],
-            management=service_instances["management"],
-            retrieval=service_instances["retrieval"],
-            graph=service_instances["graph"],
-        )
+        return R2RServices(**service_instances)
 
     async def _create_providers(
         self, provider_factory: Type[R2RProviderFactory], *args, **kwargs

+ 0 - 2
core/main/assembly/factory.py

@@ -66,7 +66,6 @@ class R2RProviderFactory:
         **kwargs,
     ) -> R2RAuthProvider | SupabaseAuthProvider:
         if auth_config.provider == "r2r":
-
             r2r_auth = R2RAuthProvider(
                 auth_config, crypto_provider, database_provider, email_provider
             )
@@ -106,7 +105,6 @@ class R2RProviderFactory:
         *args,
         **kwargs,
     ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
-
         config_dict = (
             ingestion_config.model_dump()
             if isinstance(ingestion_config, IngestionConfig)

+ 0 - 1
core/main/orchestration/hatchet/ingestion_workflow.py

@@ -216,7 +216,6 @@ def hatchet_ingestion_factory(
                     )
 
                 if chunk_enrichment_settings.enable_chunk_enrichment:
-
                     logger.info("Enriching document with contextual chunks")
 
                     # TODO: the status updating doesn't work because document_info doesn't contain information about collection IDs

+ 0 - 8
core/main/orchestration/hatchet/kg_workflow.py

@@ -23,7 +23,6 @@ if TYPE_CHECKING:
 def hatchet_kg_factory(
     orchestration_provider: OrchestrationProvider, service: GraphService
 ) -> dict[str, "Hatchet.Workflow"]:
-
     def convert_to_dict(input_data):
         """
         Converts input data back to a plain dictionary format, handling special cases like UUID and GenerationConfig.
@@ -72,7 +71,6 @@ def hatchet_kg_factory(
 
     def get_input_data_dict(input_data):
         for key, value in input_data.items():
-
             if value is None:
                 continue
 
@@ -212,7 +210,6 @@ def hatchet_kg_factory(
             retries=1, timeout="360m", parents=["kg_extract"]
         )
         async def kg_entity_description(self, context: Context) -> dict:
-
             input_data = get_input_data_dict(
                 context.workflow_input()["request"]
             )
@@ -259,7 +256,6 @@ def hatchet_kg_factory(
 
     @orchestration_provider.workflow(name="extract-triples", timeout="600m")
     class CreateGraphWorkflow:
-
         @orchestration_provider.concurrency(  # type: ignore
             max_runs=orchestration_provider.config.kg_concurrency_limit,  # type: ignore
             limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
@@ -319,7 +315,6 @@ def hatchet_kg_factory(
                 }
 
             else:
-
                 # Extract relationships and store them
                 extractions = []
                 async for extraction in self.kg_service.kg_extraction(
@@ -399,7 +394,6 @@ def hatchet_kg_factory(
         async def kg_entity_deduplication_setup(
             self, context: Context
         ) -> dict:
-
             input_data = get_input_data_dict(
                 context.workflow_input()["request"]
             )
@@ -467,7 +461,6 @@ def hatchet_kg_factory(
         async def kg_entity_deduplication_summary(
             self, context: Context
         ) -> dict:
-
             logger.info(
                 f"Running KG Entity Deduplication Summary for input data: {context.workflow_input()['request']}"
             )
@@ -660,7 +653,6 @@ def hatchet_kg_factory(
 
         @orchestration_provider.step(retries=1, timeout="360m")
         async def kg_community_summary(self, context: Context) -> dict:
-
             start_time = time.time()
 
             logger.info

+ 0 - 3
core/main/orchestration/simple/ingestion_workflow.py

@@ -101,7 +101,6 @@ def simple_ingestion_factory(service: IngestionService):
                         status=KGEnrichmentStatus.OUTDATED,  # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
                     )
                 else:
-
                     for collection_id in collection_ids:
                         try:
                             # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
@@ -332,7 +331,6 @@ def simple_ingestion_factory(service: IngestionService):
                 else:
                     for collection_id in collection_ids:
                         try:
-
                             name = document_info.title or "N/A"
                             description = ""
                             result = await service.providers.database.collections_handler.create_collection(
@@ -419,7 +417,6 @@ def simple_ingestion_factory(service: IngestionService):
             )
 
     async def create_vector_index(input_data):
-
         try:
             from core.main import IngestionServiceAdapter
 

+ 0 - 7
core/main/orchestration/simple/kg_workflow.py

@@ -12,10 +12,8 @@ logger = logging.getLogger()
 
 
 def simple_kg_factory(service: GraphService):
-
     def get_input_data_dict(input_data):
         for key, value in input_data.items():
-
             if type(value) == uuid.UUID:
                 continue
 
@@ -41,7 +39,6 @@ def simple_kg_factory(service: GraphService):
         return input_data
 
     async def extract_triples(input_data):
-
         input_data = get_input_data_dict(input_data)
 
         if input_data.get("document_id"):
@@ -105,7 +102,6 @@ def simple_kg_factory(service: GraphService):
                 raise e
 
     async def enrich_graph(input_data):
-
         input_data = get_input_data_dict(input_data)
         workflow_status = await service.providers.database.documents_handler.get_workflow_status(
             id=input_data.get("collection_id", None),
@@ -157,7 +153,6 @@ def simple_kg_factory(service: GraphService):
             )
 
         except Exception as e:
-
             await service.providers.database.documents_handler.set_workflow_status(
                 id=input_data.get("collection_id", None),
                 status_type="graph_cluster_status",
@@ -167,7 +162,6 @@ def simple_kg_factory(service: GraphService):
             raise e
 
     async def kg_community_summary(input_data):
-
         logger.info(
             f"Running kg community summary for offset: {input_data['offset']}, limit: {input_data['limit']}"
         )
@@ -181,7 +175,6 @@ def simple_kg_factory(service: GraphService):
         )
 
     async def entity_deduplication_workflow(input_data):
-
         # TODO: We should determine how we want to handle the input here and syncronize it across all simple orchestration methods
         if isinstance(input_data["graph_entity_deduplication_settings"], str):
             input_data["graph_entity_deduplication_settings"] = json.loads(

+ 3 - 0
core/main/services/auth_service.py

@@ -127,6 +127,7 @@ class AuthService(Service):
         name: Optional[str] = None,
         bio: Optional[str] = None,
         profile_picture: Optional[str] = None,
+        limits_overrides: Optional[dict] = None,
     ) -> User:
         user: User = (
             await self.providers.database.users_handler.get_user_by_id(user_id)
@@ -143,6 +144,8 @@ class AuthService(Service):
             user.bio = bio
         if profile_picture is not None:
             user.profile_picture = profile_picture
+        if limits_overrides is not None:
+            user.limits_overrides = limits_overrides
         return await self.providers.database.users_handler.update_user(user)
 
     @telemetry_event("DeleteUserAccount")

+ 0 - 10
core/main/services/graph_service.py

@@ -79,7 +79,6 @@ class GraphService(Service):
         **kwargs,
     ):
         try:
-
             logger.info(
                 f"KGService: Processing document {document_id} for KG extraction"
             )
@@ -138,7 +137,6 @@ class GraphService(Service):
         category: Optional[str] = None,
         metadata: Optional[dict] = None,
     ) -> Entity:
-
         description_embedding = str(
             await self.providers.embedding.async_get_embedding(description)
         )
@@ -162,7 +160,6 @@ class GraphService(Service):
         category: Optional[str] = None,
         metadata: Optional[dict] = None,
     ) -> Entity:
-
         description_embedding = None
         if description is not None:
             description_embedding = str(
@@ -272,7 +269,6 @@ class GraphService(Service):
         weight: Optional[float] = None,
         metadata: Optional[dict[str, Any] | str] = None,
     ) -> Relationship:
-
         description_embedding = None
         if description is not None:
             description_embedding = str(
@@ -471,7 +467,6 @@ class GraphService(Service):
         force_kg_creation: bool = False,
         **kwargs,
     ):
-
         document_status_filter = [
             KGExtractionStatus.PENDING,
             KGExtractionStatus.FAILED,
@@ -494,7 +489,6 @@ class GraphService(Service):
         max_description_input_length: int,
         **kwargs,
     ):
-
         start_time = time.time()
 
         logger.info(
@@ -568,7 +562,6 @@ class GraphService(Service):
         leiden_params: dict,
         **kwargs,
     ):
-
         logger.info(
             f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
         )
@@ -670,7 +663,6 @@ class GraphService(Service):
         graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
         **kwargs,
     ):
-
         if graph_id is None and collection_id is None:
             raise ValueError(
                 "Either graph_id or collection_id must be provided"
@@ -731,7 +723,6 @@ class GraphService(Service):
         generation_config: GenerationConfig,
         **kwargs,
     ):
-
         logger.info(
             f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
         )
@@ -1064,7 +1055,6 @@ class GraphService(Service):
                 entities_id_map[entity.name] = result.id
 
             if extraction.relationships:
-
                 for relationship in extraction.relationships:
                     await self.providers.database.graphs_handler.relationships.create(
                         subject=relationship.subject,

+ 0 - 1
core/main/services/ingestion_service.py

@@ -300,7 +300,6 @@ class IngestionService(Service):
     async def finalize_ingestion(
         self, document_info: DocumentResponse
     ) -> None:
-
         async def empty_generator():
             yield document_info
 

+ 0 - 1
core/main/services/management_service.py

@@ -455,7 +455,6 @@ class ManagementService(Service):
     async def remove_user_from_collection(
         self, user_id: UUID, collection_id: UUID
     ) -> bool:
-
         x = await self.providers.database.users_handler.remove_user_from_collection(
             user_id, collection_id
         )

+ 7 - 3
core/parsers/media/bmp_parser.py

@@ -34,9 +34,13 @@ class BMPParser(AsyncParser[str | bytes]):
             header_size = self.struct.calcsize(header_format)
 
             # Unpack header data
-            signature, file_size, reserved, reserved2, data_offset = (
-                self.struct.unpack(header_format, data[:header_size])
-            )
+            (
+                signature,
+                file_size,
+                reserved,
+                reserved2,
+                data_offset,
+            ) = self.struct.unpack(header_format, data[:header_size])
 
             # DIB header
             dib_format = "<IiiHHIIiiII"

+ 25 - 25
core/pipes/kg/community_summary.py

@@ -52,7 +52,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         relationships: list[Relationship],
         max_summary_input_length: int,
     ):
-
         entity_map: dict[str, dict[str, list[Any]]] = {}
         for entity in entities:
             if not entity.name in entity_map:
@@ -172,7 +171,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         )
 
         for attempt in range(3):
-
             description = (
                 (
                     await self.llm_provider.aget_completion(
@@ -268,34 +266,37 @@ class GraphCommunitySummaryPipe(AsyncPipe):
             f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
         )
 
-        all_entities, _ = (
-            await self.database_provider.graphs_handler.get_entities(
-                parent_id=collection_id,
-                offset=0,
-                limit=-1,
-                include_embeddings=False,
-            )
+        (
+            all_entities,
+            _,
+        ) = await self.database_provider.graphs_handler.get_entities(
+            parent_id=collection_id,
+            offset=0,
+            limit=-1,
+            include_embeddings=False,
         )
 
-        all_relationships, _ = (
-            await self.database_provider.graphs_handler.get_relationships(
-                parent_id=collection_id,
-                offset=0,
-                limit=-1,
-                include_embeddings=False,
-            )
+        (
+            all_relationships,
+            _,
+        ) = await self.database_provider.graphs_handler.get_relationships(
+            parent_id=collection_id,
+            offset=0,
+            limit=-1,
+            include_embeddings=False,
         )
 
         # Perform clustering
         leiden_params = input.message.get("leiden_params", {})
-        _, community_clusters = (
-            await self.database_provider.graphs_handler._cluster_and_add_community_info(
-                relationships=all_relationships,
-                relationship_ids_cache={},
-                leiden_params=leiden_params,
-                collection_id=collection_id,
-                clustering_mode=clustering_mode,
-            )
+        (
+            _,
+            community_clusters,
+        ) = await self.database_provider.graphs_handler._cluster_and_add_community_info(
+            relationships=all_relationships,
+            relationship_ids_cache={},
+            leiden_params=leiden_params,
+            collection_id=collection_id,
+            clustering_mode=clustering_mode,
         )
 
         # Organize clusters
@@ -330,7 +331,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         total_errors = 0
         completed_community_summary_jobs = 0
         for community_summary in asyncio.as_completed(community_summary_jobs):
-
             summary = await community_summary
             completed_community_summary_jobs += 1
             if completed_community_summary_jobs % 50 == 0:

+ 0 - 1
core/pipes/kg/deduplication.py

@@ -63,7 +63,6 @@ class GraphDeduplicationPipe(AsyncPipe):
     async def kg_named_entity_deduplication(
         self, graph_id: UUID | None, collection_id: UUID | None, **kwargs
     ):
-
         import numpy as np
 
         entities = await self._get_entities(graph_id, collection_id)

+ 0 - 6
core/pipes/kg/deduplication_summary.py

@@ -19,7 +19,6 @@ logger = logging.getLogger()
 
 
 class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
-
     class Input(AsyncPipe.Input):
         message: dict
 
@@ -48,7 +47,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         entity_descriptions: list[str],
         generation_config: GenerationConfig,
     ) -> Entity:
-
         # find the index until the length is less than 1024
         index = 0
         description_length = 0
@@ -89,7 +87,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         entity_descriptions: list[str],
         generation_config: GenerationConfig,
     ) -> Entity:
-
         # TODO: Expose this as a hyperparameter
         if len(entity_descriptions) <= 5:
             return Entity(
@@ -103,7 +100,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
     async def _prepare_and_upsert_entities(
         self, entities_batch: list[Entity], graph_id: UUID
     ) -> Any:
-
         embeddings = await self.embedding_provider.async_get_embeddings(
             [entity.description or "" for entity in entities_batch]
         )
@@ -135,7 +131,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         limit: int,
         level,
     ):
-
         if graph_id is not None:
             return await self.database_provider.graphs_handler.entities.get(
                 parent_id=graph_id,
@@ -235,7 +230,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
                 tasks = []
 
         if tasks:
-
             entities_batch = await asyncio.gather(*tasks)
             for entity in entities_batch:
                 yield entity

+ 0 - 1
core/pipes/kg/extraction.py

@@ -212,7 +212,6 @@ class GraphExtractionPipe(AsyncPipe[dict]):
         *args: Any,
         **kwargs: Any,
     ) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]:
-
         start_time = time.time()
 
         document_id = input.message["document_id"]

+ 0 - 1
core/pipes/kg/storage.py

@@ -49,7 +49,6 @@ class GraphStoragePipe(AsyncPipe):
         total_entities, total_relationships = 0, 0
 
         for extraction in kg_extractions:
-
             total_entities, total_relationships = (
                 total_entities + len(extraction.entities),
                 total_relationships + len(extraction.relationships),

+ 0 - 1
core/pipes/retrieval/chunk_search_pipe.py

@@ -65,7 +65,6 @@ class VectorSearchPipe(SearchPipe):
             search_settings.use_fulltext_search
             and search_settings.use_semantic_search
         ) or search_settings.use_hybrid_search:
-
             search_results = (
                 await self.database_provider.chunks_handler.hybrid_search(
                     query_vector=query_vector,

+ 0 - 1
core/pipes/retrieval/graph_search_pipe.py

@@ -262,6 +262,5 @@ class GraphSearchSearchPipe(GeneratorPipe):
         *args: Any,
         **kwargs: Any,
     ) -> AsyncGenerator[GraphSearchResult, None]:
-
         async for result in self.search(input, state, run_id, search_settings):
             yield result

+ 0 - 1
core/providers/auth/r2r_auth.py

@@ -361,7 +361,6 @@ class R2RAuthProvider(AuthProvider):
 
     async def request_password_reset(self, email: str) -> dict[str, str]:
         try:
-
             user = (
                 await self.database_provider.users_handler.get_user_by_email(
                     email=email

+ 1 - 1
core/providers/crypto/bcrypt.py

@@ -116,7 +116,7 @@ class BCryptCryptoProvider(CryptoProvider, ABC):
 
         # Generate unique key_id
         key_entropy = nacl.utils.random(16)
-        key_id = f"key_{base64.urlsafe_b64encode(key_entropy).decode()}"
+        key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
 
         private_key = base64.b64encode(bytes(signing_key)).decode()
         public_key = base64.b64encode(bytes(verify_key)).decode()

+ 26 - 8
core/providers/crypto/nacl.py

@@ -19,14 +19,26 @@ from core.base import CryptoConfig, CryptoProvider
 DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM"  # Replace or load from env or secrets manager
 
 
+def encode_bytes_readable(random_bytes: bytes, chars: str) -> str:
+    """Convert random bytes to a readable string using the given character set."""
+    # Each byte gives us 8 bits of randomness
+    # We use modulo to map each byte to our character set
+    result = []
+    for byte in random_bytes:
+        # Use modulo to map the byte (0-255) to our character set length
+        idx = byte % len(chars)
+        result.append(chars[idx])
+    return "".join(result)
+
+
 class NaClCryptoConfig(CryptoConfig):
     provider: str = "nacl"
     # Interactive parameters for password ops (fast)
-    ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
-    mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
+    ops_limit: int = argon2i.OPSLIMIT_MIN
+    mem_limit: int = argon2i.MEMLIMIT_MIN
     # Sensitive parameters for API key generation (slow but more secure)
-    api_ops_limit: int = argon2i.OPSLIMIT_SENSITIVE
-    api_mem_limit: int = argon2i.MEMLIMIT_SENSITIVE
+    api_ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
+    api_mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
     api_key_bytes: int = 32
     secret_key: Optional[str] = None
 
@@ -72,14 +84,20 @@ class NaClCryptoProvider(CryptoProvider):
         return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
 
     def generate_api_key(self) -> Tuple[str, str]:
+
+        # Define our character set (excluding ambiguous characters)
+        chars = string.ascii_letters.replace("l", "").replace("I", "").replace(
+            "O", ""
+        ) + string.digits.replace("0", "").replace("1", "")
+
         # Generate a unique key_id
         key_id_bytes = nacl.utils.random(16)  # 16 random bytes
-        key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+        key_id = f"sk_{encode_bytes_readable(key_id_bytes, chars)}"
 
         # Generate a high-entropy API key
-        raw_api_key = base64.urlsafe_b64encode(
-            nacl.utils.random(self.config.api_key_bytes)
-        ).decode()
+        raw_api_key = encode_bytes_readable(
+            nacl.utils.random(self.config.api_key_bytes), chars
+        )
 
         # The caller will store the hashed version in the database
         return key_id, raw_api_key

+ 0 - 1
core/providers/embeddings/litellm.py

@@ -43,7 +43,6 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
 
         self.rerank_url = None
         if config.rerank_model:
-
             if "huggingface" not in config.rerank_model:
                 raise ValueError(
                     "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"

+ 0 - 1
core/providers/ingestion/r2r/base.py

@@ -186,7 +186,6 @@ class R2RIngestionProvider(IngestionProvider):
         parsed_document: str | DocumentChunk,
         ingestion_config_override: dict,
     ) -> AsyncGenerator[Any, None]:
-
         text_spliiter = self.text_splitter
         if ingestion_config_override:
             text_spliiter = self._build_text_splitter(

+ 0 - 1
core/providers/ingestion/unstructured/base.py

@@ -220,7 +220,6 @@ class UnstructuredIngestionProvider(IngestionProvider):
         document: Document,
         ingestion_config_override: dict,
     ) -> AsyncGenerator[DocumentChunk, None]:
-
         ingestion_config = copy(
             {
                 **self.config.to_ingestion_request(),

+ 0 - 1
core/telemetry/telemetry_decorator.py

@@ -80,7 +80,6 @@ if os.getenv("TELEMETRY_ENABLED", "true").lower() in ("true", "1"):
 def telemetry_event(event_name):
     def decorator(func):
         def log_telemetry(event_type, user_id, metadata, error_message=None):
-
             if telemetry_thread_pool is None:
                 return
 

+ 0 - 2
migrations/versions/8077140e1e99_v3_api_database_revision.py

@@ -34,7 +34,6 @@ if (
 
 
 def upgrade() -> None:
-
     # Collections table migration
     op.alter_column(
         "collections",
@@ -195,7 +194,6 @@ def upgrade() -> None:
 
 
 def downgrade() -> None:
-
     # Collections table migration
     op.alter_column(
         "collections",

+ 3 - 0
pyproject.toml

@@ -52,6 +52,7 @@ unstructured-client = "0.25.5"
 psycopg-binary = "^3.2.3"
 aiosmtplib = "^3.0.2"
 types-aiofiles = "^24.1.0.20240626"
+rich = "^13.9.4"
 aiohttp = "^3.10.10"
 typing-extensions = "^4.12.2"
 
@@ -81,9 +82,11 @@ sendgrid = { version = "^6.11.0", optional = true }
 sqlalchemy = { version = "^2.0.30", optional = true }
 supabase = { version = "^2.7.4", optional = true }
 tokenizers = { version = "0.19", optional = true }
+unstructured-client = { version = "0.25.5", optional = true }
 uvicorn = { version = "^0.27.0.post1", optional = true }
 vecs = { version = "^0.4.0", optional = true }
 
+
 # R2R Ingestion
 aiofiles = { version = "^24.1.0", optional = true }
 aioshutil = { version = "^1.5", optional = true }

+ 1 - 1
sdk/async_client.py

@@ -44,7 +44,7 @@ class R2RAsyncClient(
 
     def __init__(
         self,
-        base_url: str = "http://localhost:7272",
+        base_url: str = "https://api.cloud.sciphi.ai",
         prefix: str = "/v2",
         custom_client=None,
         timeout: float = 300.0,

+ 3 - 3
sdk/base/base_client.py

@@ -1,5 +1,6 @@
 import asyncio
 import contextlib
+import os
 from functools import wraps
 from typing import Optional
 
@@ -34,7 +35,7 @@ def sync_generator_wrapper(async_gen_func):
 class BaseClient:
     def __init__(
         self,
-        base_url: str = "http://localhost:7272",
+        base_url: str = "https://api.cloud.sciphi.ai",
         prefix: str = "/v2",
         timeout: float = 300.0,
     ):
@@ -43,7 +44,7 @@ class BaseClient:
         self.timeout = timeout
         self.access_token: Optional[str] = None
         self._refresh_token: Optional[str] = None
-        self.api_key: Optional[str] = None
+        self.api_key: Optional[str] = os.getenv("R2R_API_KEY", None)
 
     def _get_auth_header(self) -> dict[str, str]:
         if self.access_token and self.api_key:
@@ -69,7 +70,6 @@ class BaseClient:
         return f"{self.base_url}/{version}/{endpoint}"
 
     def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
-
         headers = kwargs.pop("headers", {})
         if (self.access_token or self.api_key) and endpoint not in [
             "register",

+ 0 - 1
sdk/sync_client.py

@@ -104,7 +104,6 @@ class R2RClient(R2RAsyncClient):
     def _make_sync_method(
         self, async_method: Callable[..., Coroutine[Any, Any, T]]
     ) -> Callable[..., T]:
-
         @functools.wraps(async_method)
         def wrapped(*args, **kwargs):
             return self._loop.run_until_complete(async_method(*args, **kwargs))

+ 1 - 2
sdk/v2/management.py

@@ -711,9 +711,8 @@ class ManagementMixins:
         Returns:
             dict: The conversation data.
         """
-        query_params = f"?branch_id={branch_id}" if branch_id else ""
         return await self._make_request(  # type: ignore
-            "GET", f"get_conversation/{str(conversation_id)}{query_params}"
+            "GET", f"get_conversation/{str(conversation_id)}"
         )
 
     @deprecated("Use client.conversations.create() instead")

+ 3 - 3
sdk/v3/graphs.py

@@ -497,7 +497,7 @@ class GraphsSDK:
         Returns:
             dict: Created entity information
         """
-        data = {
+        data: dict[str, Any] = {
             "name": name,
             "description": description,
         }
@@ -542,7 +542,7 @@ class GraphsSDK:
         Returns:
             dict: Created relationship information
         """
-        data = {
+        data: dict[str, Any] = {
             "subject": subject,
             "subject_id": str(subject_id),
             "predicate": predicate,
@@ -585,7 +585,7 @@ class GraphsSDK:
         Returns:
             dict: Created community information
         """
-        data = {
+        data: dict[str, Any] = {
             "name": name,
             "summary": summary,
         }

+ 6 - 14
sdk/v3/users.py

@@ -1,4 +1,5 @@
 from __future__ import annotations  # for Python 3.10+
+import json
 
 from typing import Optional
 from uuid import UUID
@@ -125,19 +126,6 @@ class UsersSDK:
             version="v3",
         )
 
-    # async def set_api_key(self, api_key: str) -> None:
-    #     """
-    #     Set the API key for the client.
-
-    #     Args:
-    #         api_key (str): API key to set
-    #     """
-    #     if self.client.access_token:
-    #         raise ValueError(
-    #             "Cannot set an API key after logging in, please log out first"
-    #         )
-    #     self.client.set_api_key(api_key)
-
     async def login(self, email: str, password: str) -> dict[str, Token]:
         """
         Log in a user.
@@ -194,7 +182,7 @@ class UsersSDK:
             self.client._refresh_token = None
             raise ValueError("Invalid token provided")
 
-    async def logout(self) -> WrappedGenericMessageResponse:
+    async def logout(self) -> WrappedGenericMessageResponse | None:
         """Log out the current user."""
         if self.client.access_token:
             response = await self.client._make_request(
@@ -209,6 +197,7 @@ class UsersSDK:
 
         self.client.access_token = None
         self.client._refresh_token = None
+        return None
 
     async def refresh_token(self) -> WrappedTokenResponse:
         """Refresh the access token using the refresh token."""
@@ -361,6 +350,7 @@ class UsersSDK:
         name: Optional[str] = None,
         bio: Optional[str] = None,
         profile_picture: Optional[str] = None,
+        limits_overrides: dict | None = None,
     ) -> WrappedUserResponse:
         """
         Update user information.
@@ -387,6 +377,8 @@ class UsersSDK:
             data["bio"] = bio
         if profile_picture is not None:
             data["profile_picture"] = profile_picture
+        if limits_overrides is not None:
+            data["limits_overrides"] = limits_overrides
 
         return await self.client._make_request(
             "POST",

+ 0 - 1
shared/abstractions/graph.py

@@ -62,7 +62,6 @@ class Relationship(R2RSerializable):
 
 @dataclass
 class Community(R2RSerializable):
-
     name: str = ""
     summary: str = ""
 

+ 2 - 3
shared/abstractions/search.py

@@ -468,7 +468,6 @@ def select_search_filters(
     auth_user: Any,
     search_settings: SearchSettings,
 ) -> dict[str, Any]:
-
     filters = copy(search_settings.filters)
     selected_collections = None
     if not auth_user.is_superuser:
@@ -493,7 +492,7 @@ def select_search_filters(
         }
 
         filters.pop("collection_ids", None)
-
-        filters = {"$and": [collection_filters, filters]}  # type: ignore
+        if filters != {}:
+            filters = {"$and": [collection_filters, filters]}  # type: ignore
 
     return filters

+ 3 - 0
shared/abstractions/user.py

@@ -53,6 +53,9 @@ class User(R2RSerializable):
     graph_ids: list[UUID] = []
     document_ids: list[UUID] = []
 
+    # Add the new limits_overrides field
+    limits_overrides: Optional[dict] = None
+
     # Optional fields (to update or set at creation)
     hashed_password: Optional[str] = None
     verification_code_expiry: Optional[datetime] = None

+ 56 - 0
tests/cli/async_invoke.py

@@ -0,0 +1,56 @@
+import asyncio
+from types import TracebackType
+from typing import Any, Tuple, Type, cast
+
+import asyncclick as click
+from click import Abort
+from click.testing import CliRunner, Result
+
+
+async def async_invoke(
+    runner: CliRunner, cmd: click.Command, *args: str, **kwargs: Any
+) -> Result:
+    """Helper function to invoke async Click commands in tests."""
+    exit_code = 0
+    exception: BaseException | None = None
+    exc_info: (
+        Tuple[Type[BaseException], BaseException, TracebackType] | None
+    ) = None
+
+    with runner.isolation() as out_err:
+        stdout, stderr = out_err
+        try:
+            # Get current event loop instead of creating new one
+            loop = asyncio.get_event_loop()
+
+            # Run the command using create_task
+            task = loop.create_task(
+                cmd.main(args=args, standalone_mode=False, **kwargs)
+            )
+            return_value = await task
+
+        except Abort as e:
+            exit_code = 1
+            exception = cast(BaseException, e)
+            if e.__traceback__:
+                exc_info = (BaseException, exception, e.__traceback__)
+            return_value = None
+        except Exception as e:
+            exit_code = 1
+            exception = cast(BaseException, e)
+            if e.__traceback__:
+                exc_info = (BaseException, exception, e.__traceback__)
+            return_value = None
+
+        stdout_bytes = stdout.getvalue() or b""
+        stderr_bytes = stderr.getvalue() if stderr else b""
+
+        return Result(
+            runner=runner,
+            stdout_bytes=stdout_bytes,
+            stderr_bytes=stderr_bytes,
+            return_value=return_value,
+            exit_code=exit_code,
+            exception=exception,
+            exc_info=exc_info,
+        )

+ 162 - 0
tests/cli/commands/test_collections_cli.py

@@ -0,0 +1,162 @@
+"""
+Tests for the collection commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - delete
+    - list-documents
+    - list-users
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.collections import (
+    create,
+    delete,
+    list,
+    list_documents,
+    list_users,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_collection_lifecycle():
+    """Test the complete lifecycle of a collection: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_name = f"test-collection-{uuid.uuid4()}"
+    description = "Test collection description"
+
+    # Create collection
+    create_result = await async_invoke(
+        runner,
+        create,
+        collection_name,
+        "--description",
+        description,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    collection_id = create_response["results"]["id"]
+
+    try:
+        # Retrieve collection
+        retrieve_result = await async_invoke(
+            runner, retrieve, collection_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        assert collection_id in retrieve_output
+
+        # List documents in collection
+        list_docs_result = await async_invoke(
+            runner, list_documents, collection_id, obj=client
+        )
+        assert list_docs_result.exit_code == 0
+
+        # List users in collection
+        list_users_result = await async_invoke(
+            runner, list_users, collection_id, obj=client
+        )
+        assert list_users_result.exit_code == 0
+    finally:
+        # Delete collection
+        delete_result = await async_invoke(
+            runner, delete, collection_id, obj=client
+        )
+        assert delete_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_list_collections():
+    """Test listing collections with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test collection first
+    create_result = await async_invoke(
+        runner, create, f"test-collection-{uuid.uuid4()}", obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    collection_id = response["results"]["id"]
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+        # Get paginated results just to verify they exist
+        list_paginated = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_paginated.exit_code == 0
+
+    finally:
+        # Cleanup
+        await async_invoke(runner, delete, collection_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_collection():
+    """Test operations on a nonexistent collection."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_id, obj=client
+    )
+    # Updated assertion to match actual error message
+    assert (
+        "the specified collection does not exist."
+        in retrieve_result.stderr_bytes.decode().lower()
+    )
+
+    # Test list_documents
+    list_docs_result = await async_invoke(
+        runner, list_documents, nonexistent_id, obj=client
+    )
+    assert (
+        "collection not found"
+        in list_docs_result.stderr_bytes.decode().lower()
+    )
+
+    # Test list_users
+    list_users_result = await async_invoke(
+        runner, list_users, nonexistent_id, obj=client
+    )
+    assert (
+        "collection not found"
+        in list_users_result.stderr_bytes.decode().lower()
+    )

+ 7 - 0
tests/cli/commands/test_config_cli.py

@@ -0,0 +1,7 @@
+"""
+Tests for the conversations commands in the CLI.
+    x reset
+    x key
+    x host
+    x view
+"""

+ 168 - 0
tests/cli/commands/test_conversations_cli.py

@@ -0,0 +1,168 @@
+"""
+Tests for the conversations commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - delete
+    - list-users
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.conversations import (
+    create,
+    delete,
+    list,
+    list_users,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_conversation_lifecycle():
+    """Test the complete lifecycle of a conversation: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create conversation
+    create_result = await async_invoke(
+        runner,
+        create,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    conversation_id = create_response["results"]["id"]
+
+    try:
+        # Retrieve conversation
+        retrieve_result = await async_invoke(
+            runner, retrieve, conversation_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        # FIXME: This assertion fails, we need to sync Conversation and ConversationResponse
+        # assert conversation_id in retrieve_output
+        assert "results" in retrieve_output
+
+        # List users in conversation
+        list_users_result = await async_invoke(
+            runner, list_users, conversation_id, obj=client
+        )
+        assert list_users_result.exit_code == 0
+    finally:
+        # Delete conversation
+        delete_result = await async_invoke(
+            runner, delete, conversation_id, obj=client
+        )
+        assert delete_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_list_conversations():
+    """Test listing conversations with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test conversation first
+    create_result = await async_invoke(runner, create, obj=client)
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    conversation_id = response["results"]["id"]
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+        # Test paginated results
+        list_paginated = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_paginated.exit_code == 0
+
+    finally:
+        # Cleanup
+        await async_invoke(runner, delete, conversation_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_conversation():
+    """Test operations on a nonexistent conversation."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_id, obj=client
+    )
+    assert "not found" in retrieve_result.stderr_bytes.decode().lower()
+
+    # Test list_users
+    list_users_result = await async_invoke(
+        runner, list_users, nonexistent_id, obj=client
+    )
+    assert "not found" in list_users_result.stderr_bytes.decode().lower()
+
+
+@pytest.mark.asyncio
+async def test_list_conversations_pagination():
+    """Test pagination functionality of list conversations."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create multiple conversations
+    conversation_ids = []
+    for _ in range(3):
+        create_result = await async_invoke(runner, create, obj=client)
+        response = extract_json_block(create_result.stdout_bytes.decode())
+        conversation_ids.append(response["results"]["id"])
+
+    try:
+        # Test with different pagination parameters
+        list_first_page = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_first_page.exit_code == 0
+        first_page_output = list_first_page.stdout_bytes.decode()
+
+        list_second_page = await async_invoke(
+            runner, list, "--offset", "2", "--limit", "2", obj=client
+        )
+        assert list_second_page.exit_code == 0
+        second_page_output = list_second_page.stdout_bytes.decode()
+
+        # Verify different results on different pages
+        assert first_page_output != second_page_output
+
+    finally:
+        # Cleanup
+        for conversation_id in conversation_ids:
+            await async_invoke(runner, delete, conversation_id, obj=client)

+ 0 - 0
tests/cli/commands/test_database_cli.py


+ 330 - 0
tests/cli/commands/test_documents_cli.py

@@ -0,0 +1,330 @@
+"""
+Tests for the document commands in the CLI.
+    - create
+    - retrieve
+    - list
+    - delete
+    - list-chunks
+    - list-collections
+    x ingest-files-from-url
+    x extract
+    x list-entities
+    x list-relationships
+    x create-sample
+    x create-samples
+"""
+
+import contextlib
+import json
+import os
+import tempfile
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.documents import (
+    create,
+    delete,
+    list,
+    list_chunks,
+    list_collections,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+@pytest.fixture
+def temp_text_file():
+    """Create a temporary text file for testing."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write("This is test content for document testing.")
+        temp_path = f.name
+
+    yield temp_path
+
+    # Cleanup temp file
+    if os.path.exists(temp_path):
+        os.unlink(temp_path)
+
+
+@pytest.fixture
+def temp_json_file():
+    """Create a temporary JSON file for testing."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".json", delete=False
+    ) as f:
+        json.dump({"test": "content", "for": "document testing"}, f)
+        temp_path = f.name
+
+    yield temp_path
+
+    # Cleanup temp file
+    if os.path.exists(temp_path):
+        os.unlink(temp_path)
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    # We assume the output contains at least one JSON object printed with json.dumps(indent=2).
+    # We'll find the first '{' and the matching closing '}' that forms a valid JSON object.
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    # Track braces to find the matching '}'
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+
+            if brace_count == 0:
+                # Found the matching closing brace
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_document_lifecycle(temp_text_file):
+    """Test the complete lifecycle of a document: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create document
+    create_result = await async_invoke(
+        runner, create, temp_text_file, obj=client
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    document_id = create_response["results"]["document_id"]
+
+    try:
+        # Retrieve document
+        retrieve_result = await async_invoke(
+            runner, retrieve, document_id, obj=client
+        )
+        assert (
+            retrieve_result.exit_code == 0
+        ), retrieve_result.stdout_bytes.decode()
+
+        # Instead of parsing JSON, verify the ID appears in the table output
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        assert document_id in retrieve_output
+
+        # List chunks
+        list_chunks_result = await async_invoke(
+            runner, list_chunks, document_id, obj=client
+        )
+        assert (
+            list_chunks_result.exit_code == 0
+        ), list_chunks_result.stdout_bytes.decode()
+
+        # List collections
+        list_collections_result = await async_invoke(
+            runner, list_collections, document_id, obj=client
+        )
+        assert (
+            list_collections_result.exit_code == 0
+        ), list_collections_result.stdout_bytes.decode()
+    finally:
+        # Delete document
+        delete_result = await async_invoke(
+            runner, delete, document_id, obj=client
+        )
+        assert (
+            delete_result.exit_code == 0
+        ), delete_result.stdout_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_create_multiple_documents(temp_text_file, temp_json_file):
+    """Test creating multiple documents with metadata."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    metadatas = json.dumps(
+        [
+            {"description": "Test document 1"},
+            {"description": "Test document 2"},
+        ]
+    )
+
+    create_result = await async_invoke(
+        runner,
+        create,
+        temp_text_file,
+        temp_json_file,
+        "--metadatas",
+        metadatas,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    # The command may print multiple JSON objects separated by dashes and status lines.
+    # Extract all JSON objects.
+    json_objects = []
+    start_idx = 0
+    while True:
+        try:
+            # Attempt to extract a JSON object from output[start_idx:]
+            block = extract_json_block(output[start_idx:])
+            json_objects.append(block)
+            # Move start_idx beyond this block to find the next one
+            next_start = output[start_idx:].find("{")
+            start_idx += output[start_idx:].find("{") + 1
+            # Move past the first '{' we found
+            # Actually, let's break after one extraction to avoid infinite loops if the output is large.
+            # Instead, we find multiple objects by splitting on the line of dashes:
+            break
+        except ValueError:
+            break
+
+    # Alternatively, if multiple objects are separated by "----------", we can split and parse each:
+    # This assumes each block between "----------" lines contains exactly one JSON object.
+    blocks = output.split("-" * 40)
+    json_objects = []
+    for block in blocks:
+        block = block.strip()
+        if '"results"' in block and "{" in block and "}" in block:
+            with contextlib.suppress(ValueError):
+                json_objects.append(extract_json_block(block))
+
+    assert (
+        len(json_objects) == 2
+    ), f"Expected 2 JSON objects, got {len(json_objects)}: {output}"
+
+    document_ids = [obj["results"]["document_id"] for obj in json_objects]
+
+    try:
+        # List all documents
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0, list_result.stdout_bytes.decode()
+
+        # Verify both documents were created
+        for doc_id in document_ids:
+            retrieve_result = await async_invoke(
+                runner, retrieve, doc_id, obj=client
+            )
+            assert (
+                retrieve_result.exit_code == 0
+            ), retrieve_result.stdout_bytes.decode()
+    finally:
+        # Cleanup - delete all created documents
+        for doc_id in document_ids:
+            delete_result = await async_invoke(
+                runner, delete, doc_id, obj=client
+            )
+            assert (
+                delete_result.exit_code == 0
+            ), delete_result.stdout_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_create_with_custom_id():
+    """Test creating a document with a custom ID."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    custom_id = str(uuid.uuid4())
+
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write("Test content")
+        temp_path = f.name
+
+    try:
+        create_result = await async_invoke(
+            runner, create, temp_path, "--ids", custom_id, obj=client
+        )
+        assert (
+            create_result.exit_code == 0
+        ), create_result.stdout_bytes.decode()
+
+        output = create_result.stdout_bytes.decode()
+        create_response = extract_json_block(output)
+        assert create_response["results"]["document_id"] == custom_id
+    finally:
+        if os.path.exists(temp_path):
+            os.unlink(temp_path)
+
+        await async_invoke(runner, delete, custom_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_retrieve_nonexistent_document():
+    """Test retrieving a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(runner, retrieve, nonexistent_id, obj=client)
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "Document not found" in stderr
+        or "Document not found" in result.stdout_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_list_chunks_nonexistent_document():
+    """Test listing chunks for a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(
+        runner, list_chunks, nonexistent_id, obj=client
+    )
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "No chunks found for the given document ID." in stderr
+        or "No chunks found for the given document ID."
+        in result.stdout_bytes.decode()
+    )
+
+
+# FIXME: This should be returning 'Document not found' but returns an empty list instead.
+# @pytest.mark.asyncio
+# async def test_list_collections_nonexistent_document():
+#     """Test listing collections for a document that doesn't exist."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     nonexistent_id = str(uuid.uuid4())
+#     result = await async_invoke(
+#         runner, list_collections, nonexistent_id, obj=client
+#     )
+
+#     stderr = result.stderr_bytes.decode()
+#     assert (
+#         "Document not found" in stderr
+#         or "Document not found" in result.stdout_bytes.decode()
+#     )
+
+
+@pytest.mark.asyncio
+async def test_delete_nonexistent_document():
+    """Test deleting a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(runner, delete, nonexistent_id, obj=client)
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "No entries found for deletion" in stderr
+        or "No entries found for deletion" in result.stdout_bytes.decode()
+    )

+ 312 - 0
tests/cli/commands/test_graphs_cli.py

@@ -0,0 +1,312 @@
+"""
+Tests for the graphs commands in the CLI.
+    - list
+    - retrieve
+    - reset
+    - update
+    - list-entities
+    - get-entity
+    x remove-entity
+    - list-relationships
+    - get-relationship
+    x remove-relationship
+    - build
+    - list-communities
+    - get-community
+    x update-community
+    x delete-community
+    - pull
+    - remove-document
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.collections import create as create_collection
+from cli.commands.graphs import (
+    build,
+    delete_community,
+    get_community,
+    get_entity,
+    get_relationship,
+    list,
+    list_communities,
+    list_entities,
+    list_relationships,
+    pull,
+    remove_document,
+    remove_entity,
+    remove_relationship,
+    reset,
+    retrieve,
+    update,
+    update_community,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+async def create_test_collection(
+    runner: CliRunner, client: R2RAsyncClient
+) -> str:
+    """Helper function to create a test collection and return its ID."""
+    collection_name = f"test-collection-{uuid.uuid4()}"
+    create_result = await async_invoke(
+        runner, create_collection, collection_name, obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    return response["results"]["id"]
+
+
+@pytest.mark.asyncio
+async def test_graph_basic_operations():
+    """Test basic graph operations: retrieve, reset, update."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # Retrieve graph
+        retrieve_result = await async_invoke(
+            runner, retrieve, collection_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        assert collection_id in retrieve_result.stdout_bytes.decode()
+
+        # Update graph
+        new_name = "Updated Graph Name"
+        new_description = "Updated description"
+        update_result = await async_invoke(
+            runner,
+            update,
+            collection_id,
+            "--name",
+            new_name,
+            "--description",
+            new_description,
+            obj=client,
+        )
+        assert update_result.exit_code == 0
+
+        # Reset graph
+        reset_result = await async_invoke(
+            runner, reset, collection_id, obj=client
+        )
+        assert reset_result.exit_code == 0
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_entity_operations():
+    """Test entity-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List entities (empty initially)
+        list_entities_result = await async_invoke(
+            runner, list_entities, collection_id, obj=client
+        )
+        assert list_entities_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_entities,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent entity operations
+        nonexistent_entity_id = str(uuid.uuid4())
+        get_entity_result = await async_invoke(
+            runner,
+            get_entity,
+            collection_id,
+            nonexistent_entity_id,
+            obj=client,
+        )
+        assert "not found" in get_entity_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_relationship_operations():
+    """Test relationship-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List relationships
+        list_rel_result = await async_invoke(
+            runner, list_relationships, collection_id, obj=client
+        )
+        assert list_rel_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_relationships,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent relationship operations
+        nonexistent_rel_id = str(uuid.uuid4())
+        get_rel_result = await async_invoke(
+            runner,
+            get_relationship,
+            collection_id,
+            nonexistent_rel_id,
+            obj=client,
+        )
+        assert "not found" in get_rel_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_community_operations():
+    """Test community-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List communities
+        list_comm_result = await async_invoke(
+            runner, list_communities, collection_id, obj=client
+        )
+        assert list_comm_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_communities,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent community operations
+        nonexistent_comm_id = str(uuid.uuid4())
+        get_comm_result = await async_invoke(
+            runner,
+            get_community,
+            collection_id,
+            nonexistent_comm_id,
+            obj=client,
+        )
+        assert "not found" in get_comm_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_build_and_pull():
+    """Test graph building and document pull operations."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # Test build with minimal settings
+        settings = {"some_setting": "value"}
+        build_result = await async_invoke(
+            runner,
+            build,
+            collection_id,
+            "--settings",
+            json.dumps(settings),
+            obj=client,
+        )
+        assert build_result.exit_code == 0
+
+        # Test pull documents
+        pull_result = await async_invoke(
+            runner, pull, collection_id, obj=client
+        )
+        assert pull_result.exit_code == 0
+
+        # Test remove document (with nonexistent document)
+        nonexistent_doc_id = str(uuid.uuid4())
+        remove_doc_result = await async_invoke(
+            runner,
+            remove_document,
+            collection_id,
+            nonexistent_doc_id,
+            obj=client,
+        )
+        assert "not found" in remove_doc_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_list_graphs():
+    """Test listing graphs."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass

+ 6 - 0
tests/cli/commands/test_indices_cli.py

@@ -0,0 +1,6 @@
+"""
+Tests for the indices commands in the CLI.
+    x list
+    x retrieve
+    x delete
+"""

+ 95 - 0
tests/cli/commands/test_prompts_cli.py

@@ -0,0 +1,95 @@
+"""
+Tests for the prompts commands in the CLI.
+    - list
+    - retrieve
+    x delete
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.prompts import list, retrieve
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_prompts_list():
+    """Test listing prompts."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Test basic list
+    list_result = await async_invoke(runner, list, obj=client)
+    assert list_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_prompts_retrieve():
+    """Test retrieving prompts with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Test retrieve with just name
+    name = "hyde"
+    retrieve_result = await async_invoke(runner, retrieve, name, obj=client)
+    assert retrieve_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_prompt():
+    """Test operations on a nonexistent prompt."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_name = f"nonexistent-{uuid.uuid4()}"
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_name, obj=client
+    )
+    assert "not found" in retrieve_result.stderr_bytes.decode().lower()
+
+
+@pytest.mark.asyncio
+async def test_prompt_retrieve_with_all_options():
+    """Test retrieving a prompt with all options combined."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    name = "example-prompt"
+    inputs = "input1,input2"
+    override = "custom prompt text"
+
+    retrieve_result = await async_invoke(
+        runner,
+        retrieve,
+        name,
+        "--inputs",
+        inputs,
+        "--prompt-override",
+        override,
+        obj=client,
+    )
+    assert retrieve_result.exit_code == 0

+ 213 - 0
tests/cli/commands/test_retrieval_cli.py

@@ -0,0 +1,213 @@
+"""
+Tests for the retrieval commands in the CLI.
+    - search
+    - rag
+"""
+
+import json
+import tempfile
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.documents import create as create_document
+from cli.commands.retrieval import rag, search
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+async def create_test_document(
+    runner: CliRunner, client: R2RAsyncClient
+) -> str:
+    """Helper function to create a test document and return its ID."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write(
+            "This is a test document about artificial intelligence and machine learning. "
+            "AI systems can be trained on large datasets to perform various tasks."
+        )
+        temp_path = f.name
+
+    create_result = await async_invoke(
+        runner, create_document, temp_path, obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    return response["results"]["document_id"]
+
+
+@pytest.mark.asyncio
+async def test_basic_search():
+    """Test basic search functionality."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test document first
+    document_id = await create_test_document(runner, client)
+
+    try:
+        # Test basic search
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "artificial intelligence",
+            "--limit",
+            "5",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        assert "Vector search results:" in search_result.stdout_bytes.decode()
+
+    finally:
+        # Cleanup will be handled by document deletion in a real implementation
+        pass
+
+
+@pytest.mark.asyncio
+async def test_search_with_filters():
+    """Test search with filters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        filters = json.dumps({"document_id": {"$in": [document_id]}})
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "machine learning",
+            "--filters",
+            filters,
+            "--limit",
+            "5",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        output = search_result.stdout_bytes.decode()
+        assert "Vector search results:" in output
+        assert document_id in output
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_search_with_advanced_options():
+    """Test search with advanced options."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "AI systems",
+            "--use-hybrid-search",
+            "true",
+            "--search-strategy",
+            "vanilla",
+            "--graph-search-enabled",
+            "true",
+            "--chunk-search-enabled",
+            "true",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        output = search_result.stdout_bytes.decode()
+        assert "Vector search results:" in output
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_basic_rag():
+    """Test basic RAG functionality."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_rag_with_streaming():
+    """Test RAG with streaming enabled."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            "--stream",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_rag_with_model_specification():
+    """Test RAG with specific model."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            "--rag-model",
+            "azure/gpt-4o-mini",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass

+ 338 - 0
tests/cli/commands/test_system_cli.py

@@ -0,0 +1,338 @@
+"""
+Tests for the system commands in the CLI.
+    - health
+    - settings
+    - status
+    x serve
+    x image-exists
+    x docker-down
+    x generate-report
+    x update
+    - version
+
+"""
+
+import json
+from importlib.metadata import version as get_version
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.system import health, settings, status, version
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+@pytest.mark.asyncio
+async def test_health_against_server():
+    """Test health check against a real server."""
+    # Create real client
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+
+    # Run command
+    runner = CliRunner(mix_stderr=False)
+    result = await async_invoke(runner, health, obj=client)
+
+    # Extract just the JSON part (everything after the "Time taken" line)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Basic validation
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+    assert "message" in response_data["results"]
+    assert response_data["results"]["message"] == "ok"
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_health_server_down():
+    """Test health check when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, health, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_health_invalid_url():
+    """Test health check with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, health, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_settings_against_server():
+    """Test settings retrieval against a real server."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+
+    # Extract JSON part after "Time taken" line
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Validate response structure
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+    assert "config" in response_data["results"]
+    assert "prompts" in response_data["results"]
+
+    # Validate key configuration sections
+    config = response_data["results"]["config"]
+    assert "completion" in config
+    assert "database" in config
+    assert "embedding" in config
+    assert "ingestion" in config
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_settings_server_down():
+    """Test settings retrieval when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_settings_invalid_url():
+    """Test settings retrieval with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_settings_response_structure():
+    """Test detailed structure of settings response."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    # Validate prompts structure
+    prompts = response_data["results"]["prompts"]
+    assert "results" in prompts
+    assert "total_entries" in prompts
+    assert isinstance(prompts["results"], list)
+
+    # Validate prompt entries
+    for prompt in prompts["results"]:
+        assert "name" in prompt
+        assert "id" in prompt
+        assert "template" in prompt
+        assert "input_types" in prompt
+        assert "created_at" in prompt
+        assert "updated_at" in prompt
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_settings_config_validation():
+    """Test specific configuration values in settings response."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    config = response_data["results"]["config"]
+
+    # Validate completion config
+    completion = config["completion"]
+    assert "provider" in completion
+    assert "concurrent_request_limit" in completion
+    assert "generation_config" in completion
+
+    # Validate database config
+    database = config["database"]
+    assert "provider" in database
+    assert "default_collection_name" in database
+    assert "limits" in database
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_against_server():
+    """Test status check against a real server."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+
+    # Extract JSON part after "Time taken" line
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Validate response structure
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+
+    # Validate specific fields
+    results = response_data["results"]
+    assert "start_time" in results
+    assert "uptime_seconds" in results
+    assert "cpu_usage" in results
+    assert "memory_usage" in results
+
+    # Validate data types
+    assert isinstance(results["uptime_seconds"], (int, float))
+    assert isinstance(results["cpu_usage"], (int, float))
+    assert isinstance(results["memory_usage"], (int, float))
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_server_down():
+    """Test status check when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_status_invalid_url():
+    """Test status check with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_status_value_ranges():
+    """Test that status values are within expected ranges."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    results = response_data["results"]
+
+    # CPU usage should be between 0 and 100
+    assert 0 <= results["cpu_usage"] <= 100
+
+    # Memory usage should be between 0 and 100
+    assert 0 <= results["memory_usage"] <= 100
+
+    # Uptime should be positive
+    assert results["uptime_seconds"] > 0
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_start_time_format():
+    """Test that start_time is in correct ISO format."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    from datetime import datetime
+
+    # Verify start_time is valid ISO format
+    start_time = response_data["results"]["start_time"]
+    try:
+        datetime.fromisoformat(start_time.replace("Z", "+00:00"))
+    except ValueError:
+        pytest.fail("start_time is not in valid ISO format")
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_version_command():
+    """Test basic version command functionality."""
+    runner = CliRunner()
+    result = await async_invoke(runner, version)
+
+    # Verify command succeeded
+    assert result.exit_code == 0
+
+    # Verify output is valid JSON and matches actual package version
+    expected_version = get_version("r2r")
+    actual_version = json.loads(result.stdout_bytes.decode())
+    assert actual_version == expected_version
+
+
+@pytest.mark.asyncio
+async def test_version_output_format():
+    """Test that version output is properly formatted JSON."""
+    runner = CliRunner()
+    result = await async_invoke(runner, version)
+
+    # Verify output is valid JSON
+    try:
+        output = result.stdout_bytes.decode()
+        parsed = json.loads(output)
+        assert isinstance(parsed, str)  # Version should be a string
+    except json.JSONDecodeError:
+        pytest.fail("Version output is not valid JSON")
+
+    # Should be non-empty output ending with newline
+    assert output.strip()
+    assert output.endswith("\n")
+
+
+@pytest.mark.asyncio
+async def test_version_error_handling(monkeypatch):
+    """Test error handling when version import fails."""
+
+    def mock_version(_):
+        raise ImportError("Package not found")
+
+    # Mock the version function to raise an error
+    monkeypatch.setattr("importlib.metadata.version", mock_version)
+
+    runner = CliRunner(mix_stderr=False)
+    result = await async_invoke(runner, version)
+
+    # Verify command failed with exception
+    assert result.exit_code == 1
+    error_output = result.stderr_bytes.decode()
+    assert "An unexpected error occurred" in error_output
+    assert "Package not found" in error_output

+ 143 - 0
tests/cli/commands/test_users_cli.py

@@ -0,0 +1,143 @@
+"""
+Tests for the user commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - me
+    x list-collections
+    x add-to-collection
+    x remove-from-collection
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.users import (
+    add_to_collection,
+    create,
+    list,
+    list_collections,
+    me,
+    remove_from_collection,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_user_lifecycle():
+    """Test the complete lifecycle of a user: create, retrieve, list, collections."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test user with random email
+    test_email = f"test_{uuid.uuid4()}@example.com"
+    test_password = "TestPassword123!"
+
+    # Create user
+    create_result = await async_invoke(
+        runner, create, test_email, test_password, obj=client
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    user_id = create_response["results"]["id"]
+
+    try:
+        # List users and verify our new user is included
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0, list_result.stdout_bytes.decode()
+        list_output = list_result.stdout_bytes.decode()
+        assert test_email in list_output
+
+        # Retrieve specific user
+        retrieve_result = await async_invoke(
+            runner, retrieve, user_id, obj=client
+        )
+        assert (
+            retrieve_result.exit_code == 0
+        ), retrieve_result.stdout_bytes.decode()
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        retrieve_response = extract_json_block(retrieve_output)
+        assert retrieve_response["results"]["email"] == test_email
+
+        # Test me endpoint
+        me_result = await async_invoke(runner, me, obj=client)
+        assert me_result.exit_code == 0, me_result.stdout_bytes.decode()
+
+        # List collections for user
+        collections_result = await async_invoke(
+            runner, list_collections, user_id, obj=client
+        )
+        assert (
+            collections_result.exit_code == 0
+        ), collections_result.stdout_bytes.decode()
+
+    finally:
+        # We don't delete the user since there's no delete command
+        pass
+
+
+# FIXME: This should be returning 'User not found' but returns an empty list instead.
+# @pytest.mark.asyncio
+# async def test_retrieve_nonexistent_user():
+#     """Test retrieving a user that doesn't exist."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     nonexistent_id = str(uuid.uuid4())
+#     result = await async_invoke(runner, retrieve, nonexistent_id, obj=client)
+
+#     assert result.exit_code != 0
+#     error_output = result.stderr_bytes.decode()
+#     assert "User not found" in error_output
+
+
+# FIXME: This is returning with a status of 0 but has a 400 on the server side?
+# @pytest.mark.asyncio
+# async def test_create_duplicate_user():
+#     """Test creating a user with an email that already exists."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     test_email = f"test_{uuid.uuid4()}@example.com"
+#     test_password = "TestPassword123!"
+
+#     # Create first user
+#     first_result = await async_invoke(
+#         runner, create, test_email, test_password, obj=client
+#     )
+#     assert first_result.exit_code == 0
+
+#     # Try to create second user with same email
+#     second_result = await async_invoke(
+#         runner, create, test_email, test_password, obj=client
+#     )
+#     print(f"SECOND RESULT: {second_result}")
+#     assert second_result.exit_code != 0
+#     error_output = second_result.stderr_bytes.decode()
+#     assert "already exists" in error_output.lower()

+ 118 - 0
tests/cli/utils/test_timer.py

@@ -0,0 +1,118 @@
+import asyncio
+import time
+from unittest.mock import patch
+
+import asyncclick as click
+import pytest
+from click.testing import CliRunner
+
+from cli.utils.timer import timer
+from tests.cli.async_invoke import async_invoke
+
+
+@click.command()
+async def test_command():
+    with timer():
+        time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_measures_time():
+    runner = CliRunner()
+    result = await async_invoke(runner, test_command)
+    output = result.stdout_bytes.decode()
+    assert "Time taken:" in output
+    assert "seconds" in output
+    measured_time = float(output.split(":")[1].split()[0])
+    assert 0.1 <= measured_time <= 0.2
+
+
+@click.command()
+async def zero_duration_command():
+    with timer():
+        pass
+
+
+@pytest.mark.asyncio
+async def test_timer_zero_duration():
+    runner = CliRunner()
+    result = await async_invoke(runner, zero_duration_command)
+    output = result.stdout_bytes.decode()
+    measured_time = float(output.split(":")[1].split()[0])
+    assert measured_time >= 0
+    assert measured_time < 0.1
+
+
+@click.command()
+async def exception_command():
+    with timer():
+        raise ValueError("Test exception")
+
+
+@pytest.mark.asyncio
+async def test_timer_with_exception():
+    runner = CliRunner()
+    result = await async_invoke(runner, exception_command)
+    assert result.exit_code != 0
+    assert isinstance(result.exception, ValueError)
+
+
+@click.command()
+async def async_command():
+    with timer():
+        await asyncio.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_with_async_code():
+    runner = CliRunner()
+    result = await async_invoke(runner, async_command)
+    output = result.stdout_bytes.decode()
+    measured_time = float(output.split(":")[1].split()[0])
+    assert 0.1 <= measured_time <= 0.2
+
+
+@click.command()
+async def nested_command():
+    with timer():
+        time.sleep(0.1)
+        with timer():
+            time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_multiple_nested():
+    runner = CliRunner()
+    result = await async_invoke(runner, nested_command)
+    output = result.stdout_bytes.decode()
+    assert output.count("Time taken:") == 2
+
+
+@click.command()
+async def mock_time_command():
+    with timer():
+        pass
+
+
+@pytest.mark.asyncio
+@patch("time.time")
+async def test_timer_with_mock_time(mock_time):
+    mock_time.side_effect = [0, 1]  # Start and end times
+    runner = CliRunner()
+    result = await async_invoke(runner, mock_time_command)
+    output = result.stdout_bytes.decode()
+    assert "Time taken: 1.00 seconds" in output
+
+
+@click.command()
+async def precision_command():
+    with timer():
+        time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_precision():
+    runner = CliRunner()
+    result = await async_invoke(runner, precision_command)
+    output = result.stdout_bytes.decode()
+    assert len(output.split(":")[1].split()[0].split(".")[1]) == 2

+ 0 - 2
tests/integration/test_conversations.py

@@ -86,7 +86,6 @@ def test_delete_non_existent_conversation(client):
     bad_id = str(uuid.uuid4())
     with pytest.raises(R2RException) as exc_info:
         result = client.conversations.delete(id=bad_id)
-        print(result)
     assert (
         exc_info.value.status_code == 404
     ), "Wrong error code for delete non-existent"
@@ -122,7 +121,6 @@ def test_update_message(client, test_conversation):
         content="Updated content",
         metadata={"new_key": "new_value"},
     )["results"]
-    print(update_resp)
     # /new_branch_id = update_resp["new_branch_id"]
 
     assert update_resp["message"], "No message returned after update"

+ 187 - 0
tests/integration/test_filters.py

@@ -0,0 +1,187 @@
+import uuid
+
+import pytest
+
+from r2r import R2RException
+
+
+@pytest.fixture
+def setup_docs_with_collections(client):
+    # Create some test collections
+
+    random_suffix = str(uuid.uuid4())[:8]
+    coll_ids = []
+    for i in range(3):
+        resp = client.collections.create(name=f"TestColl{i}")["results"]
+        coll_ids.append(resp["id"])
+
+    # Create documents with different collection arrangements:
+    # doc1: [coll1]
+    doc1 = client.documents.create(
+        raw_text="Doc in coll1" + random_suffix, run_with_orchestration=False
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[0], doc1)
+
+    # doc2: [coll1, coll2]
+    doc2 = client.documents.create(
+        raw_text="Doc in coll1 and coll2" + random_suffix,
+        run_with_orchestration=False,
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[0], doc2)
+    client.collections.add_document(coll_ids[1], doc2)
+
+    # doc3: no collections
+    doc3 = client.documents.create(
+        raw_text="Doc in no collections" + random_suffix,
+        run_with_orchestration=False,
+    )["results"]["document_id"]
+
+    # doc4: [coll3]
+    doc4 = client.documents.create(
+        raw_text="Doc in coll3" + random_suffix, run_with_orchestration=False
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[2], doc4)
+
+    yield {"coll_ids": coll_ids, "doc_ids": [doc1, doc2, doc3, doc4]}
+
+    # Cleanup
+    for d_id in [doc1, doc2, doc3, doc4]:
+        try:
+            client.documents.delete(id=d_id)
+        except R2RException:
+            pass
+    for c_id in coll_ids:
+        try:
+            client.collections.delete(c_id)
+        except R2RException:
+            pass
+
+
+def test_collection_id_eq_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id = coll_ids[0] should match doc1 and doc2 only
+    filters = {"collection_id": {"$eq": coll_ids[0]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+    } == found_ids, f"Expected doc1 and doc2, got {found_ids}"
+
+
+def test_collection_id_ne_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id != coll_ids[0] means docs that are NOT in coll0
+    # Those are doc3 (no collections) and doc4 (in coll3 only)
+    filters = {"collection_id": {"$ne": coll_ids[0]}}
+    # listed = client.documents.list(limit=10, offset=0, filters=filters)["results"]
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc3,
+        doc4,
+    } == found_ids, f"Expected doc3 and doc4, got {found_ids}"
+
+
+def test_collection_id_in_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id in [coll_ids[0], coll_ids[2]] means docs in either coll0 or coll2
+    # doc1 in coll0, doc2 in coll0, doc4 in coll2
+    # doc3 is in none
+    filters = {"collection_id": {"$in": [coll_ids[0], coll_ids[2]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+        doc4,
+    } == found_ids, f"Expected doc1, doc2, doc4, got {found_ids}"
+
+
+def test_collection_id_nin_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id nin [coll_ids[1]] means docs that do NOT belong to coll1
+    # doc2 belongs to coll1, so exclude doc2
+    # doc1, doc3, doc4 remain
+    filters = {"collection_id": {"$nin": [coll_ids[1]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc3,
+        doc4,
+    } == found_ids, f"Expected doc1, doc3, doc4, got {found_ids}"
+
+
+def test_collection_id_contains_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # $contains: For a single collection_id, we interpret as arrays that must contain the given UUID.
+    # If collection_id {"$contains": "coll_ids[0]"}, docs must have coll0 in their array
+    # That would be doc1 and doc2 only
+    filters = {"collection_id": {"$contains": coll_ids[0]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+    } == found_ids, f"Expected doc1 and doc2, got {found_ids}"
+
+
+def test_collection_id_contains_multiple(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # If we allow $contains with a list, e.g., {"$contains": [coll_ids[0], coll_ids[1]]},
+    # this should mean the doc's collection_ids contain ALL of these.
+    # Only doc2 has coll0 AND coll1. doc1 only has coll0, doc3 no collections, doc4 only coll3.
+    filters = {"collection_id": {"$contains": [coll_ids[0], coll_ids[1]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {doc2} == found_ids, f"Expected doc2 only, got {found_ids}"
+
+
+def test_delete_by_collection_id_eq(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc1, doc2, doc3, doc4 = setup_docs_with_collections["doc_ids"]
+
+    # Delete documents in coll0
+    filters = {"collection_id": {"$eq": coll_ids[0]}}
+    del_resp = client.documents.delete_by_filter(filters)["results"]
+    assert del_resp["success"], "Failed to delete by collection_id $eq filter"
+
+    # doc1 and doc2 should be deleted, doc3 and doc4 remain
+    for d_id in [doc1, doc2]:
+        with pytest.raises(R2RException) as exc:
+            client.documents.retrieve(d_id)
+        assert exc.value.status_code == 404, f"Doc {d_id} still exists!"
+    # Check doc3 and doc4 still exist
+    assert client.documents.retrieve(doc3)
+    assert client.documents.retrieve(doc4)

+ 37 - 1
tests/integration/test_users.py

@@ -2,7 +2,9 @@ import uuid
 
 import pytest
 
+from core.database.postgres import PostgresUserHandler
 from r2r import R2RClient, R2RException
+from shared.abstractions import User
 
 
 @pytest.fixture(scope="session")
@@ -334,7 +336,6 @@ def test_non_owner_delete_collection(client):
     client.users.login(non_owner_email, non_owner_password)
     with pytest.raises(R2RException) as exc_info:
         result = client.collections.delete(coll_id)
-        print("result = ", result)
     assert (
         exc_info.value.status_code == 403
     ), "Wrong error code for non-owner deletion attempt"
@@ -599,3 +600,38 @@ def test_multiple_api_keys(client):
         ), f"Key {key_id} still exists after deletion"
 
     client.users.logout()
+
+
+def test_update_user_limits_overrides(client: R2RClient):
+    # 1) Create user
+    user_email = f"test_{uuid.uuid4()}@example.com"
+    client.users.register(user_email, "SomePassword123!")
+    client.users.login(user_email, "SomePassword123!")
+
+    # 2) Confirm the default overrides is None
+    fetched_user = client.users.me()["results"]
+    client.users.logout()
+
+    assert len(fetched_user["limits_overrides"]) == 0
+
+    # 3) Update the overrides
+    overrides = {
+        "global_per_min": 10,
+        "monthly_limit": 3000,
+        "route_overrides": {
+            "/some-route": {"route_per_min": 5},
+        },
+    }
+    client.users.update(id=fetched_user["id"], limits_overrides=overrides)
+
+    # 4) Fetch user again, check
+    client.users.login(user_email, "SomePassword123!")
+    updated_user = client.users.me()["results"]
+    assert len(updated_user["limits_overrides"]) != 0
+    assert updated_user["limits_overrides"]["global_per_min"] == 10
+    assert (
+        updated_user["limits_overrides"]["route_overrides"]["/some-route"][
+            "route_per_min"
+        ]
+        == 5
+    )

+ 344 - 0
tests/unit/conftest.py

@@ -0,0 +1,344 @@
+# tests/conftest.py
+import os
+from uuid import uuid4
+
+import pytest
+
+from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
+from core.database.postgres import (
+    PostgresChunksHandler,
+    PostgresCollectionsHandler,
+    PostgresConnectionManager,
+    PostgresConversationsHandler,
+    PostgresDatabaseProvider,
+    PostgresDocumentsHandler,
+    PostgresGraphsHandler,
+    PostgresLimitsHandler,
+)
+from core.database.users import (  # Make sure this import is correct
+    PostgresUserHandler,
+)
+from core.providers import NaClCryptoConfig, NaClCryptoProvider
+from core.utils import generate_user_id
+
+TEST_DB_CONNECTION_STRING = os.environ.get(
+    "TEST_DB_CONNECTION_STRING",
+    "postgresql://postgres:postgres@localhost:5432/test_db",
+)
+
+
+@pytest.fixture
+async def db_provider():
+    crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
+    db_config = DatabaseConfig(
+        app=AppConfig(project_name="test_project"),
+        provider="postgres",
+        connection_string=TEST_DB_CONNECTION_STRING,
+        postgres_configuration_settings={
+            "max_connections": 10,
+            "statement_cache_size": 100,
+        },
+        project_name="test_project",
+    )
+
+    dimension = 4
+    quantization_type = VectorQuantizationType.FP32
+
+    db_provider = PostgresDatabaseProvider(
+        db_config, dimension, crypto_provider, quantization_type
+    )
+
+    await db_provider.initialize()
+    yield db_provider
+    # Teardown logic if needed
+    await db_provider.close()
+
+
+@pytest.fixture
+def crypto_provider():
+    # Provide a crypto provider fixture if needed separately
+    return NaClCryptoProvider(NaClCryptoConfig(app={}))
+
+
+@pytest.fixture
+async def chunks_handler(db_provider):
+    dimension = db_provider.dimension
+    quantization_type = db_provider.quantization_type
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresChunksHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+        quantization_type=quantization_type,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def collections_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    config = db_provider.config
+
+    handler = PostgresCollectionsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        config=config,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def conversations_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresConversationsHandler(project_name, connection_manager)
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def documents_handler(db_provider):
+    dimension = db_provider.dimension
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresDocumentsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def graphs_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    dimension = db_provider.dimension
+    quantization_type = db_provider.quantization_type
+
+    # If collections_handler is needed, you can depend on the collections_handler fixture
+    # or pass None if it's optional.
+    handler = PostgresGraphsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+        quantization_type=quantization_type,
+        collections_handler=None,  # if needed, or await collections_handler fixture
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def limits_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    config = db_provider.config
+
+    handler = PostgresLimitsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        config=config,
+    )
+    await handler.create_tables()
+    # Optionally truncate
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('request_log')};"
+    )
+    return handler
+
+
+@pytest.fixture
+async def users_handler(db_provider, crypto_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresUserHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        crypto_provider=crypto_provider,
+    )
+    await handler.create_tables()
+
+    # Optionally clean up users table before each test
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('users')} CASCADE;"
+    )
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;"
+    )
+
+    return handler
+
+
+# # tests/conftest.py
+# import pytest
+# import os
+
+# from core.database.postgres import (
+#     PostgresChunksHandler,
+#     PostgresConnectionManager,
+#     PostgresDatabaseProvider,
+#     PostgresCollectionsHandler,
+#     PostgresConversationsHandler,
+#     PostgresDocumentsHandler,
+#     PostgresGraphsHandler,
+#     PostgresLimitsHandler,
+#     PostgresUserHandler
+# )
+# from core.providers import NaClCryptoConfig, NaClCryptoProvider
+# from core.base import  DatabaseConfig, VectorQuantizationType
+
+
+# TEST_DB_CONNECTION_STRING = os.environ.get(
+#     "TEST_DB_CONNECTION_STRING",
+#     "postgresql://postgres:postgres@localhost:5432/test_db",
+# )
+
+# @pytest.fixture
+# async def db_provider():
+#     # Example: a crypto provider needed by the database
+#     crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
+
+#     db_config = DatabaseConfig(
+#         app={},
+#         provider="postgres",
+#         connection_string=TEST_DB_CONNECTION_STRING,
+#         # Set these values as appropriate
+#         postgres_configuration_settings={
+#             "max_connections": 10,
+#             "statement_cache_size": 100,
+#         },
+#     )
+
+#     dimension = 4
+#     quantization_type = VectorQuantizationType.FP32
+
+#     db_provider = PostgresDatabaseProvider(
+#         db_config, dimension, crypto_provider, quantization_type
+#     )
+#     await db_provider.initialize()
+#     yield db_provider
+
+#     # Teardown logic if needed: close pools, drop tables, etc.
+#     await db_provider.close()
+
+
+# @pytest.fixture
+# async def chunks_handler(db_provider):
+#     # Assuming project_name and dimension are retrieved from db_provider
+#     dimension = db_provider.dimension
+#     quantization_type = db_provider.quantization_type
+#     project_name = db_provider.project_name
+#     connection_manager = (
+#         db_provider.connection_manager
+#     )  # type: PostgresConnectionManager
+
+#     handler = PostgresChunksHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#         quantization_type=quantization_type,
+#     )
+#     await handler.create_tables()
+#     return handler
+
+
+# @pytest.fixture
+# async def collections_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     config = db_provider.config
+
+#     handler = PostgresCollectionsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         config=config
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def conversations_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresConversationsHandler(project_name, connection_manager)
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def documents_handler(db_provider):
+#     dimension = db_provider.dimension
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresDocumentsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def graphs_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     dimension = db_provider.dimension
+#     quantization_type = db_provider.quantization_type
+
+#     # Constructing graphs handler with required args
+#     handler = PostgresGraphsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#         quantization_type=quantization_type,
+#         collections_handler=None  # If needed, you can mock or create a collections_handler
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def limits_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     config = db_provider.config  # This has default limits
+
+#     handler = PostgresLimitsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         config=config,
+#     )
+#     await handler.create_tables()
+#     # Optionally truncate after creation to ensure clean state
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('request_log')};")
+
+#     return handler
+
+
+# @pytest.fixture
+# async def users_handler(db_provider, crypto_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresUserHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         crypto_provider=crypto_provider,
+#     )
+#     await handler.create_tables()
+
+#     # Optionally clean up users table before each test
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
+
+#     return handler

+ 315 - 0
tests/unit/test_chunks.py

@@ -0,0 +1,315 @@
+# tests/integration/test_chunks.py
+import asyncio
+import uuid
+from typing import AsyncGenerator, Optional, Tuple
+
+import pytest
+
+from r2r import R2RAsyncClient, R2RException
+
+
+class AsyncR2RTestClient:
+    """Wrapper to ensure async operations use the correct event loop"""
+
+    def __init__(self, base_url: str = "http://localhost:7272"):
+        self.client = R2RAsyncClient(base_url)
+
+    async def create_document(
+        self, chunks: list[str], run_with_orchestration: bool = False
+    ) -> Tuple[str, list[dict]]:
+        response = await self.client.documents.create(
+            chunks=chunks, run_with_orchestration=run_with_orchestration
+        )
+        return response["results"]["document_id"], []
+
+    async def delete_document(self, doc_id: str) -> None:
+        await self.client.documents.delete(id=doc_id)
+
+    async def list_chunks(self, doc_id: str) -> list[dict]:
+        response = await self.client.documents.list_chunks(id=doc_id)
+        return response["results"]
+
+    async def retrieve_chunk(self, chunk_id: str) -> dict:
+        response = await self.client.chunks.retrieve(id=chunk_id)
+        return response["results"]
+
+    async def update_chunk(
+        self, chunk_id: str, text: str, metadata: Optional[dict] = None
+    ) -> dict:
+        response = await self.client.chunks.update(
+            {"id": chunk_id, "text": text, "metadata": metadata or {}}
+        )
+        return response["results"]
+
+    async def delete_chunk(self, chunk_id: str) -> dict:
+        response = await self.client.chunks.delete(id=chunk_id)
+        return response["results"]
+
+    async def search_chunks(self, query: str, limit: int = 5) -> list[dict]:
+        response = await self.client.chunks.search(
+            query=query, search_settings={"limit": limit}
+        )
+        return response["results"]
+
+    async def register_user(self, email: str, password: str) -> None:
+        await self.client.users.register(email, password)
+
+    async def login_user(self, email: str, password: str) -> None:
+        await self.client.users.login(email, password)
+
+    async def logout_user(self) -> None:
+        await self.client.users.logout()
+
+
+@pytest.fixture
+async def test_client() -> AsyncGenerator[AsyncR2RTestClient, None]:
+    """Create a test client."""
+    client = AsyncR2RTestClient()
+    yield client
+
+
+@pytest.fixture
+async def test_document(
+    test_client: AsyncR2RTestClient,
+) -> AsyncGenerator[Tuple[str, list[dict]], None]:
+    """Create a test document with chunks."""
+    doc_id, _ = await test_client.create_document(
+        ["Test chunk 1", "Test chunk 2"]
+    )
+    await asyncio.sleep(1)  # Wait for ingestion
+    chunks = await test_client.list_chunks(doc_id)
+    yield doc_id, chunks
+    try:
+        await test_client.delete_document(doc_id)
+    except R2RException:
+        pass
+
+
+class TestChunks:
+    @pytest.mark.asyncio
+    async def test_create_and_list_chunks(
+        self, test_client: AsyncR2RTestClient
+    ):
+        # Create document with chunks
+        doc_id, _ = await test_client.create_document(
+            ["Hello chunk", "World chunk"]
+        )
+        await asyncio.sleep(1)  # Wait for ingestion
+
+        # List and verify chunks
+        chunks = await test_client.list_chunks(doc_id)
+        assert len(chunks) == 2, "Expected 2 chunks in the document"
+
+        # Cleanup
+        await test_client.delete_document(doc_id)
+
+    @pytest.mark.asyncio
+    async def test_retrieve_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        retrieved = await test_client.retrieve_chunk(chunk_id)
+        assert retrieved["id"] == chunk_id, "Retrieved wrong chunk ID"
+        assert retrieved["text"] == "Test chunk 1", "Chunk text mismatch"
+
+    @pytest.mark.asyncio
+    async def test_update_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Update chunk
+        updated = await test_client.update_chunk(
+            chunk_id, "Updated text", {"version": 2}
+        )
+        assert updated["text"] == "Updated text", "Chunk text not updated"
+        assert updated["metadata"]["version"] == 2, "Metadata not updated"
+
+    @pytest.mark.asyncio
+    async def test_delete_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Delete and verify
+        result = await test_client.delete_chunk(chunk_id)
+        assert result["success"], "Chunk deletion failed"
+
+        # Verify deletion
+        with pytest.raises(R2RException) as exc_info:
+            await test_client.retrieve_chunk(chunk_id)
+        assert exc_info.value.status_code == 404
+
+    @pytest.mark.asyncio
+    async def test_search_chunks(self, test_client: AsyncR2RTestClient):
+        # Create searchable document
+        doc_id, _ = await test_client.create_document(
+            ["Aristotle reference", "Another piece of text"]
+        )
+        await asyncio.sleep(1)  # Wait for indexing
+
+        # Search
+        results = await test_client.search_chunks("Aristotle")
+        assert len(results) > 0, "No search results found"
+
+        # Cleanup
+        await test_client.delete_document(doc_id)
+
+    @pytest.mark.asyncio
+    async def test_unauthorized_chunk_access(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Create and login as different user
+        non_owner_client = AsyncR2RTestClient()
+        email = f"test_{uuid.uuid4()}@example.com"
+        await non_owner_client.register_user(email, "password123")
+        await non_owner_client.login_user(email, "password123")
+
+        # Attempt unauthorized access
+        with pytest.raises(R2RException) as exc_info:
+            await non_owner_client.retrieve_chunk(chunk_id)
+        assert exc_info.value.status_code == 403
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_with_filters(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test listing chunks with owner_id filter."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        try:
+            # Create a document with chunks
+            doc_id, _ = await test_client.create_document(
+                ["Test chunk 1", "Test chunk 2"]
+            )
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # Test listing chunks (filters automatically applied on server)
+            response = await test_client.client.chunks.list(offset=0, limit=1)
+
+            assert "results" in response, "Expected 'results' in response"
+            # assert "page_info" in response, "Expected 'page_info' in response"
+            assert (
+                len(response["results"]) <= 1
+            ), "Expected at most 1 result due to limit"
+
+            if len(response["results"]) > 0:
+                # Verify we only get chunks owned by our temp user
+                chunk = response["results"][0]
+                chunks = await test_client.list_chunks(doc_id)
+                assert chunk["owner_id"] in [
+                    c["owner_id"] for c in chunks
+                ], "Got chunk from wrong owner"
+
+        finally:
+            # Cleanup
+            try:
+                await test_client.delete_document(doc_id)
+            except:
+                pass
+            await test_client.logout_user()
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_pagination(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test chunk listing with pagination."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        doc_id = None
+        try:
+            # Create a document with multiple chunks
+            chunks = [f"Test chunk {i}" for i in range(5)]
+            doc_id, _ = await test_client.create_document(chunks)
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # Test first page
+            response1 = await test_client.client.chunks.list(offset=0, limit=2)
+
+            assert (
+                len(response1["results"]) == 2
+            ), "Expected 2 results on first page"
+            # assert response1["page_info"]["has_next"], "Expected more pages"
+
+            # Test second page
+            response2 = await test_client.client.chunks.list(offset=2, limit=2)
+
+            assert (
+                len(response2["results"]) == 2
+            ), "Expected 2 results on second page"
+
+            # Verify no duplicate results
+            ids_page1 = {chunk["id"] for chunk in response1["results"]}
+            ids_page2 = {chunk["id"] for chunk in response2["results"]}
+            assert not ids_page1.intersection(
+                ids_page2
+            ), "Found duplicate chunks across pages"
+
+        finally:
+            # Cleanup
+            if doc_id:
+                try:
+                    await test_client.delete_document(doc_id)
+                except:
+                    pass
+            await test_client.logout_user()
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_with_multiple_documents(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test listing chunks across multiple documents."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        doc_ids = []
+        try:
+            # Create multiple documents
+            for i in range(2):
+                doc_id, _ = await test_client.create_document(
+                    [f"Doc {i} chunk 1", f"Doc {i} chunk 2"]
+                )
+                doc_ids.append(doc_id)
+
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # List all chunks
+            response = await test_client.client.chunks.list(offset=0, limit=10)
+
+            assert len(response["results"]) == 4, "Expected 4 total chunks"
+
+            # Verify all chunks belong to our documents
+            chunk_doc_ids = {
+                chunk["document_id"] for chunk in response["results"]
+            }
+            assert all(
+                str(doc_id) in chunk_doc_ids for doc_id in doc_ids
+            ), "Got chunks from wrong documents"
+
+        finally:
+            # Cleanup
+            for doc_id in doc_ids:
+                try:
+                    await test_client.delete_document(doc_id)
+                except:
+                    pass
+            await test_client.logout_user()
+
+
+if __name__ == "__main__":
+    pytest.main(["-v", "--asyncio-mode=auto"])

+ 205 - 0
tests/unit/test_collections.py

@@ -0,0 +1,205 @@
+import pytest
+import uuid
+from uuid import UUID
+from core.base.api.models import CollectionResponse
+from core.base import R2RException
+
+
+@pytest.mark.asyncio
+async def test_create_collection(collections_handler):
+    owner_id = uuid.uuid4()
+    resp = await collections_handler.create_collection(
+        owner_id=owner_id,
+        name="Test Collection",
+        description="A test collection",
+    )
+    assert isinstance(resp, CollectionResponse)
+    assert resp.name == "Test Collection"
+    assert resp.owner_id == owner_id
+    assert resp.description == "A test collection"
+
+
+@pytest.mark.asyncio
+async def test_create_collection_default_name(collections_handler):
+    owner_id = uuid.uuid4()
+    # If no name provided, should use default_collection_name from config
+    resp = await collections_handler.create_collection(owner_id=owner_id)
+    assert isinstance(resp, CollectionResponse)
+    assert resp.name is not None  # default collection name should be set
+    assert resp.owner_id == owner_id
+
+
+@pytest.mark.asyncio
+async def test_update_collection(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="Original Name", description="Original Desc"
+    )
+
+    updated = await collections_handler.update_collection(
+        collection_id=coll.id,
+        name="Updated Name",
+        description="New Description",
+    )
+    assert updated.name == "Updated Name"
+    assert updated.description == "New Description"
+    # user_count and document_count should be integers
+    assert isinstance(updated.user_count, int)
+    assert isinstance(updated.document_count, int)
+
+
+@pytest.mark.asyncio
+async def test_update_collection_no_fields(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="NoUpdate", description="No Update"
+    )
+
+    with pytest.raises(R2RException) as exc:
+        await collections_handler.update_collection(collection_id=coll.id)
+    assert exc.value.status_code == 400
+
+
+@pytest.mark.asyncio
+async def test_delete_collection_relational(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="ToDelete"
+    )
+
+    # Confirm existence
+    exists = await collections_handler.collection_exists(coll.id)
+    assert exists is True
+
+    await collections_handler.delete_collection_relational(coll.id)
+
+    exists = await collections_handler.collection_exists(coll.id)
+    assert exists is False
+
+
+@pytest.mark.asyncio
+async def test_collection_exists(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(owner_id=owner_id)
+    assert await collections_handler.collection_exists(coll.id) is True
+
+
+@pytest.mark.asyncio
+async def test_documents_in_collection(collections_handler, db_provider):
+    # Create a collection
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="DocCollection"
+    )
+
+    # Insert some documents related to this collection
+    # We'll directly insert into the documents table for simplicity
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents (id, collection_ids, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status)
+        VALUES ($1, $2, $3, 'txt', '{{}}', 'Test Doc', 'v1', 1234, 'pending', 'pending')
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, [coll.id], owner_id]
+    )
+
+    # Now fetch documents in collection
+    res = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(res["results"]) == 1
+    assert res["total_entries"] == 1
+    assert res["results"][0].id == doc_id
+    assert res["results"][0].title == "Test Doc"
+
+
+@pytest.mark.asyncio
+async def test_get_collections_overview(collections_handler, db_provider):
+    owner_id = uuid.uuid4()
+    coll1 = await collections_handler.create_collection(
+        owner_id=owner_id, name="Overview1"
+    )
+    coll2 = await collections_handler.create_collection(
+        owner_id=owner_id, name="Overview2"
+    )
+
+    overview = await collections_handler.get_collections_overview(
+        offset=0, limit=10
+    )
+    # There should be at least these two
+    ids = [c.id for c in overview["results"]]
+    assert coll1.id in ids
+    assert coll2.id in ids
+
+
+@pytest.mark.asyncio
+async def test_assign_document_to_collection_relational(
+    collections_handler, db_provider
+):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="Assign"
+    )
+
+    # Insert a doc
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents (id, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status, collection_ids)
+        VALUES ($1, $2, 'txt', '{{}}', 'Standalone Doc', 'v1', 10, 'pending', 'pending', ARRAY[]::uuid[])
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, owner_id]
+    )
+
+    # Assign this doc to the collection
+    await collections_handler.assign_document_to_collection_relational(
+        doc_id, coll.id
+    )
+
+    # Verify doc is now in collection
+    docs = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(docs["results"]) == 1
+    assert docs["results"][0].id == doc_id
+
+
+@pytest.mark.asyncio
+async def test_remove_document_from_collection_relational(
+    collections_handler, db_provider
+):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="RemoveDoc"
+    )
+
+    # Insert a doc already in collection
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents
+        (id, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status, collection_ids)
+        VALUES ($1, $2, 'txt', '{{}}'::jsonb, 'Another Doc', 'v1', 10, 'pending', 'pending', $3)
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, owner_id, [coll.id]]
+    )
+
+    # Remove it
+    await collections_handler.remove_document_from_collection_relational(
+        doc_id, coll.id
+    )
+
+    docs = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(docs["results"]) == 0
+
+
+@pytest.mark.asyncio
+async def test_delete_nonexistent_collection(collections_handler):
+    non_existent_id = uuid.uuid4()
+    with pytest.raises(R2RException) as exc:
+        await collections_handler.delete_collection_relational(non_existent_id)
+    assert (
+        exc.value.status_code == 404
+    ), "Should raise 404 for non-existing collection"

+ 132 - 0
tests/unit/test_conversations.py

@@ -0,0 +1,132 @@
+import json
+import uuid
+from uuid import UUID
+
+import pytest
+
+from core.base import Message, R2RException
+from shared.api.models.management.responses import (
+    ConversationResponse,
+    MessageResponse,
+)
+
+
+@pytest.mark.asyncio
+async def test_create_conversation(conversations_handler):
+    resp = await conversations_handler.create_conversation()
+    assert isinstance(resp, ConversationResponse)
+    assert resp.id is not None
+    assert resp.created_at is not None
+
+
+@pytest.mark.asyncio
+async def test_create_conversation_with_user_and_name(conversations_handler):
+    user_id = uuid.uuid4()
+    resp = await conversations_handler.create_conversation(
+        user_id=user_id, name="Test Conv"
+    )
+    assert resp.id is not None
+    assert resp.created_at is not None
+    # There's no direct field for user_id in ConversationResponse,
+    # but we can verify by fetch:
+    # Just trust it for now since the handler doesn't return user_id directly.
+
+
+@pytest.mark.asyncio
+async def test_add_message(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="Hello!")
+    resp = await conversations_handler.add_message(conv_id, msg)
+    assert isinstance(resp, MessageResponse)
+    assert resp.id is not None
+    assert resp.message.content == "Hello!"
+
+
+@pytest.mark.asyncio
+async def test_add_message_with_parent(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    parent_msg = Message(role="user", content="Parent message")
+    parent_resp = await conversations_handler.add_message(conv_id, parent_msg)
+    parent_id = parent_resp.id
+
+    child_msg = Message(role="assistant", content="Child reply")
+    child_resp = await conversations_handler.add_message(
+        conv_id, child_msg, parent_id=parent_id
+    )
+    assert child_resp.id is not None
+    assert child_resp.message.content == "Child reply"
+
+
+@pytest.mark.asyncio
+async def test_edit_message(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    original_msg = Message(role="user", content="Original")
+    resp = await conversations_handler.add_message(conv_id, original_msg)
+    msg_id = resp.id
+
+    updated = await conversations_handler.edit_message(
+        msg_id, "Edited content"
+    )
+    assert updated["message"].content == "Edited content"
+    assert updated["metadata"]["edited"] is True
+
+
+@pytest.mark.asyncio
+async def test_update_message_metadata(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="Meta-test")
+    resp = await conversations_handler.add_message(conv_id, msg)
+    msg_id = resp.id
+
+    await conversations_handler.update_message_metadata(
+        msg_id, {"test_key": "test_value"}
+    )
+
+    # Verify metadata updated
+    full_conversation = await conversations_handler.get_conversation(conv_id)
+    for m in full_conversation:
+        if m.id == str(msg_id):
+            assert m.metadata["test_key"] == "test_value"
+            break
+
+
+@pytest.mark.asyncio
+async def test_get_conversation(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg1 = Message(role="user", content="Msg1")
+    msg2 = Message(role="assistant", content="Msg2")
+
+    await conversations_handler.add_message(conv_id, msg1)
+    await conversations_handler.add_message(conv_id, msg2)
+
+    messages = await conversations_handler.get_conversation(conv_id)
+    assert len(messages) == 2
+    assert messages[0].message.content == "Msg1"
+    assert messages[1].message.content == "Msg2"
+
+
+@pytest.mark.asyncio
+async def test_delete_conversation(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="To be deleted")
+    await conversations_handler.add_message(conv_id, msg)
+
+    await conversations_handler.delete_conversation(conv_id)
+
+    with pytest.raises(R2RException) as exc:
+        await conversations_handler.get_conversation(conv_id)
+    assert (
+        exc.value.status_code == 404
+    ), "Conversation should be deleted and not found"

+ 143 - 0
tests/unit/test_documents.py

@@ -0,0 +1,143 @@
+import json
+import uuid
+from uuid import UUID
+
+import pytest
+
+from core.base import (
+    DocumentResponse,
+    DocumentType,
+    IngestionStatus,
+    KGExtractionStatus,
+    R2RException,
+    SearchSettings,
+)
+
+
+def make_db_entry(doc: DocumentResponse):
+    # This simulates what your real code should do:
+    return {
+        "id": doc.id,
+        "collection_ids": doc.collection_ids,
+        "owner_id": doc.owner_id,
+        "document_type": doc.document_type.value,
+        "metadata": json.dumps(doc.metadata),
+        "title": doc.title,
+        "version": doc.version,
+        "size_in_bytes": doc.size_in_bytes,
+        "ingestion_status": doc.ingestion_status.value,
+        "extraction_status": doc.extraction_status.value,
+        "created_at": doc.created_at,
+        "updated_at": doc.updated_at,
+        "ingestion_attempt_number": 0,
+        "summary": doc.summary,
+        # If summary_embedding is a list, we can store it as a string here if needed
+        "summary_embedding": (
+            str(doc.summary_embedding)
+            if doc.summary_embedding is not None
+            else None
+        ),
+    }
+
+
+@pytest.mark.asyncio
+async def test_upsert_documents_overview_insert(documents_handler):
+    doc_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=uuid.uuid4(),
+        document_type=DocumentType.TXT,
+        metadata={"description": "A test document"},
+        title="Test Doc",
+        version="v1",
+        size_in_bytes=1234,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    # Simulate the handler call
+    await documents_handler.upsert_documents_overview(
+        [doc]
+    )  # adjust your handler to accept list or doc
+    # If your handler expects a db entry dict, you may need to patch handler or adapt your code
+
+    # Verify
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    assert res["total_entries"] == 1
+    fetched_doc = res["results"][0]
+    assert fetched_doc.id == doc_id
+    assert fetched_doc.title == "Test Doc"
+    assert fetched_doc.metadata["description"] == "A test document"
+
+
+@pytest.mark.asyncio
+async def test_upsert_documents_overview_update(documents_handler):
+    doc_id = uuid.uuid4()
+    owner_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=owner_id,
+        document_type=DocumentType.TXT,
+        metadata={"note": "initial"},
+        title="Initial Title",
+        version="v1",
+        size_in_bytes=100,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    await documents_handler.upsert_documents_overview([doc])
+
+    # Update document
+    doc.title = "Updated Title"
+    doc.metadata["note"] = "updated"
+
+    await documents_handler.upsert_documents_overview([doc])
+
+    # Verify update
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    fetched_doc = res["results"][0]
+    assert fetched_doc.title == "Updated Title"
+    assert fetched_doc.metadata["note"] == "updated"
+
+
+@pytest.mark.asyncio
+async def test_delete_document(documents_handler):
+    doc_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=uuid.uuid4(),
+        document_type=DocumentType.TXT,
+        metadata={},
+        title="ToDelete",
+        version="v1",
+        size_in_bytes=100,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    await documents_handler.upsert_documents_overview([doc])
+    await documents_handler.delete(doc_id)
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    assert res["total_entries"] == 0

+ 449 - 0
tests/unit/test_graphs.py

@@ -0,0 +1,449 @@
+import pytest
+import uuid
+from uuid import UUID
+
+from enum import Enum
+from core.base.abstractions import Entity, Relationship, Community
+from core.base.api.models import GraphResponse
+
+
+class StoreType(str, Enum):
+    GRAPHS = "graphs"
+    DOCUMENTS = "documents"
+
+
+@pytest.mark.asyncio
+async def test_create_graph(graphs_handler):
+    coll_id = uuid.uuid4()
+    resp = await graphs_handler.create(
+        collection_id=coll_id, name="My Graph", description="Test Graph"
+    )
+    assert isinstance(resp, GraphResponse)
+    assert resp.name == "My Graph"
+    assert resp.collection_id == coll_id
+
+
+@pytest.mark.asyncio
+async def test_add_entities_and_relationships(graphs_handler):
+    # Create a graph
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="TestGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add an entity
+    entity = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="TestEntity",
+        category="Person",
+        description="A test entity",
+    )
+    assert entity.name == "TestEntity"
+
+    # Add another entity
+    entity2 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="AnotherEntity",
+        category="Place",
+        description="A test place",
+    )
+
+    # Add a relationship between them
+    rel = await graphs_handler.relationships.create(
+        subject="TestEntity",
+        subject_id=entity.id,
+        predicate="lives_in",
+        object="AnotherEntity",
+        object_id=entity2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        description="Entity lives in AnotherEntity",
+    )
+    assert rel.predicate == "lives_in"
+
+    # Verify entities retrieval
+    ents, total_ents = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total_ents == 2
+    names = [e.name for e in ents]
+    assert "TestEntity" in names and "AnotherEntity" in names
+
+    # Verify relationships retrieval
+    rels, total_rels = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total_rels == 1
+    assert rels[0].predicate == "lives_in"
+
+
+@pytest.mark.asyncio
+async def test_delete_entities_and_relationships(graphs_handler):
+    # Create another graph
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeletableGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add entities
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="DeleteMe",
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="DeleteMeToo",
+    )
+
+    # Add relationship
+    rel = await graphs_handler.relationships.create(
+        subject="DeleteMe",
+        subject_id=e1.id,
+        predicate="related_to",
+        object="DeleteMeToo",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Delete one entity
+    await graphs_handler.entities.delete(
+        parent_id=graph_id,
+        entity_ids=[e1.id],
+        store_type=StoreType.GRAPHS.value,
+    )
+    ents, count = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert count == 1
+    assert ents[0].id == e2.id
+
+    # Delete the relationship
+    await graphs_handler.relationships.delete(
+        parent_id=graph_id,
+        relationship_ids=[rel.id],
+        store_type=StoreType.GRAPHS.value,
+    )
+    rels, rel_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert rel_count == 0
+
+
+@pytest.mark.asyncio
+async def test_communities(graphs_handler):
+    # Insert a community for a collection_id (not strictly related to a graph_id)
+    coll_id = uuid.uuid4()
+    await graphs_handler.communities.create(
+        parent_id=coll_id,
+        store_type=StoreType.GRAPHS.value,
+        name="CommunityOne",
+        summary="Test community",
+        findings=["finding1", "finding2"],
+        rating=4.5,
+        rating_explanation="Excellent",
+        description_embedding=[0.1, 0.2, 0.3, 0.4],
+    )
+
+    comms, count = await graphs_handler.communities.get(
+        parent_id=coll_id,
+        store_type=StoreType.GRAPHS.value,
+        offset=0,
+        limit=10,
+    )
+    assert count == 1
+    assert comms[0].name == "CommunityOne"
+
+
+# TODO - Fix code such that these tests pass
+# # @pytest.mark.asyncio
+# # async def test_delete_graph(graphs_handler):
+# #     # Create a graph and then delete it
+# #     coll_id = uuid.uuid4()
+# #     graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
+# #     graph_id = graph_resp.id
+
+# #     # reset or delete calls are complicated in the code. We'll just call `reset` and `delete`
+# #     await graphs_handler.reset(graph_id)
+# #     # This should remove all entities & relationships from the graph_id
+
+# #     # Now delete the graph itself
+# #     # The `delete` method seems to be tied to collection_id rather than graph_id
+# #     await graphs_handler.delete(collection_id=graph_id, cascade=False)
+# #     # If the code is structured so that delete requires a collection_id,
+# #     # ensure `graph_id == collection_id` or adapt the code accordingly.
+
+# #     # Try fetching the graph
+# #     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
+# #     assert overview["total_entries"] == 0, "Graph should be deleted"
+
+
+# @pytest.mark.asyncio
+# async def test_delete_graph(graphs_handler):
+#     # Create a graph and then delete it
+#     coll_id = uuid.uuid4()
+#     graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
+#     graph_id = graph_resp.id
+
+#     # Reset the graph (remove entities, relationships, communities)
+#     await graphs_handler.reset(graph_id)
+
+#     # Now delete the graph using collection_id (which equals graph_id in this code)
+#     await graphs_handler.delete(collection_id=coll_id)
+
+#     # Verify the graph is deleted
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[coll_id])
+#     assert overview["total_entries"] == 0, "Graph should be deleted"
+
+
+@pytest.mark.asyncio
+async def test_create_graph_defaults(graphs_handler):
+    # Create a graph without specifying name or description
+    coll_id = uuid.uuid4()
+    resp = await graphs_handler.create(collection_id=coll_id)
+    assert resp.collection_id == coll_id
+    # The code sets a default name, which should be "Graph {coll_id}"
+    assert resp.name == f"Graph {coll_id}"
+    # Default description should be empty string as per code
+    assert resp.description == ""
+
+
+# @pytest.mark.asyncio
+# async def test_list_multiple_graphs(graphs_handler):
+#     # Create multiple graphs
+#     coll_id1 = uuid.uuid4()
+#     coll_id2 = uuid.uuid4()
+#     graph_resp1 = await graphs_handler.create(collection_id=coll_id1, name="Graph1")
+#     graph_resp2 = await graphs_handler.create(collection_id=coll_id2, name="Graph2")
+#     graph_resp3 = await graphs_handler.create(collection_id=coll_id2, name="Graph3")
+
+#     # List all graphs without filters
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10)
+#     # Ensure at least these three are in there
+#     found_ids = [g.id for g in overview["results"]]
+#     assert graph_resp1.id in found_ids
+#     assert graph_resp2.id in found_ids
+#     assert graph_resp3.id in found_ids
+
+#     # Filter by collection_id = coll_id2 should return Graph2 and Graph3 (the most recent one first if same collection)
+#     overview_coll2 = await graphs_handler.list_graphs(offset=0, limit=10, filter_collection_id=coll_id2)
+#     returned_ids = [g.id for g in overview_coll2["results"]]
+#     # According to the code, we only see the "most recent" graph per collection. Verify this logic.
+#     # If your code is returning only the most recent graph per collection, we should see only one graph per collection_id here.
+#     # Adjust test according to actual logic you desire.
+#     # For this example, let's assume we should only get the latest graph per collection. Graph3 should be newer than Graph2.
+#     assert len(returned_ids) == 1
+#     assert graph_resp3.id in returned_ids
+
+
+@pytest.mark.asyncio
+async def test_update_graph(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="OldName", description="OldDescription"
+    )
+    graph_id = graph_resp.id
+
+    # Update name and description
+    updated_resp = await graphs_handler.update(
+        collection_id=graph_id, name="NewName", description="NewDescription"
+    )
+    assert updated_resp.name == "NewName"
+    assert updated_resp.description == "NewDescription"
+
+    # Retrieve and verify
+    overview = await graphs_handler.list_graphs(
+        offset=0, limit=10, filter_graph_ids=[graph_id]
+    )
+    assert overview["total_entries"] == 1
+    fetched_graph = overview["results"][0]
+    assert fetched_graph.name == "NewName"
+    assert fetched_graph.description == "NewDescription"
+
+
+@pytest.mark.asyncio
+async def test_bulk_entities(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="BulkEntities"
+    )
+    graph_id = graph_resp.id
+
+    # Add multiple entities
+    entities_to_add = [
+        {"name": "EntityA", "category": "CategoryA", "description": "DescA"},
+        {"name": "EntityB", "category": "CategoryB", "description": "DescB"},
+        {"name": "EntityC", "category": "CategoryC", "description": "DescC"},
+    ]
+    for ent in entities_to_add:
+        await graphs_handler.entities.create(
+            parent_id=graph_id,
+            store_type=StoreType.GRAPHS.value,
+            name=ent["name"],
+            category=ent["category"],
+            description=ent["description"],
+        )
+
+    ents, total = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total == 3
+    fetched_names = [e.name for e in ents]
+    for ent in entities_to_add:
+        assert ent["name"] in fetched_names
+
+
+@pytest.mark.asyncio
+async def test_relationship_filtering(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="RelFilteringGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add entities
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node1"
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node2"
+    )
+    e3 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node3"
+    )
+
+    # Add different relationships
+    await graphs_handler.relationships.create(
+        subject="Node1",
+        subject_id=e1.id,
+        predicate="connected_to",
+        object="Node2",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    await graphs_handler.relationships.create(
+        subject="Node2",
+        subject_id=e2.id,
+        predicate="linked_with",
+        object="Node3",
+        object_id=e3.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Get all relationships
+    all_rels, all_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert all_count == 2
+
+    # Filter by relationship_type = ["connected_to"]
+    filtered_rels, filt_count = await graphs_handler.get_relationships(
+        parent_id=graph_id,
+        offset=0,
+        limit=10,
+        relationship_types=["connected_to"],
+    )
+    assert filt_count == 1
+    assert filtered_rels[0].predicate == "connected_to"
+
+
+@pytest.mark.asyncio
+async def test_delete_all_entities(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeleteAllEntities"
+    )
+    graph_id = graph_resp.id
+
+    # Add some entities
+    await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+    )
+    await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+    )
+
+    # Delete all entities without specifying IDs
+    await graphs_handler.entities.delete(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+    )
+    ents, count = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert count == 0
+
+
+@pytest.mark.asyncio
+async def test_delete_all_relationships(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeleteAllRels"
+    )
+    graph_id = graph_resp.id
+
+    # Add two entities and a relationship
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+    )
+    await graphs_handler.relationships.create(
+        subject="E1",
+        subject_id=e1.id,
+        predicate="connected",
+        object="E2",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Delete all relationships
+    await graphs_handler.relationships.delete(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+    )
+    rels, rel_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert rel_count == 0
+
+
+@pytest.mark.asyncio
+async def test_error_handling_invalid_graph_id(graphs_handler):
+    # Attempt to get a non-existent graph
+    non_existent_id = uuid.uuid4()
+    overview = await graphs_handler.list_graphs(
+        offset=0, limit=10, filter_graph_ids=[non_existent_id]
+    )
+    assert overview["total_entries"] == 0
+
+    # Attempt to delete a non-existent graph
+    with pytest.raises(Exception) as exc_info:
+        await graphs_handler.delete(collection_id=non_existent_id)
+    # Expect an R2RException or HTTPException (depending on your code)
+    # Check the message or type if needed
+
+
+# TODO - Fix code to pass this test.
+# @pytest.mark.asyncio
+# async def test_delete_graph_cascade(graphs_handler):
+#     coll_id = uuid.uuid4()
+#     graph_resp = await graphs_handler.create(collection_id=coll_id, name="CascadeGraph")
+#     graph_id = graph_resp.id
+
+#     # Add entities/relationships here if you have documents attached
+#     # This test would verify that cascade=True behavior is correct
+#     # For now, just call delete with cascade=True
+#     # Depending on your implementation, you might need documents associated with the collection to test fully.
+#     await graphs_handler.delete(collection_id=coll_id)
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
+#     assert overview["total_entries"] == 0

+ 249 - 0
tests/unit/test_limits.py

@@ -0,0 +1,249 @@
+import uuid
+from datetime import datetime, timedelta, timezone
+from uuid import UUID
+
+import pytest
+
+from core.base import LimitSettings
+from core.database.postgres import PostgresLimitsHandler
+from shared.abstractions import User
+
+
+@pytest.mark.asyncio
+async def test_log_request_and_count(limits_handler):
+    """
+    Test that when we log requests, the count increments, and rate-limits are enforced.
+    Route-specific test using the /v3/retrieval/search endpoint limits.
+    """
+    # Clear existing logs first
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/v3/retrieval/search"  # Using actual route from config
+    test_user = User(
+        id=user_id,
+        email="test@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    # Set route limit to match config: 5 requests per minute
+    old_route_limits = limits_handler.config.route_limits
+    new_route_limits = {
+        route: LimitSettings(route_per_min=5, monthly_limit=10)
+    }
+    limits_handler.config.route_limits = new_route_limits
+
+    print(f"\nTesting with route limits: {new_route_limits}")
+    print(f"Route settings: {limits_handler.config.route_limits[route]}")
+
+    try:
+        # Initial check should pass (no requests yet)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial check passed (no requests)")
+
+        # Log 5 requests (exactly at limit)
+        for i in range(5):
+            await limits_handler.log_request(user_id, route)
+            now = datetime.now(timezone.utc)
+            one_min_ago = now - timedelta(minutes=1)
+            route_count = await limits_handler._count_requests(
+                user_id, route, one_min_ago
+            )
+            print(f"Route count after request {i+1}: {route_count}")
+
+            # This should pass for all 5 requests
+            await limits_handler.check_limits(test_user, route)
+            print(f"Check limits passed after request {i+1}")
+
+        # Log the 6th request (over limit)
+        await limits_handler.log_request(user_id, route)
+        route_count = await limits_handler._count_requests(
+            user_id, route, one_min_ago
+        )
+        print(f"Route count after request 6: {route_count}")
+
+        # This check should fail as we've exceeded route_per_min=5
+        with pytest.raises(
+            ValueError, match="Per-route per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.route_limits = old_route_limits
+
+
+@pytest.mark.asyncio
+async def test_global_limit(limits_handler):
+    """
+    Test global limit using the configured limit of 10 requests per minute
+    """
+    # Clear existing logs
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/global-test"
+    test_user = User(
+        id=user_id,
+        email="globaltest@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    # Set global limit to match config: 10 requests per minute
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(
+        global_per_min=10, monthly_limit=20
+    )
+
+    try:
+        # Initial check should pass (no requests)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial global check passed (no requests)")
+
+        # Log 10 requests (hits the limit)
+        for i in range(11):
+            await limits_handler.log_request(user_id, route)
+
+        # Debug counts
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+        global_count = await limits_handler._count_requests(
+            user_id, None, one_min_ago
+        )
+        print(f"Global count after 10 requests: {global_count}")
+
+        # This should fail as we've hit global_per_min=10
+        with pytest.raises(
+            ValueError, match="Global per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.limits = old_limits
+
+
+@pytest.mark.asyncio
+async def test_monthly_limit(limits_handler):
+    """
+    Test monthly limit using the configured limit of 20 requests per month
+    """
+    # Clear existing logs
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/monthly-test"
+    test_user = User(
+        id=user_id,
+        email="monthly@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(monthly_limit=20)
+
+    try:
+        # Initial check should pass (no requests)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial monthly check passed (no requests)")
+
+        # Log 20 requests (hits the monthly limit)
+        for i in range(21):
+            await limits_handler.log_request(user_id, route)
+
+        # Get current month's count
+        now = datetime.now(timezone.utc)
+        first_of_month = now.replace(
+            day=1, hour=0, minute=0, second=0, microsecond=0
+        )
+        monthly_count = await limits_handler._count_requests(
+            user_id, None, first_of_month
+        )
+        print(f"Monthly count after 20 requests: {monthly_count}")
+
+        # This should fail as we've hit monthly_limit=20
+        with pytest.raises(ValueError, match="Monthly rate limit exceeded"):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.limits = old_limits
+
+
+@pytest.mark.asyncio
+async def test_user_level_override(limits_handler):
+    """
+    Test user-specific override limits with debug logging
+    """
+    user_id = UUID("47e53676-b478-5b3f-a409-234ca2164de5")
+    route = "/test-route"
+
+    # Clear existing logs first
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    test_user = User(
+        id=user_id,
+        email="override@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides={
+            "global_per_min": 2,
+            "route_per_min": 1,
+            "route_overrides": {"/test-route": {"route_per_min": 1}},
+        },
+    )
+
+    # Set default limits that should be overridden
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(
+        global_per_min=10, monthly_limit=20
+    )
+
+    # Debug: Print current limits
+    print(f"\nDefault limits: {limits_handler.config.limits}")
+    print(f"User overrides: {test_user.limits_overrides}")
+
+    try:
+        # First check limits (should pass as no requests yet)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial check passed (no requests yet)")
+
+        # Log first request
+        await limits_handler.log_request(user_id, route)
+
+        # Debug: Get current counts
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+        global_count = await limits_handler._count_requests(
+            user_id, None, one_min_ago
+        )
+        route_count = await limits_handler._count_requests(
+            user_id, route, one_min_ago
+        )
+        print(f"\nAfter first request:")
+        print(f"Global count: {global_count}")
+        print(f"Route count: {route_count}")
+
+        # Log second request
+        await limits_handler.log_request(user_id, route)
+
+        # This check should fail as we've hit route_per_min=1
+        with pytest.raises(
+            ValueError, match="Per-route per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        # Cleanup
+        limits_handler.config.limits = old_limits