jack 3 mesiacov pred
rodič
commit
fed6a49f29
98 zmenil súbory, kde vykonal 6175 pridanie a 1211 odobranie
  1. 148 12
      cli/command_group.py
  2. 72 48
      cli/commands/collections.py
  3. 165 0
      cli/commands/config.py
  4. 56 52
      cli/commands/conversations.py
  5. 4 6
      cli/commands/database.py
  6. 303 127
      cli/commands/documents.py
  7. 220 144
      cli/commands/graphs.py
  8. 39 41
      cli/commands/indices.py
  9. 36 23
      cli/commands/prompts.py
  10. 52 39
      cli/commands/retrieval.py
  11. 40 17
      cli/commands/system.py
  12. 98 53
      cli/commands/users.py
  13. 76 4
      cli/main.py
  14. 0 6
      cli/utils/docker_utils.py
  15. 2 1
      cli/utils/timer.py
  16. 124 0
      compose.yaml
  17. 0 1
      core/base/parsers/base_parser.py
  18. 0 1
      core/base/providers/ingestion.py
  19. 53 10
      core/database/base.py
  20. 0 2
      core/database/chunks.py
  21. 0 1
      core/database/collections.py
  22. 0 1
      core/database/documents.py
  23. 0 13
      core/database/graphs.py
  24. 89 54
      core/database/limits.py
  25. 234 116
      core/database/users.py
  26. 1 1
      core/examples/hello_r2r.py
  27. 6 6
      core/main/abstractions.py
  28. 57 104
      core/main/api/v3/base_router.py
  29. 9 9
      core/main/api/v3/chunks_router.py
  30. 23 23
      core/main/api/v3/collections_router.py
  31. 14 14
      core/main/api/v3/conversations_router.py
  32. 38 36
      core/main/api/v3/documents_router.py
  33. 34 36
      core/main/api/v3/graph_router.py
  34. 9 11
      core/main/api/v3/indices_router.py
  35. 10 10
      core/main/api/v3/prompts_router.py
  36. 10 11
      core/main/api/v3/retrieval_router.py
  37. 6 6
      core/main/api/v3/system_router.py
  38. 49 40
      core/main/api/v3/users_router.py
  39. 0 1
      core/main/app.py
  40. 5 12
      core/main/assembly/builder.py
  41. 0 2
      core/main/assembly/factory.py
  42. 0 1
      core/main/orchestration/hatchet/ingestion_workflow.py
  43. 0 8
      core/main/orchestration/hatchet/kg_workflow.py
  44. 0 3
      core/main/orchestration/simple/ingestion_workflow.py
  45. 0 7
      core/main/orchestration/simple/kg_workflow.py
  46. 3 0
      core/main/services/auth_service.py
  47. 0 10
      core/main/services/graph_service.py
  48. 0 1
      core/main/services/ingestion_service.py
  49. 0 1
      core/main/services/management_service.py
  50. 7 3
      core/parsers/media/bmp_parser.py
  51. 25 25
      core/pipes/kg/community_summary.py
  52. 0 1
      core/pipes/kg/deduplication.py
  53. 0 6
      core/pipes/kg/deduplication_summary.py
  54. 0 1
      core/pipes/kg/extraction.py
  55. 0 1
      core/pipes/kg/storage.py
  56. 0 1
      core/pipes/retrieval/chunk_search_pipe.py
  57. 0 1
      core/pipes/retrieval/graph_search_pipe.py
  58. 0 1
      core/providers/auth/r2r_auth.py
  59. 1 1
      core/providers/crypto/bcrypt.py
  60. 26 8
      core/providers/crypto/nacl.py
  61. 0 1
      core/providers/embeddings/litellm.py
  62. 0 1
      core/providers/ingestion/r2r/base.py
  63. 0 1
      core/providers/ingestion/unstructured/base.py
  64. 0 1
      core/telemetry/telemetry_decorator.py
  65. 0 2
      migrations/versions/8077140e1e99_v3_api_database_revision.py
  66. 3 0
      pyproject.toml
  67. 1 1
      sdk/async_client.py
  68. 3 3
      sdk/base/base_client.py
  69. 0 1
      sdk/sync_client.py
  70. 1 2
      sdk/v2/management.py
  71. 3 3
      sdk/v3/graphs.py
  72. 6 14
      sdk/v3/users.py
  73. 0 1
      shared/abstractions/graph.py
  74. 2 3
      shared/abstractions/search.py
  75. 3 0
      shared/abstractions/user.py
  76. 56 0
      tests/cli/async_invoke.py
  77. 162 0
      tests/cli/commands/test_collections_cli.py
  78. 7 0
      tests/cli/commands/test_config_cli.py
  79. 168 0
      tests/cli/commands/test_conversations_cli.py
  80. 0 0
      tests/cli/commands/test_database_cli.py
  81. 330 0
      tests/cli/commands/test_documents_cli.py
  82. 312 0
      tests/cli/commands/test_graphs_cli.py
  83. 6 0
      tests/cli/commands/test_indices_cli.py
  84. 95 0
      tests/cli/commands/test_prompts_cli.py
  85. 213 0
      tests/cli/commands/test_retrieval_cli.py
  86. 338 0
      tests/cli/commands/test_system_cli.py
  87. 143 0
      tests/cli/commands/test_users_cli.py
  88. 118 0
      tests/cli/utils/test_timer.py
  89. 0 2
      tests/integration/test_conversations.py
  90. 187 0
      tests/integration/test_filters.py
  91. 37 1
      tests/integration/test_users.py
  92. 344 0
      tests/unit/conftest.py
  93. 315 0
      tests/unit/test_chunks.py
  94. 205 0
      tests/unit/test_collections.py
  95. 132 0
      tests/unit/test_conversations.py
  96. 143 0
      tests/unit/test_documents.py
  97. 449 0
      tests/unit/test_graphs.py
  98. 249 0
      tests/unit/test_limits.py

+ 148 - 12
cli/command_group.py

@@ -1,11 +1,47 @@
+# from .main import load_config
+import json
+import types
 from functools import wraps
 from functools import wraps
+from pathlib import Path
+from typing import Any, Never
 
 
 import asyncclick as click
 import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 from asyncclick.exceptions import Exit
 from asyncclick.exceptions import Exit
+from rich import box
+from rich.console import Console
+from rich.table import Table
 
 
 from sdk import R2RAsyncClient
 from sdk import R2RAsyncClient
 
 
+console = Console()
+
+CONFIG_DIR = Path.home() / ".r2r"
+CONFIG_FILE = CONFIG_DIR / "config.json"
+
+
+def load_config() -> dict[str, Any]:
+    """
+    Load the CLI config from ~/.r2r/config.json.
+    Returns an empty dict if the file doesn't exist or is invalid.
+    """
+    if not CONFIG_FILE.is_file():
+        return {}
+    try:
+        with open(CONFIG_FILE, "r", encoding="utf-8") as f:
+            data = json.load(f)
+            # Ensure we always have a dict
+            if not isinstance(data, dict):
+                return {}
+            return data
+    except (IOError, json.JSONDecodeError):
+        return {}
+
+
+def silent_exit(ctx, code=0):
+    if code != 0:
+        raise Exit(code)
+
 
 
 def deprecated_command(new_name):
 def deprecated_command(new_name):
     def decorator(f):
     def decorator(f):
@@ -23,19 +59,119 @@ def deprecated_command(new_name):
     return decorator
     return decorator
 
 
 
 
-@click.group()
+def custom_help_formatter(commands):
+    """Create a nicely formatted help table using rich"""
+    table = Table(
+        box=box.ROUNDED,
+        border_style="blue",
+        pad_edge=False,
+        collapse_padding=True,
+    )
+
+    table.add_column("Command", style="cyan", no_wrap=True)
+    table.add_column("Description", style="white")
+
+    command_groups = {
+        "Document Management": [
+            ("documents", "Document ingestion and management commands"),
+            ("collections", "Collections management commands"),
+        ],
+        "Knowledge Graph": [
+            ("graphs", "Graph creation and management commands"),
+            ("prompts", "Prompt template management"),
+        ],
+        "Interaction": [
+            ("conversations", "Conversation management commands"),
+            ("retrieval", "Knowledge retrieval commands"),
+        ],
+        "System": [
+            ("configure", "Configuration management commands"),
+            ("users", "User management commands"),
+            ("indices", "Index management commands"),
+            ("system", "System administration commands"),
+        ],
+        "Database": [
+            ("db", "Database management commands"),
+            ("upgrade", "Upgrade database schema"),
+            ("downgrade", "Downgrade database schema"),
+            ("current", "Show current schema version"),
+            ("history", "View schema migration history"),
+        ],
+    }
+
+    for group_name, group_commands in command_groups.items():
+        table.add_row(
+            f"[bold yellow]{group_name}[/bold yellow]", "", style="dim"
+        )
+        for cmd_name, description in group_commands:
+            if cmd_name in commands:
+                table.add_row(f"  {cmd_name}", commands[cmd_name].help or "")
+        table.add_row("", "")  # Add spacing between groups
+
+    return table
+
+
+class CustomGroup(click.Group):
+    def format_help(self, ctx, formatter):
+        console.print("\n[bold blue]R2R Command Line Interface[/bold blue]")
+        console.print("The most advanced AI retrieval system\n")
+
+        if self.get_help_option(ctx) is not None:
+            console.print("[bold cyan]Usage:[/bold cyan]")
+            console.print("  r2r [OPTIONS] COMMAND [ARGS]...\n")
+
+        console.print("[bold cyan]Options:[/bold cyan]")
+        console.print(
+            "  --base-url TEXT  Base URL for the API [default: https://api.cloud.sciphi.ai]"
+        )
+        console.print("  --help           Show this message and exit.\n")
+
+        console.print("[bold cyan]Commands:[/bold cyan]")
+        console.print(custom_help_formatter(self.commands))
+        console.print(
+            "\nFor more details on a specific command, run: [bold]r2r COMMAND --help[/bold]\n"
+        )
+
+
+class CustomContext(click.Context):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        super().__init__(*args, **kwargs)
+        self.exit_func = types.MethodType(silent_exit, self)
+
+    def exit(self, code: int = 0) -> Never:
+        self.exit_func(code)
+        raise SystemExit(code)
+
+
+def initialize_client(base_url: str) -> R2RAsyncClient:
+    """Initialize R2R client with API key from config if available."""
+    client = R2RAsyncClient()
+
+    try:
+        config = load_config()
+        if api_key := config.get("api_key"):
+            client.set_api_key(api_key)
+            if not client.api_key:
+                console.print(
+                    "[yellow]Warning: API key not properly set in client[/yellow]"
+                )
+
+    except Exception as e:
+        console.print(
+            "[yellow]Warning: Failed to load API key from config[/yellow]"
+        )
+        console.print_exception()
+
+    return client
+
+
+@click.group(cls=CustomGroup)
 @click.option(
 @click.option(
-    "--base-url", default="http://localhost:7272", help="Base URL for the API"
+    "--base-url",
+    default="https://cloud.sciphi.ai",
+    help="Base URL for the API",
 )
 )
 @pass_context
 @pass_context
-async def cli(ctx, base_url):
+async def cli(ctx: click.Context, base_url: str) -> None:
     """R2R CLI for all core operations."""
     """R2R CLI for all core operations."""
-
-    ctx.obj = R2RAsyncClient(base_url=base_url)
-
-    # Override the default exit behavior
-    def silent_exit(self, code=0):
-        if code != 0:
-            raise Exit(code)
-
-    ctx.exit = silent_exit.__get__(ctx)
+    ctx.obj = initialize_client(base_url)

+ 72 - 48
cli/commands/collections.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 
 
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -17,17 +17,21 @@ def collections():
 @click.argument("name", required=True, type=str)
 @click.argument("name", required=True, type=str)
 @click.option("--description", type=str)
 @click.option("--description", type=str)
 @pass_context
 @pass_context
-async def create(ctx, name, description):
+async def create(ctx: click.Context, name, description):
     """Create a collection."""
     """Create a collection."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.create(
-            name=name,
-            description=description,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.create(
+                name=name,
+                description=description,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @collections.command()
 @collections.command()
@@ -43,46 +47,57 @@ async def create(ctx, name, description):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(ctx: click.Context, ids, offset, limit):
     """Get an overview of collections."""
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
     ids = list(ids) if ids else None
     ids = list(ids) if ids else None
 
 
-    with timer():
-        response = await client.collections.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @collections.command()
 @collections.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a collection by ID."""
     """Retrieve a collection by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @collections.command()
 @collections.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a collection by ID."""
     """Delete a collection by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.collections.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @collections.command()
 @collections.command()
@@ -98,19 +113,24 @@ async def delete(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_documents(ctx, id, offset, limit):
+async def list_documents(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.list_documents(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.collections.list_documents(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
 
 
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @collections.command()
 @collections.command()
@@ -126,16 +146,20 @@ async def list_documents(ctx, id, offset, limit):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_users(ctx, id, offset, limit):
+async def list_users(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.list_users(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list_users(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 165 - 0
cli/commands/config.py

@@ -0,0 +1,165 @@
+import configparser
+from pathlib import Path
+
+import asyncclick as click
+from rich.box import ROUNDED
+from rich.console import Console
+from rich.table import Table
+
+console = Console()
+
+
+def get_config_dir():
+    """Create and return the config directory path."""
+    config_dir = Path.home() / ".r2r"
+    config_dir.mkdir(exist_ok=True)
+    return config_dir
+
+
+def get_config_file():
+    """Get the config file path."""
+    return get_config_dir() / "config.ini"
+
+
+class Config:
+    _instance = None
+    _config = configparser.ConfigParser()
+    _config_file = get_config_file()
+
+    @classmethod
+    def load(cls):
+        """Load the configuration file."""
+        if cls._config_file.exists():
+            cls._config.read(cls._config_file)
+
+    @classmethod
+    def save(cls):
+        """Save the configuration to file."""
+        with open(cls._config_file, "w") as f:
+            cls._config.write(f)
+
+    @classmethod
+    def get_credentials(cls, service):
+        """Get credentials for a specific service."""
+        cls.load()  # Ensure we have latest config
+        return dict(cls._config[service]) if service in cls._config else {}
+
+    @classmethod
+    def set_credentials(cls, service, credentials):
+        """Set credentials for a specific service."""
+        cls.load()  # Ensure we have latest config
+        if service not in cls._config:
+            cls._config[service] = {}
+        cls._config[service].update(credentials)
+        cls.save()
+
+
+@click.group()
+def configure():
+    """Configuration management commands."""
+    pass
+
+
+@configure.command()
+@click.confirmation_option(
+    prompt="Are you sure you want to reset all settings?"
+)
+async def reset():
+    """Reset all configuration to defaults."""
+    if Config._config_file.exists():
+        Config._config_file.unlink()  # Delete the config file
+    Config._config = configparser.ConfigParser()  # Reset the config in memory
+
+    # Set default values
+    Config.set_credentials(
+        "Base URL", {"base_url": "https://api.cloud.sciphi.ai"}
+    )
+
+    console.print(
+        "[green]Successfully reset configuration to defaults[/green]"
+    )
+
+
+@configure.command()
+@click.option(
+    "--api-key",
+    prompt="SciPhi API Key",
+    hide_input=True,
+    help="API key for SciPhi cloud",
+)
+async def key(api_key):
+    """Configure SciPhi cloud API credentials."""
+    Config.set_credentials("SciPhi", {"api_key": api_key})
+    console.print(
+        "[green]Successfully configured SciPhi cloud credentials[/green]"
+    )
+
+
+@configure.command()
+@click.option(
+    "--base-url",
+    prompt="R2R Base URL",
+    default="https://api.cloud.sciphi.ai",
+    hide_input=False,
+    help="Host URL for R2R",
+)
+async def host(host):
+    """Configure R2R host URL."""
+    Config.set_credentials("Host", {"R2R_HOST": host})
+    console.print("[green]Successfully configured R2R host URL[/green]")
+
+
+@configure.command()
+async def view():
+    """View current configuration."""
+    Config.load()
+
+    table = Table(
+        title="[bold blue]R2R Settings[/bold blue]",
+        show_header=True,
+        header_style="bold white on blue",
+        border_style="blue",
+        box=ROUNDED,
+        pad_edge=False,
+        collapse_padding=True,
+    )
+
+    table.add_column(
+        "Section", justify="left", style="bright_yellow", no_wrap=True
+    )
+    table.add_column(
+        "Key", justify="left", style="bright_magenta", no_wrap=True
+    )
+    table.add_column(
+        "Value", justify="left", style="bright_green", no_wrap=True
+    )
+
+    # Group related configurations together
+    config_groups = {
+        "API Credentials": ["SciPhi"],
+        "Server Settings": ["Base URL", "Port"],
+    }
+
+    for group_name, sections in config_groups.items():
+        has_items = any(section in Config._config for section in sections)
+        if has_items:
+            table.add_row(
+                f"[bold]{group_name}[/bold]", "", "", style="bright_blue"
+            )
+
+            for section in sections:
+                if section in Config._config:
+                    for key, value in Config._config[section].items():
+                        # Mask API keys for security
+                        displayed_value = (
+                            f"****{value[-4:]}"
+                            if "api_key" in key.lower()
+                            else value
+                        )
+                        table.add_row(
+                            f"  {section}", key.lower(), displayed_value
+                        )
+
+    console.print("\n")
+    console.print(table)
+    console.print("\n")

+ 56 - 52
cli/commands/conversations.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 
 
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -15,14 +15,18 @@ def conversations():
 
 
 @conversations.command()
 @conversations.command()
 @pass_context
 @pass_context
-async def create(ctx):
+async def create(ctx: click.Context):
     """Create a conversation."""
     """Create a conversation."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.conversations.create()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.create()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @conversations.command()
 @conversations.command()
@@ -38,62 +42,58 @@ async def create(ctx):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(ctx: click.Context, ids, offset, limit):
     """Get an overview of conversations."""
     """Get an overview of conversations."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
     ids = list(ids) if ids else None
     ids = list(ids) if ids else None
 
 
-    with timer():
-        response = await client.conversations.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @conversations.command()
 @conversations.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a collection by ID."""
     """Retrieve a collection by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.conversations.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @conversations.command()
 @conversations.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a collection by ID."""
     """Delete a collection by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.conversations.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
-
-
-@conversations.command()
-@click.argument("id", required=True, type=str)
-@pass_context
-async def list_branches(ctx, id):
-    """List all branches in a conversation."""
-    client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.conversations.list_branches(
-            id=id,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.conversations.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @conversations.command()
 @conversations.command()
@@ -109,16 +109,20 @@ async def list_branches(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_users(ctx, id, offset, limit):
+async def list_users(ctx: click.Context, id, offset, limit):
     """Get an overview of collections."""
     """Get an overview of collections."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.collections.list_users(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.collections.list_users(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 4 - 6
cli/commands/database.py

@@ -82,7 +82,6 @@ async def upgrade(schema, revision):
         click.echo(
         click.echo(
             f"Running database upgrade for schema {schema or 'default'}..."
             f"Running database upgrade for schema {schema or 'default'}..."
         )
         )
-        print(f"Upgrading revision = {revision}")
         command = f"upgrade {revision}" if revision else "upgrade"
         command = f"upgrade {revision}" if revision else "upgrade"
         result = await run_alembic_command(command, schema_name=schema)
         result = await run_alembic_command(command, schema_name=schema)
 
 
@@ -104,11 +103,10 @@ async def upgrade(schema, revision):
 @click.option("--revision", help="Downgrade to a specific revision")
 @click.option("--revision", help="Downgrade to a specific revision")
 async def downgrade(schema, revision):
 async def downgrade(schema, revision):
     """Downgrade database schema to the previous revision or a specific revision."""
     """Downgrade database schema to the previous revision or a specific revision."""
-    if not revision:
-        if not click.confirm(
-            "No revision specified. This will downgrade the database by one revision. Continue?"
-        ):
-            return
+    if not revision and not click.confirm(
+        "No revision specified. This will downgrade the database by one revision. Continue?"
+    ):
+        return
 
 
     try:
     try:
         db_url = get_database_url_from_env(log=False)
         db_url = get_database_url_from_env(log=False)

+ 303 - 127
cli/commands/documents.py

@@ -2,15 +2,23 @@ import json
 import os
 import os
 import tempfile
 import tempfile
 import uuid
 import uuid
+from builtins import list as _list
+from typing import Any, Optional, Sequence
 from urllib.parse import urlparse
 from urllib.parse import urlparse
+from uuid import UUID
 
 
 import asyncclick as click
 import asyncclick as click
 import requests
 import requests
 from asyncclick import pass_context
 from asyncclick import pass_context
+from rich.box import ROUNDED
+from rich.console import Console
+from rich.table import Table
 
 
 from cli.utils.param_types import JSON
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
+
+console = Console()
 
 
 
 
 @click.group()
 @click.group()
@@ -31,15 +39,21 @@ def documents():
     "--run-without-orchestration", is_flag=True, help="Run with orchestration"
     "--run-without-orchestration", is_flag=True, help="Run with orchestration"
 )
 )
 @pass_context
 @pass_context
-async def create(ctx, file_paths, ids, metadatas, run_without_orchestration):
+async def create(
+    ctx: click.Context,
+    file_paths: tuple[str, ...],
+    ids: Optional[tuple[str, ...]] = None,
+    metadatas: Optional[Sequence[dict[str, Any]]] = None,
+    run_without_orchestration: bool = False,
+):
     """Ingest files into R2R."""
     """Ingest files into R2R."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
     run_with_orchestration = not run_without_orchestration
-    responses = []
+    responses: _list[dict[str, Any]] = []
 
 
     for idx, file_path in enumerate(file_paths):
     for idx, file_path in enumerate(file_paths):
         with timer():
         with timer():
-            current_id = [ids[idx]] if ids and idx < len(ids) else None
+            current_id = ids[idx] if ids and idx < len(ids) else None
             current_metadata = (
             current_metadata = (
                 metadatas[idx] if metadatas and idx < len(metadatas) else None
                 metadatas[idx] if metadatas and idx < len(metadatas) else None
             )
             )
@@ -47,74 +61,181 @@ async def create(ctx, file_paths, ids, metadatas, run_without_orchestration):
             click.echo(
             click.echo(
                 f"Processing file {idx + 1}/{len(file_paths)}: {file_path}"
                 f"Processing file {idx + 1}/{len(file_paths)}: {file_path}"
             )
             )
-            response = await client.documents.create(
-                file_path=file_path,
-                metadata=current_metadata,
-                id=current_id,
-                run_with_orchestration=run_with_orchestration,
-            )
-            responses.append(response)
-            click.echo(json.dumps(response, indent=2))
-            click.echo("-" * 40)
+            try:
+                response = await client.documents.create(
+                    file_path=file_path,
+                    metadata=current_metadata,
+                    id=current_id,
+                    run_with_orchestration=run_with_orchestration,
+                )
+                responses.append(response)  # type: ignore
+                click.echo(json.dumps(response, indent=2))
+                click.echo("-" * 40)
+            except R2RException as e:
+                click.echo(str(e), err=True)
+            except Exception as e:
+                click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
     click.echo(f"\nProcessed {len(responses)} files successfully.")
     click.echo(f"\nProcessed {len(responses)} files successfully.")
 
 
 
 
 @documents.command()
 @documents.command()
-@click.argument("file_path", required=True, type=click.Path(exists=True))
-@click.option("--id", required=True, help="Existing document ID to update")
+@click.option("--ids", multiple=True, help="Document IDs to fetch")
 @click.option(
 @click.option(
-    "--metadata", type=JSON, help="Metadatas for ingestion as a JSON string"
+    "--offset",
+    default=0,
+    help="The offset to start from. Defaults to 0.",
 )
 )
 @click.option(
 @click.option(
-    "--run-without-orchestration", is_flag=True, help="Run with orchestration"
+    "--limit",
+    default=100,
+    help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def update(ctx, file_path, id, metadata, run_without_orchestration):
-    """Update an existing file in R2R."""
+async def list(
+    ctx: click.Context,
+    ids: Optional[tuple[str, ...]] = None,
+    offset: int = 0,
+    limit: int = 100,
+) -> None:
+    """Get an overview of documents."""
+    ids = list(ids) if ids else None
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-    run_with_orchestration = not run_without_orchestration
-    responses = []
 
 
-    with timer():
-        click.echo(f"Updating file {id}: {file_path}")
-        response = await client.documents.update(
-            file_path=file_path,
-            metadata=metadata,
-            id=id,
-            run_with_orchestration=run_with_orchestration,
+    try:
+        with timer():
+            response = await client.documents.list(
+                ids=ids,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Documents[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
         )
-        responses.append(response)
-        click.echo(json.dumps(response, indent=2))
-        click.echo("-" * 40)
 
 
-    click.echo(f"Updated file {id} file successfully.")
+        # Add columns based on your document structure
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Type", style="bright_magenta")
+        table.add_column("Title", style="bright_green")
+        table.add_column("Ingestion Status", style="bright_cyan")
+        table.add_column("Extraction Status", style="bright_cyan")
+        table.add_column("Summary", style="bright_white")
+        table.add_column("Created At", style="bright_white")
+
+        for document in response["results"]:  # type: ignore
+            table.add_row(
+                document.get("id", ""),
+                document.get("document_type", ""),
+                document.get("title", ""),
+                document.get("ingestion_status", ""),
+                document.get("extraction_status", ""),
+                document.get("summary", ""),
+                document.get("created_at", "")[:19],
+            )
+
+        console = Console()
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} documents (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 
 
 @documents.command()
 @documents.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id: UUID):
     """Retrieve a document by ID."""
     """Retrieve a document by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
 
-    with timer():
-        response = await client.documents.retrieve(id=id)
+    try:
+        with timer():
+            response = await client.documents.retrieve(id=id)
+
+        # Get the actual document data from the results
+        document = response["results"]  # type: ignore
+
+        metadata_table = Table(
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            title="[bold blue]Document Details[/bold blue]",
+            show_lines=True,
+        )
 
 
-    click.echo(json.dumps(response, indent=2))
+        metadata_table.add_column("Field", style="bright_yellow")
+        metadata_table.add_column("Value", style="bright_white")
+
+        # Add core document information
+        core_fields = [
+            ("ID", document.get("id", "")),
+            ("Type", document.get("document_type", "")),
+            ("Title", document.get("title", "")),
+            ("Created At", document.get("created_at", "")[:19]),
+            ("Updated At", document.get("updated_at", "")[:19]),
+            ("Ingestion Status", document.get("ingestion_status", "")),
+            ("Extraction Status", document.get("extraction_status", "")),
+            ("Size", f"{document.get('size_in_bytes', 0):,} bytes"),
+        ]
+
+        for field, value in core_fields:
+            metadata_table.add_row(field, str(value))
+
+        # Add metadata section if it exists
+        if "metadata" in document:
+            metadata_table.add_row(
+                "[bold]Metadata[/bold]", "", style="bright_blue"
+            )
+            for key, value in document["metadata"].items():
+                metadata_table.add_row(f"  {key}", str(value))
+
+        # Add summary if it exists
+        if "summary" in document:
+            metadata_table.add_row(
+                "[bold]Summary[/bold]",
+                document["summary"],
+            )
+
+        console.print("\n")
+        console.print(metadata_table)
+        console.print("\n")
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 
 
 @documents.command()
 @documents.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def delete(ctx, id):
+async def delete(ctx: click.Context, id):
     """Delete a document by ID."""
     """Delete a document by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.documents.delete(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.delete(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @documents.command()
 @documents.command()
@@ -130,18 +251,53 @@ async def delete(ctx, id):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_chunks(ctx, id, offset, limit):
-    """List collections for a specific document."""
+async def list_chunks(ctx: click.Context, id, offset, limit):
+    """List chunks for a specific document."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
 
-    with timer():
-        response = await client.documents.list_chunks(
-            id=id,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await client.documents.list_chunks(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Document Chunks[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
         )
 
 
-    click.echo(json.dumps(response, indent=2))
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Text", style="bright_white")
+
+        for chunk in response["results"]:  # type: ignore
+            table.add_row(
+                chunk.get("id", ""),
+                (
+                    chunk.get("text", "")[:200] + "..."
+                    if len(chunk.get("text", "")) > 200
+                    else chunk.get("text", "")
+                ),
+            )
+
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} chunks (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 
 
 @documents.command()
 @documents.command()
@@ -157,18 +313,53 @@ async def list_chunks(ctx, id, offset, limit):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_collections(ctx, id, offset, limit):
+async def list_collections(ctx: click.Context, id, offset, limit):
     """List collections for a specific document."""
     """List collections for a specific document."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
+    console = Console()
 
 
-    with timer():
-        response = await client.documents.list_collections(
-            id=id,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await client.documents.list_collections(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
+
+        table = Table(
+            title="[bold blue]Document Collections[/bold blue]",
+            show_header=True,
+            header_style="bold white on blue",
+            border_style="blue",
+            box=ROUNDED,
+            pad_edge=False,
+            collapse_padding=True,
+            show_lines=True,
         )
         )
 
 
-    click.echo(json.dumps(response, indent=2))
+        table.add_column("ID", style="bright_yellow", no_wrap=True)
+        table.add_column("Name", style="bright_green")
+        table.add_column("Description", style="bright_white")
+        table.add_column("Created At", style="bright_white")
+
+        for collection in response["results"]:  # type: ignore
+            table.add_row(
+                collection.get("id", ""),
+                collection.get("name", ""),
+                collection.get("description", ""),
+                collection.get("created_at", "")[:19],
+            )
+
+        console.print("\n")
+        console.print(table)
+        console.print(
+            f"\n[dim]Showing {len(response['results'])} collections (offset: {offset}, limit: {limit})[/dim]"  # type: ignore
+        )
+
+    except R2RException as e:
+        console.print(f"[bold red]Error:[/bold red] {str(e)}")
+    except Exception as e:
+        console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
 
 
 
 
 # TODO
 # TODO
@@ -228,7 +419,9 @@ async def ingest_files_from_urls(client, urls):
     help="Run without orchestration",
     help="Run without orchestration",
 )
 )
 @pass_context
 @pass_context
-async def extract(ctx, id, run_type, settings, run_without_orchestration):
+async def extract(
+    ctx: click.Context, id, run_type, settings, run_without_orchestration
+):
     """Extract entities and relationships from a document."""
     """Extract entities and relationships from a document."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
     run_with_orchestration = not run_without_orchestration
@@ -262,19 +455,25 @@ async def extract(ctx, id, run_type, settings, run_without_orchestration):
     help="Include embeddings in response",
     help="Include embeddings in response",
 )
 )
 @pass_context
 @pass_context
-async def list_entities(ctx, id, offset, limit, include_embeddings):
+async def list_entities(
+    ctx: click.Context, id, offset, limit, include_embeddings
+):
     """List entities extracted from a document."""
     """List entities extracted from a document."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.documents.list_entities(
-            id=id,
-            offset=offset,
-            limit=limit,
-            include_embeddings=include_embeddings,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.list_entities(
+                id=id,
+                offset=offset,
+                limit=limit,
+                include_embeddings=include_embeddings,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @documents.command()
 @documents.command()
@@ -301,50 +500,52 @@ async def list_entities(ctx, id, offset, limit, include_embeddings):
 )
 )
 @pass_context
 @pass_context
 async def list_relationships(
 async def list_relationships(
-    ctx, id, offset, limit, entity_names, relationship_types
+    ctx: click.Context, id, offset, limit, entity_names, relationship_types
 ):
 ):
     """List relationships extracted from a document."""
     """List relationships extracted from a document."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.documents.list_relationships(
-            id=id,
-            offset=offset,
-            limit=limit,
-            entity_names=list(entity_names) if entity_names else None,
-            relationship_types=(
-                list(relationship_types) if relationship_types else None
-            ),
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.documents.list_relationships(
+                id=id,
+                offset=offset,
+                limit=limit,
+                entity_names=list(entity_names) if entity_names else None,
+                relationship_types=(
+                    list(relationship_types) if relationship_types else None
+                ),
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @documents.command()
 @documents.command()
-@click.option(
-    "--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)"
-)
-@click.option(
-    "--v3", is_flag=True, help="use aristotle_v3.txt (a larger file)"
-)
 @pass_context
 @pass_context
-async def create_sample(ctx, v2=True, v3=False):
+async def create_sample(ctx: click.Context) -> None:
     """Ingest the first sample file into R2R."""
     """Ingest the first sample file into R2R."""
-    sample_file_url = f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt"
+    sample_file_url = "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt"
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await ingest_files_from_urls(client, [sample_file_url])
-    click.echo(
-        f"Sample file ingestion completed. Ingest files response:\n\n{response}"
-    )
+    try:
+        with timer():
+            response = await ingest_files_from_urls(client, [sample_file_url])
+        click.echo(
+            f"Sample file ingestion completed. Ingest files response:\n\n{response}"
+        )
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @documents.command()
 @documents.command()
 @pass_context
 @pass_context
-async def create_samples(ctx):
+async def create_samples(ctx: click.Context) -> None:
     """Ingest multiple sample files into R2R."""
     """Ingest multiple sample files into R2R."""
-    client: R2RAsyncClient = ctx.obj
     urls = [
     urls = [
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_3.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_3.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_4.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_4.html",
@@ -356,38 +557,13 @@ async def create_samples(ctx):
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_2.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_2.html",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt",
         "https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle.txt",
     ]
     ]
-    with timer():
-        response = await ingest_files_from_urls(client, urls)
-
-    click.echo(
-        f"Sample files ingestion completed. Ingest files response:\n\n{response}"
-    )
-
-
-@documents.command()
-@click.option("--ids", multiple=True, help="Document IDs to fetch")
-@click.option(
-    "--offset",
-    default=0,
-    help="The offset to start from. Defaults to 0.",
-)
-@click.option(
-    "--limit",
-    default=100,
-    help="The maximum number of nodes to return. Defaults to 100.",
-)
-@pass_context
-async def list(ctx, ids, offset, limit):
-    """Get an overview of documents."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-    ids = list(ids) if ids else None
 
 
-    with timer():
-        response = await client.documents.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
+    try:
+        with timer():
+            response = await ingest_files_from_urls(client, urls)
+        click.echo(
+            f"Sample files ingestion completed. Ingest files response:\n\n{response}"
         )
         )
-
-    for document in response["results"]:
-        click.echo(document)
+    except R2RException as e:
+        click.echo(str(e), err=True)

+ 220 - 144
cli/commands/graphs.py

@@ -5,7 +5,7 @@ from asyncclick import pass_context
 
 
 from cli.utils.param_types import JSON
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -29,45 +29,59 @@ def graphs():
     help="The maximum number of graphs to return. Defaults to 100.",
     help="The maximum number of graphs to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list(ctx, collection_ids, offset, limit):
+async def list(ctx: click.Context, collection_ids, offset, limit):
     """List available graphs."""
     """List available graphs."""
-    client: R2RAsyncClient = ctx.obj
     collection_ids = list(collection_ids) if collection_ids else None
     collection_ids = list(collection_ids) if collection_ids else None
+    client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.list(
-            collection_ids=collection_ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list(
+                collection_ids=collection_ids,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, collection_id):
+async def retrieve(ctx: click.Context, collection_id):
     """Retrieve a specific graph by collection ID."""
     """Retrieve a specific graph by collection ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.retrieve(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.retrieve(
+                collection_id=collection_id
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
 @pass_context
-async def reset(ctx, collection_id):
+async def reset(ctx: click.Context, collection_id):
     """Reset a graph, removing all its data."""
     """Reset a graph, removing all its data."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.reset(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.reset(collection_id=collection_id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -75,18 +89,22 @@ async def reset(ctx, collection_id):
 @click.option("--name", help="New name for the graph")
 @click.option("--name", help="New name for the graph")
 @click.option("--description", help="New description for the graph")
 @click.option("--description", help="New description for the graph")
 @pass_context
 @pass_context
-async def update(ctx, collection_id, name, description):
+async def update(ctx: click.Context, collection_id, name, description):
     """Update graph information."""
     """Update graph information."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.update(
-            collection_id=collection_id,
-            name=name,
-            description=description,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.update(
+                collection_id=collection_id,
+                name=name,
+                description=description,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -102,52 +120,64 @@ async def update(ctx, collection_id, name, description):
     help="The maximum number of entities to return. Defaults to 100.",
     help="The maximum number of entities to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_entities(ctx, collection_id, offset, limit):
+async def list_entities(ctx: click.Context, collection_id, offset, limit):
     """List entities in a graph."""
     """List entities in a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.list_entities(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_entities(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @pass_context
 @pass_context
-async def get_entity(ctx, collection_id, entity_id):
+async def get_entity(ctx: click.Context, collection_id, entity_id):
     """Get entity information from a graph."""
     """Get entity information from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.get_entity(
-            collection_id=collection_id,
-            entity_id=entity_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_entity(
+                collection_id=collection_id,
+                entity_id=entity_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @click.argument("entity_id", required=True, type=str)
 @pass_context
 @pass_context
-async def remove_entity(ctx, collection_id, entity_id):
+async def remove_entity(ctx: click.Context, collection_id, entity_id):
     """Remove an entity from a graph."""
     """Remove an entity from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.remove_entity(
-            collection_id=collection_id,
-            entity_id=entity_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.remove_entity(
+                collection_id=collection_id,
+                entity_id=entity_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -163,52 +193,66 @@ async def remove_entity(ctx, collection_id, entity_id):
     help="The maximum number of relationships to return. Defaults to 100.",
     help="The maximum number of relationships to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_relationships(ctx, collection_id, offset, limit):
+async def list_relationships(ctx: click.Context, collection_id, offset, limit):
     """List relationships in a graph."""
     """List relationships in a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.graphs.list_relationships(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_relationships(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @pass_context
 @pass_context
-async def get_relationship(ctx, collection_id, relationship_id):
+async def get_relationship(ctx: click.Context, collection_id, relationship_id):
     """Get relationship information from a graph."""
     """Get relationship information from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.get_relationship(
-            collection_id=collection_id,
-            relationship_id=relationship_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_relationship(
+                collection_id=collection_id,
+                relationship_id=relationship_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @click.argument("relationship_id", required=True, type=str)
 @pass_context
 @pass_context
-async def remove_relationship(ctx, collection_id, relationship_id):
+async def remove_relationship(
+    ctx: click.Context, collection_id, relationship_id
+):
     """Remove a relationship from a graph."""
     """Remove a relationship from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.remove_relationship(
-            collection_id=collection_id,
-            relationship_id=relationship_id,
-        )
+    try:
+        with timer():
+            response = await client.graphs.remove_relationship(
+                collection_id=collection_id,
+                relationship_id=relationship_id,
+            )
 
 
-    click.echo(json.dumps(response, indent=2))
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -224,21 +268,29 @@ async def remove_relationship(ctx, collection_id, relationship_id):
 )
 )
 @pass_context
 @pass_context
 async def build(
 async def build(
-    ctx, collection_id, settings, run_type, run_without_orchestration
+    ctx: click.Context,
+    collection_id,
+    settings,
+    run_type,
+    run_without_orchestration,
 ):
 ):
     """Build a graph with specified settings."""
     """Build a graph with specified settings."""
-    client: R2RAsyncClient = ctx.obj
     run_with_orchestration = not run_without_orchestration
     run_with_orchestration = not run_without_orchestration
+    client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.build(
-            collection_id=collection_id,
-            settings=settings,
-            run_type=run_type,
-            run_with_orchestration=run_with_orchestration,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.build(
+                collection_id=collection_id,
+                settings=settings,
+                run_type=run_type,
+                run_with_orchestration=run_with_orchestration,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -254,35 +306,43 @@ async def build(
     help="The maximum number of communities to return. Defaults to 100.",
     help="The maximum number of communities to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_communities(ctx, collection_id, offset, limit):
+async def list_communities(ctx: click.Context, collection_id, offset, limit):
     """List communities in a graph."""
     """List communities in a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.list_communities(
-            collection_id=collection_id,
-            offset=offset,
-            limit=limit,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.list_communities(
+                collection_id=collection_id,
+                offset=offset,
+                limit=limit,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @pass_context
 @pass_context
-async def get_community(ctx, collection_id, community_id):
+async def get_community(ctx: click.Context, collection_id, community_id):
     """Get community information from a graph."""
     """Get community information from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.get_community(
-            collection_id=collection_id,
-            community_id=community_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.get_community(
+                collection_id=collection_id,
+                community_id=community_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
@@ -305,7 +365,7 @@ async def get_community(ctx, collection_id, community_id):
 )
 )
 @pass_context
 @pass_context
 async def update_community(
 async def update_community(
-    ctx,
+    ctx: click.Context,
     collection_id,
     collection_id,
     community_id,
     community_id,
     name,
     name,
@@ -319,64 +379,80 @@ async def update_community(
     """Update community information."""
     """Update community information."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.update_community(
-            collection_id=collection_id,
-            community_id=community_id,
-            name=name,
-            summary=summary,
-            findings=findings,
-            rating=rating,
-            rating_explanation=rating_explanation,
-            level=level,
-            attributes=attributes,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.update_community(
+                collection_id=collection_id,
+                community_id=community_id,
+                name=name,
+                summary=summary,
+                findings=findings,
+                rating=rating,
+                rating_explanation=rating_explanation,
+                level=level,
+                attributes=attributes,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @click.argument("community_id", required=True, type=str)
 @pass_context
 @pass_context
-async def delete_community(ctx, collection_id, community_id):
+async def delete_community(ctx: click.Context, collection_id, community_id):
     """Delete a community from a graph."""
     """Delete a community from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.delete_community(
-            collection_id=collection_id,
-            community_id=community_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.delete_community(
+                collection_id=collection_id,
+                community_id=community_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
 @pass_context
-async def pull(ctx, collection_id):
+async def pull(ctx: click.Context, collection_id):
     """Pull documents into a graph."""
     """Pull documents into a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.pull(collection_id=collection_id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.pull(collection_id=collection_id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @graphs.command()
 @graphs.command()
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("document_id", required=True, type=str)
 @click.argument("document_id", required=True, type=str)
 @pass_context
 @pass_context
-async def remove_document(ctx, collection_id, document_id):
+async def remove_document(ctx: click.Context, collection_id, document_id):
     """Remove a document from a graph."""
     """Remove a document from a graph."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.graphs.remove_document(
-            collection_id=collection_id,
-            document_id=document_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.graphs.remove_document(
+                collection_id=collection_id,
+                document_id=document_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 39 - 41
cli/commands/indices.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 
 
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -25,65 +25,63 @@ def indices():
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list(ctx, offset, limit):
+async def list(ctx: click.Context, offset, limit):
     """Get an overview of indices."""
     """Get an overview of indices."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.indices.list(
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.indices.list(
+                offset=offset,
+                limit=limit,
+            )
 
 
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @indices.command()
 @indices.command()
 @click.argument("index_name", required=True, type=str)
 @click.argument("index_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, index_name, table_name):
+async def retrieve(ctx: click.Context, index_name, table_name):
     """Retrieve an index by name."""
     """Retrieve an index by name."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.indices.retrieve(
-            index_name=index_name,
-            table_name=table_name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.indices.retrieve(
+                index_name=index_name,
+                table_name=table_name,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @indices.command()
 @indices.command()
 @click.argument("index_name", required=True, type=str)
 @click.argument("index_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @click.argument("table_name", required=True, type=str)
 @pass_context
 @pass_context
-async def delete(ctx, index_name, table_name):
+async def delete(ctx: click.Context, index_name, table_name):
     """Delete an index by name."""
     """Delete an index by name."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.indices.retrieve(
-            index_name=index_name,
-            table_name=table_name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
-
-
-@indices.command()
-@click.argument("id", required=True, type=str)
-@pass_context
-async def list_branches(ctx, id):
-    """List all branches in a conversation."""
-    client: R2RAsyncClient = ctx.obj
-
-    with timer():
-        response = await client.indices.list_branches(
-            id=id,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    try:
+        with timer():
+            response = await client.indices.retrieve(
+                index_name=index_name,
+                table_name=table_name,
+            )
+
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 36 - 23
cli/commands/prompts.py

@@ -4,7 +4,7 @@ import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 
 
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -15,15 +15,20 @@ def prompts():
 
 
 @prompts.command()
 @prompts.command()
 @pass_context
 @pass_context
-async def list(ctx):
+async def list(ctx: click.Context):
     """Get an overview of prompts."""
     """Get an overview of prompts."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.prompts.list()
+    try:
+        with timer():
+            response = await client.prompts.list()
 
 
-    for prompt in response["results"]:
-        click.echo(json.dumps(prompt, indent=2))
+        for prompt in response["results"]:  # type: ignore
+            click.echo(json.dumps(prompt, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @prompts.command()
 @prompts.command()
@@ -31,30 +36,38 @@ async def list(ctx):
 @click.option("--inputs", default=None, type=str)
 @click.option("--inputs", default=None, type=str)
 @click.option("--prompt-override", default=None, type=str)
 @click.option("--prompt-override", default=None, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, name, inputs, prompt_override):
+async def retrieve(ctx: click.Context, name, inputs, prompt_override):
     """Retrieve an prompts by name."""
     """Retrieve an prompts by name."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.prompts.retrieve(
-            name=name,
-            inputs=inputs,
-            prompt_override=prompt_override,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.prompts.retrieve(
+                name=name,
+                inputs=inputs,
+                prompt_override=prompt_override,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @prompts.command()
 @prompts.command()
 @click.argument("name", required=True, type=str)
 @click.argument("name", required=True, type=str)
 @pass_context
 @pass_context
-async def delete(ctx, name):
-    """Delete an index by name."""
+async def delete(ctx: click.Context, name):
+    """Delete a prompt by name."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.prompts.delete(
-            name=name,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.prompts.delete(
+                name=name,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 52 - 39
cli/commands/retrieval.py

@@ -5,7 +5,7 @@ from asyncclick import pass_context
 
 
 from cli.utils.param_types import JSON
 from cli.utils.param_types import JSON
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -52,9 +52,8 @@ def retrieval():
     help="Use search over document chunks?",
     help="Use search over document chunks?",
 )
 )
 @pass_context
 @pass_context
-async def search(ctx, query, **kwargs):
+async def search(ctx: click.Context, query, **kwargs):
     """Perform a search query."""
     """Perform a search query."""
-    client: R2RAsyncClient = ctx.obj
     search_settings = {
     search_settings = {
         k: v
         k: v
         for k, v in kwargs.items()
         for k, v in kwargs.items()
@@ -78,28 +77,36 @@ async def search(ctx, query, **kwargs):
     if chunk_search_enabled != None:
     if chunk_search_enabled != None:
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
 
 
-    with timer():
-        results = await client.retrieval.search(
-            query,
-            "custom",
-            search_settings,
-        )
-
-        if isinstance(results, dict) and "results" in results:
-            results = results["results"]
-
-        if "chunk_search_results" in results:
-            click.echo("Vector search results:")
-            for result in results["chunk_search_results"]:
-                click.echo(json.dumps(result, indent=2))
+    client: R2RAsyncClient = ctx.obj
 
 
-        if (
-            "graph_search_results" in results
-            and results["graph_search_results"]
-        ):
-            click.echo("KG search results:")
-            for result in results["graph_search_results"]:
-                click.echo(json.dumps(result, indent=2))
+    print("client.base_url = ", client.base_url)
+    try:
+        with timer():
+            results = await client.retrieval.search(
+                query,
+                "custom",
+                search_settings,
+            )
+
+            if isinstance(results, dict) and "results" in results:
+                results = results["results"]
+
+            if "chunk_search_results" in results:  # type: ignore
+                click.echo("Vector search results:")
+                for result in results["chunk_search_results"]:  # type: ignore
+                    click.echo(json.dumps(result, indent=2))
+
+            if (
+                "graph_search_results" in results  # type: ignore
+                and results["graph_search_results"]  # type: ignore
+            ):
+                click.echo("KG search results:")
+                for result in results["graph_search_results"]:  # type: ignore
+                    click.echo(json.dumps(result, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @retrieval.command()
 @retrieval.command()
@@ -142,9 +149,8 @@ async def search(ctx, query, **kwargs):
 @click.option("--stream", is_flag=True, help="Stream the RAG response")
 @click.option("--stream", is_flag=True, help="Stream the RAG response")
 @click.option("--rag-model", default=None, help="Model for RAG")
 @click.option("--rag-model", default=None, help="Model for RAG")
 @pass_context
 @pass_context
-async def rag(ctx, query, **kwargs):
+async def rag(ctx: click.Context, query, **kwargs):
     """Perform a RAG query."""
     """Perform a RAG query."""
-    client: R2RAsyncClient = ctx.obj
     rag_generation_config = {
     rag_generation_config = {
         "stream": kwargs.get("stream", False),
         "stream": kwargs.get("stream", False),
     }
     }
@@ -174,16 +180,23 @@ async def rag(ctx, query, **kwargs):
     if chunk_search_enabled != None:
     if chunk_search_enabled != None:
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
         search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
 
 
-    with timer():
-        response = await client.retrieval.rag(
-            query=query,
-            rag_generation_config=rag_generation_config,
-            search_settings={**search_settings},
-        )
-
-        if rag_generation_config.get("stream"):
-            async for chunk in response:
-                click.echo(chunk, nl=False)
-            click.echo()
-        else:
-            click.echo(json.dumps(response["results"]["completion"], indent=2))
+    client: R2RAsyncClient = ctx.obj
+
+    try:
+        with timer():
+            response = await client.retrieval.rag(
+                query=query,
+                rag_generation_config=rag_generation_config,
+                search_settings={**search_settings},
+            )
+
+            if rag_generation_config.get("stream"):
+                async for chunk in response:  # type: ignore
+                    click.echo(chunk, nl=False)
+                click.echo()
+            else:
+                click.echo(json.dumps(response["results"]["completion"], indent=2))  # type: ignore
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 40 - 17
cli/commands/system.py

@@ -18,7 +18,7 @@ from cli.utils.docker_utils import (
     wait_for_container_health,
     wait_for_container_health,
 )
 )
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -29,35 +29,53 @@ def system():
 
 
 @cli.command()
 @cli.command()
 @pass_context
 @pass_context
-async def health(ctx):
+async def health(ctx: click.Context):
     """Check the health of the server."""
     """Check the health of the server."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.health()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.health()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 
 
 @system.command()
 @system.command()
 @pass_context
 @pass_context
-async def settings(ctx):
+async def settings(ctx: click.Context):
     """Retrieve application settings."""
     """Retrieve application settings."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.settings()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.settings()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 
 
 @system.command()
 @system.command()
 @pass_context
 @pass_context
-async def status(ctx):
+async def status(ctx: click.Context):
     """Get statistics about the server, including the start time, uptime, CPU usage, and memory usage."""
     """Get statistics about the server, including the start time, uptime, CPU usage, and memory usage."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
-    with timer():
-        response = await client.system.status()
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.system.status()
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+        raise
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise
 
 
 
 
 @cli.command()
 @cli.command()
@@ -400,4 +418,9 @@ def version():
     """Reports the SDK version."""
     """Reports the SDK version."""
     from importlib.metadata import version
     from importlib.metadata import version
 
 
-    click.echo(json.dumps(version("r2r"), indent=2))
+    try:
+        r2r_version = version("r2r")
+        click.echo(json.dumps(r2r_version, indent=2))
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
+        raise

+ 98 - 53
cli/commands/users.py

@@ -1,10 +1,12 @@
 import json
 import json
+from builtins import list as _list
+from uuid import UUID
 
 
 import asyncclick as click
 import asyncclick as click
 from asyncclick import pass_context
 from asyncclick import pass_context
 
 
 from cli.utils.timer import timer
 from cli.utils.timer import timer
-from r2r import R2RAsyncClient
+from r2r import R2RAsyncClient, R2RException
 
 
 
 
 @click.group()
 @click.group()
@@ -17,68 +19,98 @@ def users():
 @click.argument("email", required=True, type=str)
 @click.argument("email", required=True, type=str)
 @click.argument("password", required=True, type=str)
 @click.argument("password", required=True, type=str)
 @pass_context
 @pass_context
-async def create(ctx, email, password):
+async def create(ctx: click.Context, email: str, password: str):
     """Create a new user."""
     """Create a new user."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.create(email=email, password=password)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.create(
+                email=email,
+                password=password,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
-@click.option("--ids", multiple=True, help="Document IDs to fetch")
+@click.option("--ids", multiple=True, help="Document IDs to fetch", type=str)
 @click.option(
 @click.option(
     "--offset",
     "--offset",
     default=0,
     default=0,
     help="The offset to start from. Defaults to 0.",
     help="The offset to start from. Defaults to 0.",
+    type=int,
 )
 )
 @click.option(
 @click.option(
     "--limit",
     "--limit",
     default=100,
     default=100,
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
+    type=int,
 )
 )
 @pass_context
 @pass_context
-async def list(ctx, ids, offset, limit):
+async def list(
+    ctx: click.Context,
+    ids: tuple[str, ...],
+    offset: int,
+    limit: int,
+):
     """Get an overview of users."""
     """Get an overview of users."""
-    client: R2RAsyncClient = ctx.obj
-    ids = list(ids) if ids else None
+    uuids: _list[str | UUID] | None = (
+        [UUID(id_) for id_ in ids] if ids else None
+    )
 
 
-    with timer():
-        response = await client.users.list(
-            ids=ids,
-            offset=offset,
-            limit=limit,
-        )
-
-    for user in response["results"]:
-        click.echo(json.dumps(user, indent=2))
+    client: R2RAsyncClient = ctx.obj
+    try:
+        with timer():
+            response = await client.users.list(
+                ids=uuids,
+                offset=offset,
+                limit=limit,
+            )
+
+        for user in response["results"]:  # type: ignore
+            click.echo(json.dumps(user, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @pass_context
 @pass_context
-async def retrieve(ctx, id):
+async def retrieve(ctx: click.Context, id):
     """Retrieve a user by ID."""
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.retrieve(id=id)
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.retrieve(id=id)
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
 @pass_context
 @pass_context
-async def me(ctx):
+async def me(ctx: click.Context):
     """Retrieve the current user."""
     """Retrieve the current user."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.me()
+    try:
+        with timer():
+            response = await client.users.me()
 
 
-    click.echo(json.dumps(response, indent=2))
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
@@ -94,50 +126,63 @@ async def me(ctx):
     help="The maximum number of nodes to return. Defaults to 100.",
     help="The maximum number of nodes to return. Defaults to 100.",
 )
 )
 @pass_context
 @pass_context
-async def list_collections(ctx, id, offset, limit):
+async def list_collections(ctx: click.Context, id, offset, limit):
     """List collections for a specific user."""
     """List collections for a specific user."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.list_collections(
-            id=id,
-            offset=offset,
-            limit=limit,
-        )
+    try:
+        with timer():
+            response = await client.users.list_collections(
+                id=id,
+                offset=offset,
+                limit=limit,
+            )
 
 
-    for collection in response["results"]:
-        click.echo(json.dumps(collection, indent=2))
+        for collection in response["results"]:  # type: ignore
+            click.echo(json.dumps(collection, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
 @pass_context
-async def add_to_collection(ctx, id, collection_id):
+async def add_to_collection(ctx: click.Context, id, collection_id):
     """Retrieve a user by ID."""
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.add_to_collection(
-            id=id,
-            collection_id=collection_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.add_to_collection(
+                id=id,
+                collection_id=collection_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)
 
 
 
 
 @users.command()
 @users.command()
 @click.argument("id", required=True, type=str)
 @click.argument("id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @click.argument("collection_id", required=True, type=str)
 @pass_context
 @pass_context
-async def remove_from_collection(ctx, id, collection_id):
+async def remove_from_collection(ctx: click.Context, id, collection_id):
     """Retrieve a user by ID."""
     """Retrieve a user by ID."""
     client: R2RAsyncClient = ctx.obj
     client: R2RAsyncClient = ctx.obj
 
 
-    with timer():
-        response = await client.users.remove_from_collection(
-            id=id,
-            collection_id=collection_id,
-        )
-
-    click.echo(json.dumps(response, indent=2))
+    try:
+        with timer():
+            response = await client.users.remove_from_collection(
+                id=id,
+                collection_id=collection_id,
+            )
+        click.echo(json.dumps(response, indent=2))
+    except R2RException as e:
+        click.echo(str(e), err=True)
+    except Exception as e:
+        click.echo(str(f"An unexpected error occurred: {e}"), err=True)

+ 76 - 4
cli/main.py

@@ -1,6 +1,13 @@
+import json
+from typing import Any, Dict
+
+import asyncclick as click
+from rich.console import Console
+
 from cli.command_group import cli
 from cli.command_group import cli
 from cli.commands import (
 from cli.commands import (
     collections,
     collections,
+    config,
     conversations,
     conversations,
     database,
     database,
     documents,
     documents,
@@ -12,6 +19,11 @@ from cli.commands import (
     users,
     users,
 )
 )
 from cli.utils.telemetry import posthog, telemetry
 from cli.utils.telemetry import posthog, telemetry
+from r2r import R2RAsyncClient
+
+from .command_group import CONFIG_DIR, CONFIG_FILE, load_config
+
+console = Console()
 
 
 
 
 def add_command_with_telemetry(command):
 def add_command_with_telemetry(command):
@@ -39,22 +51,82 @@ add_command_with_telemetry(database.downgrade)
 add_command_with_telemetry(database.current)
 add_command_with_telemetry(database.current)
 add_command_with_telemetry(database.history)
 add_command_with_telemetry(database.history)
 
 
+add_command_with_telemetry(config.configure)
+
 
 
 def main():
 def main():
     try:
     try:
         cli()
         cli()
     except SystemExit:
     except SystemExit:
-        # Silently exit without printing the traceback
         pass
         pass
     except Exception as e:
     except Exception as e:
-        # Handle other exceptions if needed
-        raise e
+        console.print("[red]CLI error: An error occurred[/red]")
+        console.print_exception()
     finally:
     finally:
-        # Ensure all events are flushed before exiting
         if posthog:
         if posthog:
             posthog.flush()
             posthog.flush()
             posthog.shutdown()
             posthog.shutdown()
 
 
 
 
+def _ensure_config_dir_exists() -> None:
+    """Ensure that the ~/.r2r/ directory exists."""
+    CONFIG_DIR.mkdir(parents=True, exist_ok=True)
+
+
+def save_config(config_data: Dict[str, Any]) -> None:
+    """
+    Persist the given config data to ~/.r2r/config.json.
+    """
+    _ensure_config_dir_exists()
+    with open(CONFIG_FILE, "w", encoding="utf-8") as f:
+        json.dump(config_data, f, indent=2)
+
+
+@cli.command("set-api-key", short_help="Set your R2R API key")
+@click.argument("api_key", required=True, type=str)
+@click.pass_context
+async def set_api_key(ctx, api_key: str):
+    """
+    Store your R2R API key locally so you don’t have to pass it on every command.
+    Example usage:
+      r2r set-api sk-1234abcd
+    """
+    try:
+        # 1) Load existing config
+        config = load_config()
+
+        # 2) Overwrite or add the API key
+        config["api_key"] = api_key
+
+        # 3) Save changes
+        save_config(config)
+
+        console.print("[green]API key set successfully![/green]")
+    except Exception as e:
+        console.print("[red]Failed to set API key:[/red]", str(e))
+
+
+@cli.command("get-api", short_help="Get your stored R2R API key")
+@click.pass_context
+async def get_api(ctx):
+    """
+    Display your stored R2R API key.
+    Example usage:
+      r2r get-api
+    """
+    try:
+        config = load_config()
+        api_key = config.get("api_key")
+
+        if api_key:
+            console.print(f"API Key: {api_key}")
+        else:
+            console.print(
+                "[yellow]No API key found. Set one using 'r2r set-api <key>'[/yellow]"
+            )
+    except Exception as e:
+        console.print("[red]Failed to retrieve API key:[/red]", str(e))
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()

+ 0 - 6
cli/utils/docker_utils.py

@@ -116,12 +116,6 @@ async def run_local_serve(
 
 
     await r2r_instance.orchestration_provider.start_worker()
     await r2r_instance.orchestration_provider.start_worker()
 
 
-    # TODO: make this work with autoreload, currently due to hatchet, it causes a reload error
-    # import uvicorn
-    # uvicorn.run(
-    #     "core.main.app_entry:app", host=host, port=available_port, reload=False
-    # )
-
     await r2r_instance.serve(host, available_port)
     await r2r_instance.serve(host, available_port)
 
 
 
 

+ 2 - 1
cli/utils/timer.py

@@ -13,4 +13,5 @@ def timer():
     start = time.time()
     start = time.time()
     yield
     yield
     end = time.time()
     end = time.time()
-    click.echo(f"Time taken: {end - start:.2f} seconds")
+    duration = max(0, end - start)
+    click.echo(f"Time taken: {duration:.2f} seconds")

+ 124 - 0
compose.yaml

@@ -0,0 +1,124 @@
+networks:
+  r2r-network:
+    driver: bridge
+    attachable: true
+    labels:
+      - "com.docker.compose.recreate=always"
+
+volumes:
+  postgres_data:
+    name: ${VOLUME_POSTGRES_DATA:-postgres_data}
+
+services:
+  postgres:
+    image: pgvector/pgvector:pg16
+    profiles: [postgres]
+    environment:
+      - POSTGRES_USER=${R2R_POSTGRES_USER:-${POSTGRES_USER:-postgres}} # Eventually get rid of POSTGRES_USER, but for now keep it for backwards compatibility
+      - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-${POSTGRES_PASSWORD:-postgres}} # Eventually get rid of POSTGRES_PASSWORD, but for now keep it for backwards compatibility
+      - POSTGRES_HOST=${R2R_POSTGRES_HOST:-${POSTGRES_HOST:-postgres}} # Eventually get rid of POSTGRES_HOST, but for now keep it for backwards compatibility
+      - POSTGRES_PORT=${R2R_POSTGRES_PORT:-${POSTGRES_PORT:-5432}} # Eventually get rid of POSTGRES_PORT, but for now keep it for backwards compatibility
+      - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-${POSTGRES_MAX_CONNECTIONS:-1024}} # Eventually get rid of POSTGRES_MAX_CONNECTIONS, but for now keep it for backwards compatibility
+      - PGPORT=${R2R_POSTGRES_PORT:-5432}
+    volumes:
+      - postgres_data:/var/lib/postgresql/data
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}"
+    healthcheck:
+      test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"]
+      interval: 10s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    command: >
+      postgres
+      -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+
+  r2r:
+    image: ${R2R_IMAGE:-ragtoriches/prod:latest}
+    build:
+      context: .
+      args:
+        PORT: ${R2R_PORT:-7272}
+        R2R_PORT: ${R2R_PORT:-7272}
+        HOST: ${R2R_HOST:-0.0.0.0}
+        R2R_HOST: ${R2R_HOST:-0.0.0.0}
+    ports:
+      - "${R2R_PORT:-7272}:${R2R_PORT:-7272}"
+    environment:
+      - PYTHONUNBUFFERED=1
+      - R2R_PORT=${R2R_PORT:-7272}
+      - R2R_HOST=${R2R_HOST:-0.0.0.0}
+
+      # R2R
+      - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-}
+      - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-}
+      - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default}
+
+      # Postgres
+      - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres}
+      - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres}
+      - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres}
+      - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432}
+      - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres}
+      - R2R_POSTGRES_PROJECT_NAME=${R2R_POSTGRES_PROJECT_NAME:-r2r_default}
+      - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024}
+      - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100}
+
+      # OpenAI
+      - OPENAI_API_KEY=${OPENAI_API_KEY:-}
+      - OPENAI_API_BASE=${OPENAI_API_BASE:-}
+
+      # Anthropic
+      - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
+
+      # Azure
+      - AZURE_API_KEY=${AZURE_API_KEY:-}
+      - AZURE_API_BASE=${AZURE_API_BASE:-}
+      - AZURE_API_VERSION=${AZURE_API_VERSION:-}
+
+      # Google Vertex AI
+      - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
+      - VERTEX_PROJECT=${VERTEX_PROJECT:-}
+      - VERTEX_LOCATION=${VERTEX_LOCATION:-}
+
+      # AWS Bedrock
+      - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
+      - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
+      - AWS_REGION_NAME=${AWS_REGION_NAME:-}
+
+      # Groq
+      - GROQ_API_KEY=${GROQ_API_KEY:-}
+
+      # Cohere
+      - COHERE_API_KEY=${COHERE_API_KEY:-}
+
+      # Anyscale
+      - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-}
+
+      # Ollama
+      - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}
+
+    networks:
+      - r2r-network
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"]
+      interval: 6s
+      timeout: 5s
+      retries: 5
+    restart: on-failure
+    volumes:
+      - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config}
+    extra_hosts:
+      - host.docker.internal:host-gateway
+
+  r2r-dashboard:
+    image: emrgntcmplxty/r2r-dashboard:latest
+    environment:
+      - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272}
+    networks:
+      - r2r-network
+    ports:
+      - "${R2R_DASHBOARD_PORT:-7273}:3000"

+ 0 - 1
core/base/parsers/base_parser.py

@@ -7,7 +7,6 @@ T = TypeVar("T")
 
 
 
 
 class AsyncParser(ABC, Generic[T]):
 class AsyncParser(ABC, Generic[T]):
-
     @abstractmethod
     @abstractmethod
     async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
     async def ingest(self, data: T, **kwargs) -> AsyncGenerator[str, None]:
         pass
         pass

+ 0 - 1
core/base/providers/ingestion.py

@@ -179,7 +179,6 @@ class IngestionConfig(ProviderConfig):
 
 
 
 
 class IngestionProvider(Provider, ABC):
 class IngestionProvider(Provider, ABC):
-
     config: IngestionConfig
     config: IngestionConfig
     database_provider: "PostgresDatabaseProvider"
     database_provider: "PostgresDatabaseProvider"
     llm_provider: CompletionProvider
     llm_provider: CompletionProvider

+ 53 - 10
core/database/base.py

@@ -54,11 +54,18 @@ class QueryBuilder:
     def __init__(self, table_name: str):
     def __init__(self, table_name: str):
         self.table_name = table_name
         self.table_name = table_name
         self.conditions: list[str] = []
         self.conditions: list[str] = []
-        self.params: dict = {}
+        self.params: list = (
+            []
+        )  # Changed from dict to list for PostgreSQL $1, $2 style
         self.select_fields = "*"
         self.select_fields = "*"
         self.operation = "SELECT"
         self.operation = "SELECT"
         self.limit_value: Optional[int] = None
         self.limit_value: Optional[int] = None
+        self.offset_value: Optional[int] = None
+        self.order_by_fields: Optional[str] = None
+        self.returning_fields: Optional[list[str]] = None
         self.insert_data: Optional[dict] = None
         self.insert_data: Optional[dict] = None
+        self.update_data: Optional[dict] = None
+        self.param_counter = 1  # For generating $1, $2, etc.
 
 
     def select(self, fields: list[str]):
     def select(self, fields: list[str]):
         self.select_fields = ", ".join(fields)
         self.select_fields = ", ".join(fields)
@@ -69,45 +76,81 @@ class QueryBuilder:
         self.insert_data = data
         self.insert_data = data
         return self
         return self
 
 
+    def update(self, data: dict):
+        self.operation = "UPDATE"
+        self.update_data = data
+        return self
+
     def delete(self):
     def delete(self):
         self.operation = "DELETE"
         self.operation = "DELETE"
         return self
         return self
 
 
-    def where(self, condition: str, **kwargs):
+    def where(self, condition: str):
         self.conditions.append(condition)
         self.conditions.append(condition)
-        self.params.update(kwargs)
         return self
         return self
 
 
-    def limit(self, value: int):
+    def limit(self, value: Optional[str]):
         self.limit_value = value
         self.limit_value = value
         return self
         return self
 
 
+    def offset(self, value: str):
+        self.offset_value = value
+        return self
+
+    def order_by(self, fields: str):
+        self.order_by_fields = fields
+        return self
+
+    def returning(self, fields: list[str]):
+        self.returning_fields = fields
+        return self
+
     def build(self):
     def build(self):
         if self.operation == "SELECT":
         if self.operation == "SELECT":
             query = f"SELECT {self.select_fields} FROM {self.table_name}"
             query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
         elif self.operation == "INSERT":
         elif self.operation == "INSERT":
             columns = ", ".join(self.insert_data.keys())
             columns = ", ".join(self.insert_data.keys())
-            values = ", ".join(f":{key}" for key in self.insert_data.keys())
-            query = (
-                f"INSERT INTO {self.table_name} ({columns}) VALUES ({values})"
+            placeholders = ", ".join(
+                f"${i}" for i in range(1, len(self.insert_data) + 1)
             )
             )
-            self.params.update(self.insert_data)
+            query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+            self.params.extend(list(self.insert_data.values()))
+
+        elif self.operation == "UPDATE":
+            set_clauses = []
+            for i, (key, value) in enumerate(
+                self.update_data.items(), start=len(self.params) + 1
+            ):
+                set_clauses.append(f"{key} = ${i}")
+                self.params.append(value)
+            query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
         elif self.operation == "DELETE":
         elif self.operation == "DELETE":
             query = f"DELETE FROM {self.table_name}"
             query = f"DELETE FROM {self.table_name}"
+
         else:
         else:
             raise ValueError(f"Unsupported operation: {self.operation}")
             raise ValueError(f"Unsupported operation: {self.operation}")
 
 
         if self.conditions:
         if self.conditions:
             query += " WHERE " + " AND ".join(self.conditions)
             query += " WHERE " + " AND ".join(self.conditions)
 
 
-        if self.limit_value is not None and self.operation == "SELECT":
+        if self.order_by_fields and self.operation == "SELECT":
+            query += f" ORDER BY {self.order_by_fields}"
+
+        if self.offset_value is not None:
+            query += f" OFFSET {self.offset_value}"
+
+        if self.limit_value is not None:
             query += f" LIMIT {self.limit_value}"
             query += f" LIMIT {self.limit_value}"
 
 
+        if self.returning_fields:
+            query += f" RETURNING {', '.join(self.returning_fields)}"
+
         return query, self.params
         return query, self.params
 
 
 
 
 class PostgresConnectionManager(DatabaseConnectionManager):
 class PostgresConnectionManager(DatabaseConnectionManager):
-
     def __init__(self):
     def __init__(self):
         self.pool: Optional[SemaphoreConnectionPool] = None
         self.pool: Optional[SemaphoreConnectionPool] = None
 
 

+ 0 - 2
core/database/chunks.py

@@ -423,7 +423,6 @@ class PostgresChunksHandler(Handler):
     async def full_text_search(
     async def full_text_search(
         self, query_text: str, search_settings: SearchSettings
         self, query_text: str, search_settings: SearchSettings
     ) -> list[ChunkSearchResult]:
     ) -> list[ChunkSearchResult]:
-
         conditions = []
         conditions = []
         params: list[str | int | bytes] = [query_text]
         params: list[str | int | bytes] = [query_text]
 
 
@@ -1049,7 +1048,6 @@ class PostgresChunksHandler(Handler):
         id: UUID,
         id: UUID,
         similarity_threshold: float = 0.5,
         similarity_threshold: float = 0.5,
     ) -> list[dict[str, Any]]:
     ) -> list[dict[str, Any]]:
-
         table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
         table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
         query = f"""
         query = f"""
         WITH target_vector AS (
         WITH target_vector AS (

+ 0 - 1
core/database/collections.py

@@ -73,7 +73,6 @@ class PostgresCollectionsHandler(Handler):
         description: str = "",
         description: str = "",
         collection_id: Optional[UUID] = None,
         collection_id: Optional[UUID] = None,
     ) -> CollectionResponse:
     ) -> CollectionResponse:
-
         if not name and not collection_id:
         if not name and not collection_id:
             name = self.config.default_collection_name
             name = self.config.default_collection_name
             collection_id = generate_default_user_collection_id(owner_id)
             collection_id = generate_default_user_collection_id(owner_id)

+ 0 - 1
core/database/documents.py

@@ -152,7 +152,6 @@ class PostgresDocumentsHandler(Handler):
                                     document.id,
                                     document.id,
                                 )
                                 )
                             else:
                             else:
-
                                 insert_query = f"""
                                 insert_query = f"""
                                 INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                                 INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                                 (id, collection_ids, owner_id, type, metadata, title, version,
                                 (id, collection_ids, owner_id, type, metadata, title, version,

+ 0 - 13
core/database/graphs.py

@@ -776,7 +776,6 @@ class PostgresRelationshipsHandler(Handler):
 
 
 
 
 class PostgresCommunitiesHandler(Handler):
 class PostgresCommunitiesHandler(Handler):
-
     def __init__(self, *args: Any, **kwargs: Any) -> None:
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
@@ -784,7 +783,6 @@ class PostgresCommunitiesHandler(Handler):
         self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
         self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type")  # type: ignore
 
 
     async def create_tables(self) -> None:
     async def create_tables(self) -> None:
-
         vector_column_str = _decorate_vector_type(
         vector_column_str = _decorate_vector_type(
             f"({self.dimension})", self.quantization_type
             f"({self.dimension})", self.quantization_type
         )
         )
@@ -1072,7 +1070,6 @@ class PostgresGraphsHandler(Handler):
         *args: Any,
         *args: Any,
         **kwargs: Any,
         **kwargs: Any,
     ) -> None:
     ) -> None:
-
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.project_name: str = kwargs.get("project_name")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
         self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager")  # type: ignore
         self.dimension: int = kwargs.get("dimension")  # type: ignore
         self.dimension: int = kwargs.get("dimension")  # type: ignore
@@ -1258,9 +1255,7 @@ class PostgresGraphsHandler(Handler):
     async def get(
     async def get(
         self, offset: int, limit: int, graph_id: Optional[UUID] = None
         self, offset: int, limit: int, graph_id: Optional[UUID] = None
     ):
     ):
-
         if graph_id is None:
         if graph_id is None:
-
             params = [offset, limit]
             params = [offset, limit]
 
 
             QUERY = f"""
             QUERY = f"""
@@ -1498,7 +1493,6 @@ class PostgresGraphsHandler(Handler):
     ):
     ):
         """Get the estimated cost and time for enriching a KG."""
         """Get the estimated cost and time for enriching a KG."""
         if collection_id is not None:
         if collection_id is not None:
-
             document_ids = [
             document_ids = [
                 doc.id
                 doc.id
                 for doc in (
                 for doc in (
@@ -1857,7 +1851,6 @@ class PostgresGraphsHandler(Handler):
         graph_id: UUID | None,
         graph_id: UUID | None,
         document_ids: Optional[list[UUID]] = None,
         document_ids: Optional[list[UUID]] = None,
     ) -> list[Relationship]:
     ) -> list[Relationship]:
-
         QUERY = f"""
         QUERY = f"""
             SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
             SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
         """
         """
@@ -1969,7 +1962,6 @@ class PostgresGraphsHandler(Handler):
         return communities, count
         return communities, count
 
 
     async def add_community(self, community: Community) -> None:
     async def add_community(self, community: Community) -> None:
-
         # TODO: Fix in the short term.
         # TODO: Fix in the short term.
         # we need to do this because postgres insert needs to be a string
         # we need to do this because postgres insert needs to be a string
         community.description_embedding = str(community.description_embedding)  # type: ignore[assignment]
         community.description_embedding = str(community.description_embedding)  # type: ignore[assignment]
@@ -1997,7 +1989,6 @@ class PostgresGraphsHandler(Handler):
 
 
     # async def delete(self, collection_id: UUID, cascade: bool = False) -> None:
     # async def delete(self, collection_id: UUID, cascade: bool = False) -> None:
     async def delete(self, collection_id: UUID) -> None:
     async def delete(self, collection_id: UUID) -> None:
-
         graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
         graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
 
 
         if len(graphs["results"]) == 0:
         if len(graphs["results"]) == 0:
@@ -2168,7 +2159,6 @@ class PostgresGraphsHandler(Handler):
         collection_id: Optional[UUID] = None,
         collection_id: Optional[UUID] = None,
         clustering_mode: str = "local",
         clustering_mode: str = "local",
     ) -> Tuple[int, Any]:
     ) -> Tuple[int, Any]:
-
         # clear if there is any old information
         # clear if there is any old information
         conditions = []
         conditions = []
         if collection_id is not None:
         if collection_id is not None:
@@ -2248,7 +2238,6 @@ class PostgresGraphsHandler(Handler):
     async def get_entity_map(
     async def get_entity_map(
         self, offset: int, limit: int, document_id: UUID
         self, offset: int, limit: int, document_id: UUID
     ) -> dict[str, dict[str, list[dict[str, Any]]]]:
     ) -> dict[str, dict[str, list[dict[str, Any]]]]:
-
         QUERY1 = f"""
         QUERY1 = f"""
             WITH entities_list AS (
             WITH entities_list AS (
                 SELECT DISTINCT name
                 SELECT DISTINCT name
@@ -2555,7 +2544,6 @@ class PostgresGraphsHandler(Handler):
         distinct: bool = False,
         distinct: bool = False,
         entity_table_name: str = "entity",
         entity_table_name: str = "entity",
     ) -> int:
     ) -> int:
-
         if collection_id is None and document_id is None:
         if collection_id is None and document_id is None:
             raise ValueError(
             raise ValueError(
                 "Either collection_id or document_id must be provided."
                 "Either collection_id or document_id must be provided."
@@ -2576,7 +2564,6 @@ class PostgresGraphsHandler(Handler):
         ]
         ]
 
 
     async def update_entity_descriptions(self, entities: list[Entity]):
     async def update_entity_descriptions(self, entities: list[Entity]):
-
         query = f"""
         query = f"""
             UPDATE {self._get_table_name("graphs_entities")}
             UPDATE {self._get_table_name("graphs_entities")}
             SET description = $3, description_embedding = $4
             SET description = $3, description_embedding = $4

+ 89 - 54
core/database/limits.py

@@ -4,6 +4,7 @@ from typing import Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from core.base import Handler
 from core.base import Handler
+from shared.abstractions import User  # your domain user model
 
 
 from ..base.providers.database import DatabaseConfig, LimitSettings
 from ..base.providers.database import DatabaseConfig, LimitSettings
 from .base import PostgresConnectionManager
 from .base import PostgresConnectionManager
@@ -20,8 +21,12 @@ class PostgresLimitsHandler(Handler):
         connection_manager: PostgresConnectionManager,
         connection_manager: PostgresConnectionManager,
         config: DatabaseConfig,
         config: DatabaseConfig,
     ):
     ):
+        """
+        :param config: The global DatabaseConfig with default rate limits.
+        """
         super().__init__(project_name, connection_manager)
         super().__init__(project_name, connection_manager)
-        self.config = config
+        self.config = config  # Contains e.g. self.config.limits for fallback
+
         logger.debug(
         logger.debug(
             f"Initialized PostgresLimitsHandler with project: {project_name}"
             f"Initialized PostgresLimitsHandler with project: {project_name}"
         )
         )
@@ -38,8 +43,15 @@ class PostgresLimitsHandler(Handler):
         await self.connection_manager.execute_query(query)
         await self.connection_manager.execute_query(query)
 
 
     async def _count_requests(
     async def _count_requests(
-        self, user_id: UUID, route: Optional[str], since: datetime
+        self,
+        user_id: UUID,
+        route: Optional[str],
+        since: datetime,
     ) -> int:
     ) -> int:
+        """
+        Count how many requests a user (optionally for a specific route)
+        has made since the given datetime.
+        """
         if route:
         if route:
             query = f"""
             query = f"""
             SELECT COUNT(*)::int
             SELECT COUNT(*)::int
@@ -49,7 +61,9 @@ class PostgresLimitsHandler(Handler):
               AND time >= $3
               AND time >= $3
             """
             """
             params = [user_id, route, since]
             params = [user_id, route, since]
-            logger.debug(f"Counting requests for route {route}")
+            logger.debug(
+                f"Counting requests for user={user_id}, route={route}"
+            )
         else:
         else:
             query = f"""
             query = f"""
             SELECT COUNT(*)::int
             SELECT COUNT(*)::int
@@ -58,68 +72,86 @@ class PostgresLimitsHandler(Handler):
               AND time >= $2
               AND time >= $2
             """
             """
             params = [user_id, since]
             params = [user_id, since]
-            logger.debug("Counting all requests")
+            logger.debug(f"Counting all requests for user={user_id}")
 
 
         result = await self.connection_manager.fetchrow_query(query, params)
         result = await self.connection_manager.fetchrow_query(query, params)
-        count = result["count"] if result else 0
-
-        return count
+        return result["count"] if result else 0
 
 
     async def _count_monthly_requests(self, user_id: UUID) -> int:
     async def _count_monthly_requests(self, user_id: UUID) -> int:
+        """
+        Count the number of requests so far this month for a given user.
+        """
         now = datetime.now(timezone.utc)
         now = datetime.now(timezone.utc)
         start_of_month = now.replace(
         start_of_month = now.replace(
             day=1, hour=0, minute=0, second=0, microsecond=0
             day=1, hour=0, minute=0, second=0, microsecond=0
         )
         )
+        return await self._count_requests(user_id, None, start_of_month)
 
 
-        count = await self._count_requests(
-            user_id, route=None, since=start_of_month
-        )
-        return count
-
-    def _determine_limits_for(
-        self, user_id: UUID, route: str
-    ) -> LimitSettings:
-        # Start with base limits
-        limits = self.config.limits
-
-        # Route-specific limits - directly override if present
-        route_limits = self.config.route_limits.get(route)
-        if route_limits:
-            # Only override non-None values from route_limits
-            if route_limits.global_per_min is not None:
-                limits.global_per_min = route_limits.global_per_min
-            if route_limits.route_per_min is not None:
-                limits.route_per_min = route_limits.route_per_min
-            if route_limits.monthly_limit is not None:
-                limits.monthly_limit = route_limits.monthly_limit
-
-        # User-specific limits - directly override if present
-        user_limits = self.config.user_limits.get(user_id)
-        if user_limits:
-            # Only override non-None values from user_limits
-            if user_limits.global_per_min is not None:
-                limits.global_per_min = user_limits.global_per_min
-            if user_limits.route_per_min is not None:
-                limits.route_per_min = user_limits.route_per_min
-            if user_limits.monthly_limit is not None:
-                limits.monthly_limit = user_limits.monthly_limit
-
-        return limits
-
-    async def check_limits(self, user_id: UUID, route: str):
-        # Determine final applicable limits
-        limits = self._determine_limits_for(user_id, route)
-        if not limits:
-            limits = self.config.default_limits
-
-        global_per_min = limits.global_per_min
-        route_per_min = limits.route_per_min
-        monthly_limit = limits.monthly_limit
+    async def check_limits(self, user: User, route: str):
+        """
+        Perform rate limit checks for a user on a specific route.
 
 
+        :param user: The fully-fetched User object with .limits_overrides, etc.
+        :param route: The route/path being accessed.
+        :raises ValueError: if any limit is exceeded.
+        """
+        user_id = user.id
         now = datetime.now(timezone.utc)
         now = datetime.now(timezone.utc)
         one_min_ago = now - timedelta(minutes=1)
         one_min_ago = now - timedelta(minutes=1)
 
 
-        # Global per-minute check
+        # 1) First check route-specific configuration limits
+        route_config = self.config.route_limits.get(route)
+        if route_config:
+            # Check route-specific per-minute limit
+            if route_config.route_per_min is not None:
+                route_req_count = await self._count_requests(
+                    user_id, route, one_min_ago
+                )
+                if route_req_count > route_config.route_per_min:
+                    logger.warning(
+                        f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
+                    )
+                    raise ValueError(
+                        "Per-route per-minute rate limit exceeded"
+                    )
+
+            # Check route-specific monthly limit
+            if route_config.monthly_limit is not None:
+                monthly_count = await self._count_monthly_requests(user_id)
+                if monthly_count > route_config.monthly_limit:
+                    logger.warning(
+                        f"Route monthly limit exceeded for user_id={user_id}, route={route}"
+                    )
+                    raise ValueError("Route monthly limit exceeded")
+
+        # 2) Get user overrides and base limits
+        user_overrides = user.limits_overrides or {}
+        base_limits = self.config.limits
+
+        # Extract user-level overrides
+        global_per_min = user_overrides.get(
+            "global_per_min", base_limits.global_per_min
+        )
+        monthly_limit = user_overrides.get(
+            "monthly_limit", base_limits.monthly_limit
+        )
+
+        # 3) Check route-specific overrides from user config
+        route_overrides = user_overrides.get("route_overrides", {})
+        specific_config = route_overrides.get(route, {})
+
+        # Apply route-specific overrides for per-minute limits
+        route_per_min = specific_config.get(
+            "route_per_min", base_limits.route_per_min
+        )
+
+        # If route specifically overrides global or monthly limits, apply them
+        if "global_per_min" in specific_config:
+            global_per_min = specific_config["global_per_min"]
+        if "monthly_limit" in specific_config:
+            monthly_limit = specific_config["monthly_limit"]
+
+        # 4) Check global per-minute limit
         if global_per_min is not None:
         if global_per_min is not None:
             user_req_count = await self._count_requests(
             user_req_count = await self._count_requests(
                 user_id, None, one_min_ago
                 user_id, None, one_min_ago
@@ -130,7 +162,7 @@ class PostgresLimitsHandler(Handler):
                 )
                 )
                 raise ValueError("Global per-minute rate limit exceeded")
                 raise ValueError("Global per-minute rate limit exceeded")
 
 
-        # Per-route per-minute check
+        # 5) Check user-specific route per-minute limit
         if route_per_min is not None:
         if route_per_min is not None:
             route_req_count = await self._count_requests(
             route_req_count = await self._count_requests(
                 user_id, route, one_min_ago
                 user_id, route, one_min_ago
@@ -141,7 +173,7 @@ class PostgresLimitsHandler(Handler):
                 )
                 )
                 raise ValueError("Per-route per-minute rate limit exceeded")
                 raise ValueError("Per-route per-minute rate limit exceeded")
 
 
-        # Monthly limit check
+        # 6) Check monthly limit
         if monthly_limit is not None:
         if monthly_limit is not None:
             monthly_count = await self._count_monthly_requests(user_id)
             monthly_count = await self._count_monthly_requests(user_id)
             if monthly_count > monthly_limit:
             if monthly_count > monthly_limit:
@@ -151,6 +183,9 @@ class PostgresLimitsHandler(Handler):
                 raise ValueError("Monthly rate limit exceeded")
                 raise ValueError("Monthly rate limit exceeded")
 
 
     async def log_request(self, user_id: UUID, route: str):
     async def log_request(self, user_id: UUID, route: str):
+        """
+        Log a successful request to the request_log table.
+        """
         query = f"""
         query = f"""
         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)

+ 234 - 116
core/database/users.py

@@ -1,5 +1,6 @@
+import json
 from datetime import datetime
 from datetime import datetime
-from typing import Optional
+from typing import Any, Dict, List, Optional
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import HTTPException
 from fastapi import HTTPException
@@ -43,10 +44,12 @@ class PostgresUserHandler(Handler):
             reset_token TEXT,
             reset_token TEXT,
             reset_token_expiry TIMESTAMPTZ,
             reset_token_expiry TIMESTAMPTZ,
             collection_ids UUID[] NULL,
             collection_ids UUID[] NULL,
+            limits_overrides JSONB,
             created_at TIMESTAMPTZ DEFAULT NOW(),
             created_at TIMESTAMPTZ DEFAULT NOW(),
             updated_at TIMESTAMPTZ DEFAULT NOW()
             updated_at TIMESTAMPTZ DEFAULT NOW()
         );
         );
         """
         """
+
         # API keys table with updated_at instead of last_used_at
         # API keys table with updated_at instead of last_used_at
         api_keys_table_query = f"""
         api_keys_table_query = f"""
         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
@@ -86,6 +89,7 @@ class PostgresUserHandler(Handler):
                     "profile_picture",
                     "profile_picture",
                     "bio",
                     "bio",
                     "collection_ids",
                     "collection_ids",
+                    "limits_overrides",  # Fetch JSONB column
                 ]
                 ]
             )
             )
             .where("id = $1")
             .where("id = $1")
@@ -109,6 +113,8 @@ class PostgresUserHandler(Handler):
             profile_picture=result["profile_picture"],
             profile_picture=result["profile_picture"],
             bio=result["bio"],
             bio=result["bio"],
             collection_ids=result["collection_ids"],
             collection_ids=result["collection_ids"],
+            # Add the new field
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
         )
         )
 
 
     async def get_user_by_email(self, email: str) -> User:
     async def get_user_by_email(self, email: str) -> User:
@@ -128,6 +134,7 @@ class PostgresUserHandler(Handler):
                     "profile_picture",
                     "profile_picture",
                     "bio",
                     "bio",
                     "collection_ids",
                     "collection_ids",
+                    "limits_overrides",
                 ]
                 ]
             )
             )
             .where("email = $1")
             .where("email = $1")
@@ -150,13 +157,16 @@ class PostgresUserHandler(Handler):
             profile_picture=result["profile_picture"],
             profile_picture=result["profile_picture"],
             bio=result["bio"],
             bio=result["bio"],
             collection_ids=result["collection_ids"],
             collection_ids=result["collection_ids"],
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
         )
         )
 
 
     async def create_user(
     async def create_user(
         self, email: str, password: str, is_superuser: bool = False
         self, email: str, password: str, is_superuser: bool = False
     ) -> User:
     ) -> User:
+        """Create a new user."""
         try:
         try:
-            if await self.get_user_by_email(email):
+            existing = await self.get_user_by_email(email)
+            if existing:
                 raise R2RException(
                 raise R2RException(
                     status_code=400,
                     status_code=400,
                     message="User with this email already exists",
                     message="User with this email already exists",
@@ -166,27 +176,39 @@ class PostgresUserHandler(Handler):
                 raise e
                 raise e
 
 
         hashed_password = self.crypto_provider.get_password_hash(password)  # type: ignore
         hashed_password = self.crypto_provider.get_password_hash(password)  # type: ignore
-        query = f"""
-            INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            (email, id, is_superuser, hashed_password, collection_ids)
-            VALUES ($1, $2, $3, $4, $5)
-            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
-        """
-        result = await self.connection_manager.fetchrow_query(
-            query,
-            [
-                email,
-                generate_user_id(email),
-                is_superuser,
-                hashed_password,
-                [],
-            ],
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .insert(
+                {
+                    "email": email,
+                    "id": generate_user_id(email),
+                    "is_superuser": is_superuser,
+                    "hashed_password": hashed_password,
+                    "collection_ids": [],
+                    "limits_overrides": None,
+                }
+            )
+            .returning(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "limits_overrides",
+                ]
+            )
+            .build()
         )
         )
 
 
+        result = await self.connection_manager.fetchrow_query(query, params)
         if not result:
         if not result:
-            raise HTTPException(
+            raise R2RException(
                 status_code=500,
                 status_code=500,
-                detail="Failed to create user",
+                message="Failed to create user",
             )
             )
 
 
         return User(
         return User(
@@ -197,17 +219,62 @@ class PostgresUserHandler(Handler):
             is_verified=result["is_verified"],
             is_verified=result["is_verified"],
             created_at=result["created_at"],
             created_at=result["created_at"],
             updated_at=result["updated_at"],
             updated_at=result["updated_at"],
-            collection_ids=result["collection_ids"],
+            collection_ids=result["collection_ids"] or [],
             hashed_password=hashed_password,
             hashed_password=hashed_password,
+            limits_overrides=json.loads(result["limits_overrides"] or "{}"),
+            name=None,
+            bio=None,
+            profile_picture=None,
         )
         )
 
 
-    async def update_user(self, user: User) -> User:
+    async def update_user(
+        self, user: User, merge_limits: bool = False
+    ) -> User:
+        """
+        Update user information including limits_overrides.
+
+        Args:
+            user: User object containing updated information
+            merge_limits: If True, will merge existing limits_overrides with new ones.
+                        If False, will overwrite existing limits_overrides.
+
+        Returns:
+            Updated User object
+        """
+        # Get current user if we need to merge limits or get hashed password
+        current_user = None
+        try:
+            current_user = await self.get_user_by_id(user.id)
+        except R2RException:
+            raise R2RException(status_code=404, message="User not found")
+
+        # Merge or replace limits_overrides
+        final_limits = user.limits_overrides
+        if (
+            merge_limits
+            and current_user.limits_overrides
+            and user.limits_overrides
+        ):
+            final_limits = {
+                **current_user.limits_overrides,
+                **user.limits_overrides,
+            }
         query = f"""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
-                name = $5, profile_picture = $6, bio = $7, collection_ids = $8
-            WHERE id = $9
-            RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
+            SET email = $1,
+                is_superuser = $2,
+                is_active = $3,
+                is_verified = $4,
+                updated_at = NOW(),
+                name = $5,
+                profile_picture = $6,
+                bio = $7,
+                collection_ids = $8,
+                limits_overrides = $9::jsonb
+            WHERE id = $10
+            RETURNING id, email, is_superuser, is_active, is_verified,
+                    created_at, updated_at, name, profile_picture, bio,
+                    collection_ids, limits_overrides, hashed_password
         """
         """
         result = await self.connection_manager.fetchrow_query(
         result = await self.connection_manager.fetchrow_query(
             query,
             query,
@@ -219,7 +286,8 @@ class PostgresUserHandler(Handler):
                 user.name,
                 user.name,
                 user.profile_picture,
                 user.profile_picture,
                 user.bio,
                 user.bio,
-                user.collection_ids,
+                user.collection_ids or [],  # Ensure null becomes empty array
+                json.dumps(final_limits),  # Already handled null case
                 user.id,
                 user.id,
             ],
             ],
         )
         )
@@ -233,6 +301,9 @@ class PostgresUserHandler(Handler):
         return User(
         return User(
             id=result["id"],
             id=result["id"],
             email=result["email"],
             email=result["email"],
+            hashed_password=result[
+                "hashed_password"
+            ],  # Include hashed_password
             is_superuser=result["is_superuser"],
             is_superuser=result["is_superuser"],
             is_active=result["is_active"],
             is_active=result["is_active"],
             is_verified=result["is_verified"],
             is_verified=result["is_verified"],
@@ -241,15 +312,23 @@ class PostgresUserHandler(Handler):
             name=result["name"],
             name=result["name"],
             profile_picture=result["profile_picture"],
             profile_picture=result["profile_picture"],
             bio=result["bio"],
             bio=result["bio"],
-            collection_ids=result["collection_ids"],
+            collection_ids=result["collection_ids"]
+            or [],  # Ensure null becomes empty array
+            limits_overrides=json.loads(
+                result["limits_overrides"] or "{}"
+            ),  # Can be null
         )
         )
 
 
     async def delete_user_relational(self, id: UUID) -> None:
     async def delete_user_relational(self, id: UUID) -> None:
+        """Delete a user and update related records."""
         # Get the collections the user belongs to
         # Get the collections the user belongs to
-        collection_query = f"""
-            SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            WHERE id = $1
-        """
+        collection_query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(["collection_ids"])
+            .where("id = $1")
+            .build()
+        )
+
         collection_result = await self.connection_manager.fetchrow_query(
         collection_result = await self.connection_manager.fetchrow_query(
             collection_query, [id]
             collection_query, [id]
         )
         )
@@ -257,20 +336,25 @@ class PostgresUserHandler(Handler):
         if not collection_result:
         if not collection_result:
             raise R2RException(status_code=404, message="User not found")
             raise R2RException(status_code=404, message="User not found")
 
 
-        # Remove user from documents
-        doc_update_query = f"""
-            UPDATE {self._get_table_name('documents')}
-            SET id = NULL
-            WHERE id = $1
-        """
+        # Update documents query
+        doc_update_query, doc_params = (
+            QueryBuilder(self._get_table_name("documents"))
+            .update({"id": None})
+            .where("id = $1")
+            .build()
+        )
+
         await self.connection_manager.execute_query(doc_update_query, [id])
         await self.connection_manager.execute_query(doc_update_query, [id])
 
 
-        # Delete the user
-        delete_query = f"""
-            DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            WHERE id = $1
-            RETURNING id
-        """
+        # Delete user query
+        delete_query, del_params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .delete()
+            .where("id = $1")
+            .returning(["id"])
+            .build()
+        )
+
         result = await self.connection_manager.fetchrow_query(
         result = await self.connection_manager.fetchrow_query(
             delete_query, [id]
             delete_query, [id]
         )
         )
@@ -288,24 +372,48 @@ class PostgresUserHandler(Handler):
             query, [new_hashed_password, id]
             query, [new_hashed_password, id]
         )
         )
 
 
-    async def get_all_users(self) -> list[User]:
-        query = f"""
-            SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
-            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-        """
-        results = await self.connection_manager.fetch_query(query)
+    async def get_all_users(self) -> List[User]:
+        """Get all users with minimal information."""
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_superuser",
+                    "is_active",
+                    "is_verified",
+                    "created_at",
+                    "updated_at",
+                    "collection_ids",
+                    "hashed_password",
+                    "limits_overrides",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                ]
+            )
+            .build()
+        )
 
 
+        results = await self.connection_manager.fetch_query(query, params)
         return [
         return [
             User(
             User(
                 id=result["id"],
                 id=result["id"],
                 email=result["email"],
                 email=result["email"],
-                hashed_password="null",
+                hashed_password=result["hashed_password"],
                 is_superuser=result["is_superuser"],
                 is_superuser=result["is_superuser"],
                 is_active=result["is_active"],
                 is_active=result["is_active"],
                 is_verified=result["is_verified"],
                 is_verified=result["is_verified"],
                 created_at=result["created_at"],
                 created_at=result["created_at"],
                 updated_at=result["updated_at"],
                 updated_at=result["updated_at"],
-                collection_ids=result["collection_ids"],
+                collection_ids=result["collection_ids"] or [],
+                limits_overrides=json.loads(
+                    result["limits_overrides"] or "{}"
+                ),
+                name=result["name"],
+                bio=result["bio"],
+                profile_picture=result["profile_picture"],
             )
             )
             for result in results
             for result in results
         ]
         ]
@@ -456,41 +564,44 @@ class PostgresUserHandler(Handler):
     async def get_users_in_collection(
     async def get_users_in_collection(
         self, collection_id: UUID, offset: int, limit: int
         self, collection_id: UUID, offset: int, limit: int
     ) -> dict[str, list[User] | int]:
     ) -> dict[str, list[User] | int]:
-        """
-        Get all users in a specific collection with pagination.
-
-        Args:
-            collection_id (UUID): The ID of the collection to get users from.
-            offset (int): The number of users to skip.
-            limit (int): The maximum number of users to return.
-
-        Returns:
-            List[User]: A list of User objects representing the users in the collection.
-
-        Raises:
-            R2RException: If the collection doesn't exist.
-        """
-        if not await self._collection_exists(collection_id):  # type: ignore
+        """Get all users in a specific collection with pagination."""
+        if not await self._collection_exists(collection_id):
             raise R2RException(status_code=404, message="Collection not found")
             raise R2RException(status_code=404, message="Collection not found")
 
 
-        query = f"""
-            SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
-                u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
-                COUNT(*) OVER() AS total_entries
-            FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
-            WHERE $1 = ANY(u.collection_ids)
-            ORDER BY u.name
-            OFFSET $2
-        """
+        query, params = (
+            QueryBuilder(self._get_table_name(self.TABLE_NAME))
+            .select(
+                [
+                    "id",
+                    "email",
+                    "is_active",
+                    "is_superuser",
+                    "created_at",
+                    "updated_at",
+                    "is_verified",
+                    "collection_ids",
+                    "name",
+                    "bio",
+                    "profile_picture",
+                    "hashed_password",
+                    "limits_overrides",
+                    "COUNT(*) OVER() AS total_entries",
+                ]
+            )
+            .where("$1 = ANY(collection_ids)")
+            .order_by("name")
+            .offset("$2")
+            .limit("$3" if limit != -1 else None)
+            .build()
+        )
 
 
         conditions = [collection_id, offset]
         conditions = [collection_id, offset]
         if limit != -1:
         if limit != -1:
-            query += " LIMIT $3"
             conditions.append(limit)
             conditions.append(limit)
 
 
         results = await self.connection_manager.fetch_query(query, conditions)
         results = await self.connection_manager.fetch_query(query, conditions)
 
 
-        users = [
+        users_list = [
             User(
             User(
                 id=row["id"],
                 id=row["id"],
                 email=row["email"],
                 email=row["email"],
@@ -499,24 +610,24 @@ class PostgresUserHandler(Handler):
                 created_at=row["created_at"],
                 created_at=row["created_at"],
                 updated_at=row["updated_at"],
                 updated_at=row["updated_at"],
                 is_verified=row["is_verified"],
                 is_verified=row["is_verified"],
-                collection_ids=row["collection_ids"],
+                collection_ids=row["collection_ids"] or [],
                 name=row["name"],
                 name=row["name"],
                 bio=row["bio"],
                 bio=row["bio"],
                 profile_picture=row["profile_picture"],
                 profile_picture=row["profile_picture"],
-                hashed_password=None,
-                verification_code_expiry=None,
+                hashed_password=row["hashed_password"],
+                limits_overrides=json.loads(row["limits_overrides"] or "{}"),
             )
             )
             for row in results
             for row in results
         ]
         ]
 
 
         total_entries = results[0]["total_entries"] if results else 0
         total_entries = results[0]["total_entries"] if results else 0
-
-        return {"results": users, "total_entries": total_entries}
+        return {"results": users_list, "total_entries": total_entries}
 
 
     async def mark_user_as_superuser(self, id: UUID):
     async def mark_user_as_superuser(self, id: UUID):
         query = f"""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            SET is_superuser = TRUE, is_verified = TRUE,
+                verification_code = NULL, verification_code_expiry = NULL
             WHERE id = $1
             WHERE id = $1
         """
         """
         await self.connection_manager.execute_query(query, [id])
         await self.connection_manager.execute_query(query, [id])
@@ -542,7 +653,9 @@ class PostgresUserHandler(Handler):
     async def mark_user_as_verified(self, id: UUID):
     async def mark_user_as_verified(self, id: UUID):
         query = f"""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
             UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
-            SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
+            SET is_verified = TRUE,
+                verification_code = NULL,
+                verification_code_expiry = NULL
             WHERE id = $1
             WHERE id = $1
         """
         """
         await self.connection_manager.execute_query(query, [id])
         await self.connection_manager.execute_query(query, [id])
@@ -553,7 +666,9 @@ class PostgresUserHandler(Handler):
         limit: int,
         limit: int,
         user_ids: Optional[list[UUID]] = None,
         user_ids: Optional[list[UUID]] = None,
     ) -> dict[str, list[User] | int]:
     ) -> dict[str, list[User] | int]:
-
+        """
+        Return users with document usage and total entries.
+        """
         query = f"""
         query = f"""
             WITH user_document_ids AS (
             WITH user_document_ids AS (
                 SELECT
                 SELECT
@@ -604,36 +719,36 @@ class PostgresUserHandler(Handler):
             params.append(user_ids)
             params.append(user_ids)
 
 
         results = await self.connection_manager.fetch_query(query, params)
         results = await self.connection_manager.fetch_query(query, params)
+        if not results:
+            raise R2RException(status_code=404, message="No users found")
 
 
-        users = [
-            User(
-                id=row["id"],
-                email=row["email"],
-                is_superuser=row["is_superuser"],
-                is_active=row["is_active"],
-                is_verified=row["is_verified"],
-                name=row["name"],
-                bio=row["bio"],
-                created_at=row["created_at"],
-                updated_at=row["updated_at"],
-                collection_ids=row["collection_ids"] or [],
-                num_files=row["num_files"],
-                total_size_in_bytes=row["total_size_in_bytes"],
-                document_ids=(
-                    []
-                    if row["document_ids"] is None
-                    else list(row["document_ids"])
-                ),
+        users_list = []
+        for row in results:
+            users_list.append(
+                User(
+                    id=row["id"],
+                    email=row["email"],
+                    is_superuser=row["is_superuser"],
+                    is_active=row["is_active"],
+                    is_verified=row["is_verified"],
+                    name=row["name"],
+                    bio=row["bio"],
+                    created_at=row["created_at"],
+                    updated_at=row["updated_at"],
+                    profile_picture=row["profile_picture"],
+                    collection_ids=row["collection_ids"] or [],
+                    num_files=row["num_files"],
+                    total_size_in_bytes=row["total_size_in_bytes"],
+                    document_ids=(
+                        list(row["document_ids"])
+                        if row["document_ids"]
+                        else []
+                    ),
+                )
             )
             )
-            for row in results
-        ]
-
-        if not users:
-            raise R2RException(status_code=404, message="No users found")
 
 
         total_entries = results[0]["total_entries"]
         total_entries = results[0]["total_entries"]
-
-        return {"results": users, "total_entries": total_entries}
+        return {"results": users_list, "total_entries": total_entries}
 
 
     async def _collection_exists(self, collection_id: UUID) -> bool:
     async def _collection_exists(self, collection_id: UUID) -> bool:
         """Check if a collection exists."""
         """Check if a collection exists."""
@@ -693,7 +808,7 @@ class PostgresUserHandler(Handler):
         hashed_key: str,
         hashed_key: str,
         name: Optional[str] = None,
         name: Optional[str] = None,
     ) -> UUID:
     ) -> UUID:
-        """Store a new API key for a user"""
+        """Store a new API key for a user."""
         query = f"""
         query = f"""
             INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             (user_id, public_key, hashed_key, name)
             (user_id, public_key, hashed_key, name)
@@ -710,7 +825,10 @@ class PostgresUserHandler(Handler):
         return result["id"]
         return result["id"]
 
 
     async def get_api_key_record(self, key_id: str) -> Optional[dict]:
     async def get_api_key_record(self, key_id: str) -> Optional[dict]:
-        """Get API key record and update updated_at"""
+        """
+        Get API key record by 'public_key' and update 'updated_at' to now.
+        Returns { "user_id", "hashed_key" } or None if not found.
+        """
         query = f"""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             SET updated_at = NOW()
             SET updated_at = NOW()
@@ -726,7 +844,7 @@ class PostgresUserHandler(Handler):
         }
         }
 
 
     async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
     async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
-        """Get all API keys for a user"""
+        """Get all API keys for a user."""
         query = f"""
         query = f"""
             SELECT id, public_key, name, created_at, updated_at
             SELECT id, public_key, name, created_at, updated_at
             FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
@@ -745,7 +863,7 @@ class PostgresUserHandler(Handler):
         ]
         ]
 
 
     async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
     async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
-        """Delete a specific API key"""
+        """Delete a specific API key."""
         query = f"""
         query = f"""
             DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             WHERE id = $1 AND user_id = $2
             WHERE id = $1 AND user_id = $2
@@ -766,7 +884,7 @@ class PostgresUserHandler(Handler):
     async def update_api_key_name(
     async def update_api_key_name(
         self, user_id: UUID, key_id: UUID, name: str
         self, user_id: UUID, key_id: UUID, name: str
     ) -> bool:
     ) -> bool:
-        """Update the name of an API key"""
+        """Update the name of an existing API key."""
         query = f"""
         query = f"""
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
             SET name = $1, updated_at = NOW()
             SET name = $1, updated_at = NOW()

+ 1 - 1
core/examples/hello_r2r.py

@@ -1,6 +1,6 @@
 from r2r import R2RClient
 from r2r import R2RClient
 
 
-client = R2RClient("http://localhost:7272")
+client = R2RClient()
 
 
 with open("test.txt", "w") as file:
 with open("test.txt", "w") as file:
     file.write("John is a person that works at Google.")
     file.write("John is a person that works at Google.")

+ 6 - 6
core/main/abstractions.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
@@ -107,8 +107,8 @@ class R2RAgents(BaseModel):
 
 
 @dataclass
 @dataclass
 class R2RServices:
 class R2RServices:
-    auth: Optional["AuthService"] = None
-    ingestion: Optional["IngestionService"] = None
-    management: Optional["ManagementService"] = None
-    retrieval: Optional["RetrievalService"] = None
-    graph: Optional["GraphService"] = None
+    auth: "AuthService"
+    ingestion: "IngestionService"
+    management: "ManagementService"
+    retrieval: "RetrievalService"
+    graph: "GraphService"

+ 57 - 104
core/main/api/v3/base_router.py

@@ -1,7 +1,7 @@
 import functools
 import functools
 import logging
 import logging
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Callable
+from typing import Callable, Optional
 
 
 from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
 from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
@@ -15,100 +15,19 @@ logger = logging.getLogger()
 
 
 class BaseRouterV3:
 class BaseRouterV3:
     def __init__(self, providers: R2RProviders, services: R2RServices):
     def __init__(self, providers: R2RProviders, services: R2RServices):
+        """
+        :param providers: Typically includes auth, database, etc.
+        :param services: Additional service references (ingestion, run_manager, etc).
+        """
         self.providers = providers
         self.providers = providers
         self.services = services
         self.services = services
         self.router = APIRouter()
         self.router = APIRouter()
         self.openapi_extras = self._load_openapi_extras()
         self.openapi_extras = self._load_openapi_extras()
-        self._setup_routes()
-        self._register_workflows()
-
-    def get_router(self):
-        return self.router
-
-    def base_endpoint(self, func: Callable):
-        @functools.wraps(func)
-        async def wrapper(*args, **kwargs):
-            async with manage_run(
-                self.services.ingestion.run_manager, func.__name__
-            ) as run_id:
-                auth_user = kwargs.get("auth_user")
-                if auth_user:
-                    await self.services.ingestion.run_manager.log_run_info(  # TODO - this is a bit of a hack
-                        user=auth_user,
-                    )
-
-                try:
-                    func_result = await func(*args, **kwargs)
-                    if (
-                        isinstance(func_result, tuple)
-                        and len(func_result) == 2
-                    ):
-                        results, outer_kwargs = func_result
-                    else:
-                        results, outer_kwargs = func_result, {}
-
-                    if isinstance(results, StreamingResponse):
-                        return results
-                    return {"results": results, **outer_kwargs}
-
-                except R2RException:
-                    raise
-
-                except Exception as e:
 
 
-                    logger.error(
-                        f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
-                        exc_info=True,
-                    )
-
-                    raise HTTPException(
-                        status_code=500,
-                        detail={
-                            "message": f"An error '{e}' occurred during {func.__name__}",
-                            "error": str(e),
-                            "error_type": type(e).__name__,
-                        },
-                    ) from e
-
-        return wrapper
-
-    @classmethod
-    def build_router(cls, engine):
-        return cls(engine).router
-
-    def _register_workflows(self):
-        pass
-
-    def _load_openapi_extras(self):
-        return {}
-
-    @abstractmethod
-    def _setup_routes(self):
-        pass
-
-
-import functools
-import logging
-from abc import abstractmethod
-from typing import Callable, Optional
-
-from fastapi import APIRouter, Depends, HTTPException, Request
-from fastapi.responses import StreamingResponse
-
-from core.base import R2RException, manage_run
-
-from ...abstractions import R2RProviders, R2RServices
-
-logger = logging.getLogger()
-
-
-class BaseRouterV3:
-    def __init__(self, providers: R2RProviders, services: R2RServices):
-        self.providers = providers
-        self.services = services
-        self.router = APIRouter()
-        self.openapi_extras = self._load_openapi_extras()
+        # Add the rate-limiting dependency
         self.set_rate_limiting()
         self.set_rate_limiting()
+
+        # Initialize any routes
         self._setup_routes()
         self._setup_routes()
         self._register_workflows()
         self._register_workflows()
 
 
@@ -116,6 +35,13 @@ class BaseRouterV3:
         return self.router
         return self.router
 
 
     def base_endpoint(self, func: Callable):
     def base_endpoint(self, func: Callable):
+        """
+        A decorator to wrap endpoints in a standard pattern:
+         - manage_run context
+         - error handling
+         - response shaping
+        """
+
         @functools.wraps(func)
         @functools.wraps(func)
         async def wrapper(*args, **kwargs):
         async def wrapper(*args, **kwargs):
             async with manage_run(
             async with manage_run(
@@ -123,6 +49,7 @@ class BaseRouterV3:
             ) as run_id:
             ) as run_id:
                 auth_user = kwargs.get("auth_user")
                 auth_user = kwargs.get("auth_user")
                 if auth_user:
                 if auth_user:
+                    # Optionally log run info with the user
                     await self.services.ingestion.run_manager.log_run_info(
                     await self.services.ingestion.run_manager.log_run_info(
                         user=auth_user,
                         user=auth_user,
                     )
                     )
@@ -143,13 +70,11 @@ class BaseRouterV3:
 
 
                 except R2RException:
                 except R2RException:
                     raise
                     raise
-
                 except Exception as e:
                 except Exception as e:
                     logger.error(
                     logger.error(
-                        f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
+                        f"Error in base endpoint {func.__name__}() - {str(e)}",
                         exc_info=True,
                         exc_info=True,
                     )
                     )
-
                     raise HTTPException(
                     raise HTTPException(
                         status_code=500,
                         status_code=500,
                         detail={
                         detail={
@@ -163,6 +88,9 @@ class BaseRouterV3:
 
 
     @classmethod
     @classmethod
     def build_router(cls, engine):
     def build_router(cls, engine):
+        """
+        Class method for building a router instance (if you have a standard pattern).
+        """
         return cls(engine).router
         return cls(engine).router
 
 
     def _register_workflows(self):
     def _register_workflows(self):
@@ -173,48 +101,73 @@ class BaseRouterV3:
 
 
     @abstractmethod
     @abstractmethod
     def _setup_routes(self):
     def _setup_routes(self):
+        """
+        Subclasses override this to define actual endpoints.
+        """
         pass
         pass
 
 
     def set_rate_limiting(self):
     def set_rate_limiting(self):
         """
         """
-        Set up a yield dependency for rate limiting and logging.
+        Adds a yield-based dependency for rate limiting each request.
+        Checks the limits, then logs the request if the check passes.
         """
         """
 
 
         async def rate_limit_dependency(
         async def rate_limit_dependency(
             request: Request,
             request: Request,
             auth_user=Depends(self.providers.auth.auth_wrapper()),
             auth_user=Depends(self.providers.auth.auth_wrapper()),
         ):
         ):
+            """
+            1) Fetch the user from the DB (including .limits_overrides).
+            2) Pass it to limits_handler.check_limits.
+            3) After the endpoint completes, call limits_handler.log_request.
+            """
+            # If the user is superuser, skip checks
+            if auth_user.is_superuser:
+                yield
+                return
+
             user_id = auth_user.id
             user_id = auth_user.id
             route = request.scope["path"]
             route = request.scope["path"]
-            # Check the limits before proceeding
+
+            # 1) Fetch the user from DB
+            user = await self.providers.database.users_handler.get_user_by_id(
+                user_id
+            )
+            if not user:
+                raise HTTPException(status_code=404, detail="User not found.")
+
+            # 2) Rate-limit check
             try:
             try:
-                if not auth_user.is_superuser:
-                    await self.providers.database.limits_handler.check_limits(
-                        user_id, route
-                    )
+                await self.providers.database.limits_handler.check_limits(
+                    user=user, route=route  # Pass the User object
+                )
             except ValueError as e:
             except ValueError as e:
+                # If check_limits raises ValueError -> 429 Too Many Requests
                 raise HTTPException(status_code=429, detail=str(e))
                 raise HTTPException(status_code=429, detail=str(e))
 
 
             request.state.user_id = user_id
             request.state.user_id = user_id
             request.state.route = route
             request.state.route = route
-            # Yield to run the route
+
+            # 3) Execute the route
             try:
             try:
                 yield
                 yield
             finally:
             finally:
-                # After the route completes successfully, log the request
+                # 4) Log the request afterwards
                 await self.providers.database.limits_handler.log_request(
                 await self.providers.database.limits_handler.log_request(
                     user_id, route
                     user_id, route
                 )
                 )
 
 
-        async def websocket_rate_limit_dependency(
-            websocket: WebSocket,
-        ):
+        async def websocket_rate_limit_dependency(websocket: WebSocket):
+            # Example: if you want to rate-limit websockets similarly
             route = websocket.scope["path"]
             route = websocket.scope["path"]
+            # If you had a user or token, you'd do the same check.
             try:
             try:
+                # e.g. check_limits(user_id, route)
                 return True
                 return True
-            except ValueError as e:
+            except ValueError:
                 await websocket.close(code=4429, reason="Rate limit exceeded")
                 await websocket.close(code=4429, reason="Rate limit exceeded")
                 return False
                 return False
 
 
+        # Attach the dependencies so you can use them in your endpoints
         self.rate_limit_dependency = rate_limit_dependency
         self.rate_limit_dependency = rate_limit_dependency
         self.websocket_rate_limit_dependency = websocket_rate_limit_dependency
         self.websocket_rate_limit_dependency = websocket_rate_limit_dependency

+ 9 - 9
core/main/api/v3/chunks_router.py

@@ -55,7 +55,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.search(
                             response = client.chunks.search(
                                 query="search query",
                                 query="search query",
                                 search_settings={
                                 search_settings={
@@ -110,7 +110,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.retrieve(
                             response = client.chunks.retrieve(
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                             )
                             )
@@ -123,7 +123,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.chunks.retrieve({
                                 const response = await client.chunks.retrieve({
@@ -183,7 +183,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.update(
                             response = client.chunks.update(
                                 {
                                 {
                                     "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
                                     "id": "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa",
@@ -200,7 +200,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.chunks.update({
                                 const response = await client.chunks.update({
@@ -276,7 +276,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.delete(
                             response = client.chunks.delete(
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                                 id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"
                             )
                             )
@@ -289,7 +289,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.chunks.delete({
                                 const response = await client.chunks.delete({
@@ -347,7 +347,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             response = client.chunks.list(
                             response = client.chunks.list(
                                 metadata_filter={"key": "value"},
                                 metadata_filter={"key": "value"},
                                 include_vectors=False,
                                 include_vectors=False,
@@ -363,7 +363,7 @@ class ChunksRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.chunks.list({
                                 const response = await client.chunks.list({

+ 23 - 23
core/main/api/v3/collections_router.py

@@ -101,7 +101,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.create(
                             result = client.collections.create(
@@ -117,7 +117,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.create({
                                 const response = await client.collections.create({
@@ -189,7 +189,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.list(
                             result = client.collections.list(
@@ -205,7 +205,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.list();
                                 const response = await client.collections.list();
@@ -298,7 +298,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000")
                             result = client.collections.retrieve("123e4567-e89b-12d3-a456-426614174000")
@@ -311,7 +311,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"});
                                 const response = await client.collections.retrieve({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -387,7 +387,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.update(
                             result = client.collections.update(
@@ -404,7 +404,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.update({
                                 const response = await client.collections.update({
@@ -485,7 +485,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000")
                             result = client.collections.delete("123e4567-e89b-12d3-a456-426614174000")
@@ -498,7 +498,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"});
                                 const response = await client.collections.delete({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -562,7 +562,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.add_document(
                             result = client.collections.add_document(
@@ -578,7 +578,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.addDocument({
                                 const response = await client.collections.addDocument({
@@ -634,7 +634,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.list_documents(
                             result = client.collections.list_documents(
@@ -651,7 +651,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"});
                                 const response = await client.collections.listDocuments({id: "123e4567-e89b-12d3-a456-426614174000"});
@@ -733,7 +733,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.remove_document(
                             result = client.collections.remove_document(
@@ -749,7 +749,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.removeDocument({
                                 const response = await client.collections.removeDocument({
@@ -811,7 +811,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.list_users(
                             result = client.collections.list_users(
@@ -828,7 +828,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.listUsers({
                                 const response = await client.collections.listUsers({
@@ -912,7 +912,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.add_user(
                             result = client.collections.add_user(
@@ -928,7 +928,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.addUser({
                                 const response = await client.collections.addUser({
@@ -990,7 +990,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.collections.remove_user(
                             result = client.collections.remove_user(
@@ -1006,7 +1006,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.collections.removeUser({
                                 const response = await client.collections.removeUser({
@@ -1070,7 +1070,7 @@ class CollectionsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.documents.extract(
                             result = client.documents.extract(

+ 14 - 14
core/main/api/v3/conversations_router.py

@@ -42,7 +42,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.create()
                             result = client.conversations.create()
@@ -55,7 +55,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.create();
                                 const response = await client.conversations.create();
@@ -116,7 +116,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.list(
                             result = client.conversations.list(
@@ -132,7 +132,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.list();
                                 const response = await client.conversations.list();
@@ -218,7 +218,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.get(
                             result = client.conversations.get(
@@ -233,7 +233,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.retrieve({
                                 const response = await client.conversations.retrieve({
@@ -299,7 +299,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
                             result = client.conversations.update("123e4567-e89b-12d3-a456-426614174000", "new_name")
@@ -312,7 +312,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.update({
                                 const response = await client.conversations.update({
@@ -382,7 +382,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000")
                             result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000")
@@ -395,7 +395,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.delete({
                                 const response = await client.conversations.delete({
@@ -462,7 +462,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.add_message(
                             result = client.conversations.add_message(
@@ -481,7 +481,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.addMessage({
                                 const response = await client.conversations.addMessage({
@@ -558,7 +558,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.conversations.update_message(
                             result = client.conversations.update_message(
@@ -575,7 +575,7 @@ class ConversationsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.conversations.updateMessage({
                                 const response = await client.conversations.updateMessage({

+ 38 - 36
core/main/api/v3/documents_router.py

@@ -198,7 +198,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.create(
                             response = client.documents.create(
@@ -215,7 +215,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.create({
                                 const response = await client.documents.create({
@@ -558,7 +558,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.list(
                             response = client.documents.list(
@@ -574,7 +574,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.list({
                                 const response = await client.documents.list({
@@ -680,7 +680,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.retrieve(
                             response = client.documents.retrieve(
@@ -695,7 +695,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.retrieve({
                                 const response = await client.documents.retrieve({
@@ -776,7 +776,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.list_chunks(
                             response = client.documents.list_chunks(
@@ -791,7 +791,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.listChunks({
                                 const response = await client.documents.listChunks({
@@ -910,7 +910,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.download(
                             response = client.documents.download(
@@ -925,7 +925,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.download({
                                 const response = await client.documents.download({
@@ -1053,7 +1053,7 @@ class DocumentsRouter(BaseRouterV3):
                         "source": textwrap.dedent(
                         "source": textwrap.dedent(
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
                             response = client.documents.delete_by_filter(
                             response = client.documents.delete_by_filter(
                                 filters={"document_type": {"$eq": "txt"}}
                                 filters={"document_type": {"$eq": "txt"}}
@@ -1105,7 +1105,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.delete(
                             response = client.documents.delete(
@@ -1120,7 +1120,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.delete({
                                 const response = await client.documents.delete({
@@ -1186,7 +1186,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.list_collections(
                             response = client.documents.list_collections(
@@ -1201,7 +1201,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.listCollections({
                                 const response = await client.documents.listCollections({
@@ -1291,7 +1291,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.extract(
                             response = client.documents.extract(
@@ -1403,7 +1403,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.extract(
                             response = client.documents.extract(
@@ -1477,14 +1477,15 @@ class DocumentsRouter(BaseRouterV3):
                 raise R2RException("Document not found.", 404)
                 raise R2RException("Document not found.", 404)
 
 
             # Get all entities for this document from the document_entity table
             # Get all entities for this document from the document_entity table
-            entities, count = (
-                await self.providers.database.graphs_handler.entities.get(
-                    parent_id=id,
-                    store_type="documents",
-                    offset=offset,
-                    limit=limit,
-                    include_embeddings=include_embeddings,
-                )
+            (
+                entities,
+                count,
+            ) = await self.providers.database.graphs_handler.entities.get(
+                parent_id=id,
+                store_type="documents",
+                offset=offset,
+                limit=limit,
+                include_embeddings=include_embeddings,
             )
             )
 
 
             return entities, {"total_entries": count}  # type: ignore
             return entities, {"total_entries": count}  # type: ignore
@@ -1501,7 +1502,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.documents.list_relationships(
                             response = client.documents.list_relationships(
@@ -1518,7 +1519,7 @@ class DocumentsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.documents.listRelationships({
                                 const response = await client.documents.listRelationships({
@@ -1618,15 +1619,16 @@ class DocumentsRouter(BaseRouterV3):
                 raise R2RException("Document not found.", 404)
                 raise R2RException("Document not found.", 404)
 
 
             # Get relationships for this document
             # Get relationships for this document
-            relationships, count = (
-                await self.providers.database.graphs_handler.relationships.get(
-                    parent_id=id,
-                    store_type="documents",
-                    entity_names=entity_names,
-                    relationship_types=relationship_types,
-                    offset=offset,
-                    limit=limit,
-                )
+            (
+                relationships,
+                count,
+            ) = await self.providers.database.graphs_handler.relationships.get(
+                parent_id=id,
+                store_type="documents",
+                entity_names=entity_names,
+                relationship_types=relationship_types,
+                offset=offset,
+                limit=limit,
             )
             )
 
 
             return relationships, {"total_entries": count}  # type: ignore
             return relationships, {"total_entries": count}  # type: ignore

+ 34 - 36
core/main/api/v3/graph_router.py

@@ -40,7 +40,6 @@ class GraphRouter(BaseRouterV3):
         self._register_workflows()
         self._register_workflows()
 
 
     def _register_workflows(self):
     def _register_workflows(self):
-
         workflow_messages = {}
         workflow_messages = {}
         if self.providers.orchestration.config.provider == "hatchet":
         if self.providers.orchestration.config.provider == "hatchet":
             workflow_messages["extract-triples"] = (
             workflow_messages["extract-triples"] = (
@@ -164,7 +163,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.list()
                             response = client.graphs.list()
@@ -177,7 +176,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.list({});
                                 const response = await client.graphs.list({});
@@ -247,7 +246,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.get(
                             response = client.graphs.get(
@@ -261,7 +260,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.retrieve({
                                 const response = await client.graphs.retrieve({
@@ -386,7 +385,6 @@ class GraphRouter(BaseRouterV3):
             }
             }
 
 
             if run_with_orchestration:
             if run_with_orchestration:
-
                 return await self.providers.orchestration.run_workflow(  # type: ignore
                 return await self.providers.orchestration.run_workflow(  # type: ignore
                     "build-communities", {"request": workflow_input}, {}
                     "build-communities", {"request": workflow_input}, {}
                 )
                 )
@@ -413,7 +411,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.reset(
                             response = client.graphs.reset(
@@ -427,7 +425,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.reset({
                                 const response = await client.graphs.reset({
@@ -493,7 +491,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.update(
                             response = client.graphs.update(
@@ -511,7 +509,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.update({
                                 const response = await client.graphs.update({
@@ -579,10 +577,10 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
-                            response = client.graphs.get_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
+                            response = client.graphs.list_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
                             """
                             """
                         ),
                         ),
                     },
                     },
@@ -592,10 +590,10 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
-                                const response = await client.graphs.get_entities({
+                                const response = await client.graphs.listEntities({
                                     collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
                                     collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7",
                                 });
                                 });
                             }
                             }
@@ -767,7 +765,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.get_entity(
                             response = client.graphs.get_entity(
@@ -783,7 +781,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.get_entity({
                                 const response = await client.graphs.get_entity({
@@ -894,7 +892,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.remove_entity(
                             response = client.graphs.remove_entity(
@@ -910,7 +908,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.removeEntity({
                                 const response = await client.graphs.removeEntity({
@@ -973,7 +971,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
                             response = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7")
@@ -986,7 +984,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.listRelationships({
                                 const response = await client.graphs.listRelationships({
@@ -1055,7 +1053,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.get_relationship(
                             response = client.graphs.get_relationship(
@@ -1071,7 +1069,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.getRelationship({
                                 const response = await client.graphs.getRelationship({
@@ -1202,7 +1200,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.delete_relationship(
                             response = client.graphs.delete_relationship(
@@ -1218,7 +1216,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.deleteRelationship({
                                 const response = await client.graphs.deleteRelationship({
@@ -1280,7 +1278,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.create_community(
                             response = client.graphs.create_community(
@@ -1300,7 +1298,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.createCommunity({
                                 const response = await client.graphs.createCommunity({
@@ -1389,7 +1387,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
                             response = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
@@ -1402,7 +1400,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.listCommunities({
                                 const response = await client.graphs.listCommunities({
@@ -1471,7 +1469,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
                             response = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1")
@@ -1484,7 +1482,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.getCommunity({
                                 const response = await client.graphs.getCommunity({
@@ -1549,7 +1547,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.delete_community(
                             response = client.graphs.delete_community(
@@ -1565,7 +1563,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.graphs.deleteCommunity({
                                 const response = await client.graphs.deleteCommunity({
@@ -1629,7 +1627,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.update_community(
                             response = client.graphs.update_community(
@@ -1649,7 +1647,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             async function main() {
                             async function main() {
                                 const response = await client.graphs.updateCommunity({
                                 const response = await client.graphs.updateCommunity({
@@ -1724,7 +1722,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response = client.graphs.pull(
                             response = client.graphs.pull(
@@ -1738,7 +1736,7 @@ class GraphRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             async function main() {
                             async function main() {
                                 const response = await client.graphs.pull({
                                 const response = await client.graphs.pull({

+ 9 - 11
core/main/api/v3/indices_router.py

@@ -23,7 +23,6 @@ logger = logging.getLogger()
 
 
 
 
 class IndicesRouter(BaseRouterV3):
 class IndicesRouter(BaseRouterV3):
-
     def __init__(
     def __init__(
         self,
         self,
         providers: R2RProviders,
         providers: R2RProviders,
@@ -32,7 +31,6 @@ class IndicesRouter(BaseRouterV3):
         super().__init__(providers, services)
         super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
-
         ## TODO - Allow developer to pass the index id with the request
         ## TODO - Allow developer to pass the index id with the request
         @self.router.post(
         @self.router.post(
             "/indices",
             "/indices",
@@ -46,7 +44,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             # Create an HNSW index for efficient similarity search
                             # Create an HNSW index for efficient similarity search
@@ -91,7 +89,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.indicies.create({
                                 const response = await client.indicies.create({
@@ -246,7 +244,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
 
                             # List all indices
                             # List all indices
                             indices = client.indices.list(
                             indices = client.indices.list(
@@ -262,7 +260,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.indicies.list({
                                 const response = await client.indicies.list({
@@ -350,7 +348,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
 
                             # Get detailed information about a specific index
                             # Get detailed information about a specific index
                             index = client.indices.retrieve("index_1")
                             index = client.indices.retrieve("index_1")
@@ -363,7 +361,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.indicies.retrieve({
                                 const response = await client.indicies.retrieve({
@@ -454,7 +452,7 @@ class IndicesRouter(BaseRouterV3):
         #                         "source": """
         #                         "source": """
         # from r2r import R2RClient
         # from r2r import R2RClient
 
 
-        # client = R2RClient("http://localhost:7272")
+        # client = R2RClient()
 
 
         # # Update HNSW index parameters
         # # Update HNSW index parameters
         # result = client.indices.update(
         # result = client.indices.update(
@@ -514,7 +512,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
 
 
                             # Delete an index with orchestration for cleanup
                             # Delete an index with orchestration for cleanup
                             result = client.indices.delete(
                             result = client.indices.delete(
@@ -531,7 +529,7 @@ class IndicesRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.indicies.delete({
                                 const response = await client.indicies.delete({

+ 10 - 10
core/main/api/v3/prompts_router.py

@@ -38,7 +38,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.prompts.create(
                             result = client.prompts.create(
@@ -55,7 +55,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.prompts.create({
                                 const response = await client.prompts.create({
@@ -122,7 +122,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.prompts.list()
                             result = client.prompts.list()
@@ -135,7 +135,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.prompts.list();
                                 const response = await client.prompts.list();
@@ -202,7 +202,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.prompts.get(
                             result = client.prompts.get(
@@ -219,7 +219,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.prompts.retrieve({
                                 const response = await client.prompts.retrieve({
@@ -292,7 +292,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.prompts.update(
                             result = client.prompts.update(
@@ -309,7 +309,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.prompts.update({
                                 const response = await client.prompts.update({
@@ -376,7 +376,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.prompts.delete("greeting_prompt")
                             result = client.prompts.delete("greeting_prompt")
@@ -389,7 +389,7 @@ class PromptsRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.prompts.delete({
                                 const response = await client.prompts.delete({

+ 10 - 11
core/main/api/v3/retrieval_router.py

@@ -82,7 +82,6 @@ class RetrievalRouterV3(BaseRouterV3):
         return effective_settings
         return effective_settings
 
 
     def _setup_routes(self):
     def _setup_routes(self):
-
         @self.router.post(
         @self.router.post(
             "/retrieval/search",
             "/retrieval/search",
             dependencies=[Depends(self.rate_limit_dependency)],
             dependencies=[Depends(self.rate_limit_dependency)],
@@ -95,7 +94,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # if using auth, do client.login(...)
                             # if using auth, do client.login(...)
 
 
                             # Basic mode, no overrides
                             # Basic mode, no overrides
@@ -135,7 +134,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.search({
                                 const response = await client.search({
@@ -278,7 +277,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response =client.retrieval.rag(
                             response =client.retrieval.rag(
@@ -309,7 +308,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.retrieval.rag({
                                 const response = await client.retrieval.rag({
@@ -464,7 +463,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                         from r2r import R2RClient
                         from r2r import R2RClient
 
 
-                        client = R2RClient("http://localhost:7272")
+                        client = R2RClient()
                         # when using auth, do client.login(...)
                         # when using auth, do client.login(...)
 
 
                         response =client.retrieval.agent(
                         response =client.retrieval.agent(
@@ -500,7 +499,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.retrieval.agent({
                                 const response = await client.retrieval.agent({
@@ -693,7 +692,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             response =client.completion(
                             response =client.completion(
@@ -719,7 +718,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.completion({
                                 const response = await client.completion({
@@ -830,7 +829,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.retrieval.embedding(
                             result = client.retrieval.embedding(
@@ -845,7 +844,7 @@ class RetrievalRouterV3(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.retrieval.embedding({
                                 const response = await client.retrieval.embedding({

+ 6 - 6
core/main/api/v3/system_router.py

@@ -39,7 +39,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.system.health()
                             result = client.system.health()
@@ -52,7 +52,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.system.health();
                                 const response = await client.system.health();
@@ -98,7 +98,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.system.settings()
                             result = client.system.settings()
@@ -111,7 +111,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.system.settings();
                                 const response = await client.system.settings();
@@ -164,7 +164,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # when using auth, do client.login(...)
                             # when using auth, do client.login(...)
 
 
                             result = client.system.status()
                             result = client.system.status()
@@ -177,7 +177,7 @@ class SystemRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.system.status();
                                 const response = await client.system.status();

+ 49 - 40
core/main/api/v3/users_router.py

@@ -1,5 +1,5 @@
 import textwrap
 import textwrap
-from typing import Optional
+from typing import Optional, Union
 from uuid import UUID
 from uuid import UUID
 
 
 from fastapi import Body, Depends, Path, Query
 from fastapi import Body, Depends, Path, Query
@@ -19,6 +19,7 @@ from core.base.api.models import (
     WrappedUserResponse,
     WrappedUserResponse,
     WrappedUsersResponse,
     WrappedUsersResponse,
 )
 )
+from core.base.providers.database import LimitSettings
 
 
 from ...abstractions import R2RProviders, R2RServices
 from ...abstractions import R2RProviders, R2RServices
 from .base_router import BaseRouterV3
 from .base_router import BaseRouterV3
@@ -31,7 +32,6 @@ class UsersRouter(BaseRouterV3):
         super().__init__(providers, services)
         super().__init__(providers, services)
 
 
     def _setup_routes(self):
     def _setup_routes(self):
-
         @self.router.post(
         @self.router.post(
             "/users",
             "/users",
             # dependencies=[Depends(self.rate_limit_dependency)],
             # dependencies=[Depends(self.rate_limit_dependency)],
@@ -44,7 +44,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             new_user = client.users.create(
                             new_user = client.users.create(
                                 email="jane.doe@example.com",
                                 email="jane.doe@example.com",
                                 password="secure_password123"
                                 password="secure_password123"
@@ -57,7 +57,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.create({
                                 const response = await client.users.create({
@@ -152,7 +152,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             new_user = client.users.register(
                             new_user = client.users.register(
                                 email="jane.doe@example.com",
                                 email="jane.doe@example.com",
                                 password="secure_password123"
                                 password="secure_password123"
@@ -165,7 +165,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.register({
                                 const response = await client.users.register({
@@ -221,7 +221,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             tokens = client.users.verify_email(
                             tokens = client.users.verify_email(
                                 email="jane.doe@example.com",
                                 email="jane.doe@example.com",
                                 verification_code="1lklwal!awdclm"
                                 verification_code="1lklwal!awdclm"
@@ -234,7 +234,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.verifyEmail({
                                 const response = await client.users.verifyEmail({
@@ -296,7 +296,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             tokens = client.users.login(
                             tokens = client.users.login(
                                 email="jane.doe@example.com",
                                 email="jane.doe@example.com",
                                 password="secure_password123"
                                 password="secure_password123"
@@ -310,7 +310,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.login({
                                 const response = await client.users.login({
@@ -354,7 +354,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
                             result = client.users.logout()
                             result = client.users.logout()
                             """
                             """
@@ -366,7 +366,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.logout();
                                 const response = await client.users.logout();
@@ -408,7 +408,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             new_tokens = client.users.refresh_token()
                             new_tokens = client.users.refresh_token()
@@ -421,7 +421,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.refreshAccessToken();
                                 const response = await client.users.refreshAccessToken();
@@ -467,7 +467,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             result = client.users.change_password(
                             result = client.users.change_password(
@@ -482,7 +482,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.changePassword({
                                 const response = await client.users.changePassword({
@@ -537,7 +537,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             result = client.users.request_password_reset(
                             result = client.users.request_password_reset(
                                 email="jane.doe@example.com"
                                 email="jane.doe@example.com"
                             )"""
                             )"""
@@ -549,7 +549,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.requestPasswordReset({
                                 const response = await client.users.requestPasswordReset({
@@ -595,7 +595,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             result = client.users.reset_password(
                             result = client.users.reset_password(
                                 reset_token="reset_token_received_via_email",
                                 reset_token="reset_token_received_via_email",
                                 new_password="new_secure_password789"
                                 new_password="new_secure_password789"
@@ -608,7 +608,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.resetPassword({
                                 const response = await client.users.resetPassword({
@@ -659,7 +659,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # List users with filters
                             # List users with filters
@@ -676,7 +676,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.list();
                                 const response = await client.users.list();
@@ -766,7 +766,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Get user details
                             # Get user details
@@ -780,7 +780,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.retrieve();
                                 const response = await client.users.retrieve();
@@ -831,7 +831,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Get user details
                             # Get user details
@@ -847,7 +847,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.retrieve({
                                 const response = await client.users.retrieve({
@@ -918,7 +918,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                         from r2r import R2RClient
                         from r2r import R2RClient
 
 
-                        client = R2RClient("http://localhost:7272")
+                        client = R2RClient()
                         # client.login(...)
                         # client.login(...)
 
 
                         # Delete user
                         # Delete user
@@ -932,7 +932,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                         const { r2rClient } = require("r2r-js");
                         const { r2rClient } = require("r2r-js");
 
 
-                        const client = new r2rClient("http://localhost:7272");
+                        const client = new r2rClient();
 
 
                         function main() {
                         function main() {
                             const response = await client.users.delete({
                             const response = await client.users.delete({
@@ -992,7 +992,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Get user collections
                             # Get user collections
@@ -1010,7 +1010,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.listCollections({
                                 const response = await client.users.listCollections({
@@ -1095,7 +1095,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Add user to collection
                             # Add user to collection
@@ -1112,7 +1112,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.addToCollection({
                                 const response = await client.users.addToCollection({
@@ -1179,7 +1179,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Remove user from collection
                             # Remove user from collection
@@ -1196,7 +1196,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.removeFromCollection({
                                 const response = await client.users.removeFromCollection({
@@ -1267,7 +1267,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             # Update user
                             # Update user
@@ -1284,7 +1284,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             const { r2rClient } = require("r2r-js");
                             const { r2rClient } = require("r2r-js");
 
 
-                            const client = new r2rClient("http://localhost:7272");
+                            const client = new r2rClient();
 
 
                             function main() {
                             function main() {
                                 const response = await client.users.update({
                                 const response = await client.users.update({
@@ -1329,6 +1329,10 @@ class UsersRouter(BaseRouterV3):
             profile_picture: str | None = Body(
             profile_picture: str | None = Body(
                 None, description="Updated profile picture URL"
                 None, description="Updated profile picture URL"
             ),
             ),
+            limits_overrides: dict = Body(
+                None,
+                description="Updated limits overrides",
+            ),
             auth_user=Depends(self.providers.auth.auth_wrapper()),
             auth_user=Depends(self.providers.auth.auth_wrapper()),
         ) -> WrappedUserResponse:
         ) -> WrappedUserResponse:
             """
             """
@@ -1347,7 +1351,11 @@ class UsersRouter(BaseRouterV3):
                     "Only superusers can update other users' information",
                     "Only superusers can update other users' information",
                     403,
                     403,
                 )
                 )
-
+            if not auth_user.is_superuser and limits_overrides is not None:
+                raise R2RException(
+                    "Only superusers can update other users' limits overrides",
+                    403,
+                )
             return await self.services.auth.update_user(
             return await self.services.auth.update_user(
                 user_id=id,
                 user_id=id,
                 email=email,
                 email=email,
@@ -1355,6 +1363,7 @@ class UsersRouter(BaseRouterV3):
                 name=name,
                 name=name,
                 bio=bio,
                 bio=bio,
                 profile_picture=profile_picture,
                 profile_picture=profile_picture,
+                limits_overrides=limits_overrides,
             )
             )
 
 
         @self.router.post(
         @self.router.post(
@@ -1370,7 +1379,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             result = client.users.create_api_key(
                             result = client.users.create_api_key(
@@ -1424,7 +1433,7 @@ class UsersRouter(BaseRouterV3):
                             """
                             """
                             from r2r import R2RClient
                             from r2r import R2RClient
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             keys = client.users.list_api_keys(
                             keys = client.users.list_api_keys(
@@ -1482,7 +1491,7 @@ class UsersRouter(BaseRouterV3):
                             from r2r import R2RClient
                             from r2r import R2RClient
                             from uuid import UUID
                             from uuid import UUID
 
 
-                            client = R2RClient("http://localhost:7272")
+                            client = R2RClient()
                             # client.login(...)
                             # client.login(...)
 
 
                             response = client.users.delete_api_key(
                             response = client.users.delete_api_key(

+ 0 - 1
core/main/app.py

@@ -72,7 +72,6 @@ class R2RApp:
         self._apply_cors()
         self._apply_cors()
 
 
     def _setup_routes(self):
     def _setup_routes(self):
-
         self.app.include_router(self.chunks_router, prefix="/v3")
         self.app.include_router(self.chunks_router, prefix="/v3")
         self.app.include_router(self.collections_router, prefix="/v3")
         self.app.include_router(self.collections_router, prefix="/v3")
         self.app.include_router(self.conversations_router, prefix="/v3")
         self.app.include_router(self.conversations_router, prefix="/v3")

+ 5 - 12
core/main/assembly/builder.py

@@ -72,21 +72,14 @@ class R2RBuilder:
         ).create_pipelines(*args, **kwargs)
         ).create_pipelines(*args, **kwargs)
 
 
     def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
     def _create_services(self, service_params: dict[str, Any]) -> R2RServices:
+        services = ["auth", "ingestion", "management", "retrieval", "graph"]
         service_instances = {}
         service_instances = {}
-        for service_type, override in vars(R2RServices()).items():
-            logger.info(f"Creating {service_type} service")
+
+        for service_type in services:
             service_class = globals()[f"{service_type.capitalize()}Service"]
             service_class = globals()[f"{service_type.capitalize()}Service"]
-            service_instances[service_type] = override or service_class(
-                **service_params
-            )
+            service_instances[service_type] = service_class(**service_params)
 
 
-        return R2RServices(
-            auth=service_instances["auth"],
-            ingestion=service_instances["ingestion"],
-            management=service_instances["management"],
-            retrieval=service_instances["retrieval"],
-            graph=service_instances["graph"],
-        )
+        return R2RServices(**service_instances)
 
 
     async def _create_providers(
     async def _create_providers(
         self, provider_factory: Type[R2RProviderFactory], *args, **kwargs
         self, provider_factory: Type[R2RProviderFactory], *args, **kwargs

+ 0 - 2
core/main/assembly/factory.py

@@ -66,7 +66,6 @@ class R2RProviderFactory:
         **kwargs,
         **kwargs,
     ) -> R2RAuthProvider | SupabaseAuthProvider:
     ) -> R2RAuthProvider | SupabaseAuthProvider:
         if auth_config.provider == "r2r":
         if auth_config.provider == "r2r":
-
             r2r_auth = R2RAuthProvider(
             r2r_auth = R2RAuthProvider(
                 auth_config, crypto_provider, database_provider, email_provider
                 auth_config, crypto_provider, database_provider, email_provider
             )
             )
@@ -106,7 +105,6 @@ class R2RProviderFactory:
         *args,
         *args,
         **kwargs,
         **kwargs,
     ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
     ) -> R2RIngestionProvider | UnstructuredIngestionProvider:
-
         config_dict = (
         config_dict = (
             ingestion_config.model_dump()
             ingestion_config.model_dump()
             if isinstance(ingestion_config, IngestionConfig)
             if isinstance(ingestion_config, IngestionConfig)

+ 0 - 1
core/main/orchestration/hatchet/ingestion_workflow.py

@@ -216,7 +216,6 @@ def hatchet_ingestion_factory(
                     )
                     )
 
 
                 if chunk_enrichment_settings.enable_chunk_enrichment:
                 if chunk_enrichment_settings.enable_chunk_enrichment:
-
                     logger.info("Enriching document with contextual chunks")
                     logger.info("Enriching document with contextual chunks")
 
 
                     # TODO: the status updating doesn't work because document_info doesn't contain information about collection IDs
                     # TODO: the status updating doesn't work because document_info doesn't contain information about collection IDs

+ 0 - 8
core/main/orchestration/hatchet/kg_workflow.py

@@ -23,7 +23,6 @@ if TYPE_CHECKING:
 def hatchet_kg_factory(
 def hatchet_kg_factory(
     orchestration_provider: OrchestrationProvider, service: GraphService
     orchestration_provider: OrchestrationProvider, service: GraphService
 ) -> dict[str, "Hatchet.Workflow"]:
 ) -> dict[str, "Hatchet.Workflow"]:
-
     def convert_to_dict(input_data):
     def convert_to_dict(input_data):
         """
         """
         Converts input data back to a plain dictionary format, handling special cases like UUID and GenerationConfig.
         Converts input data back to a plain dictionary format, handling special cases like UUID and GenerationConfig.
@@ -72,7 +71,6 @@ def hatchet_kg_factory(
 
 
     def get_input_data_dict(input_data):
     def get_input_data_dict(input_data):
         for key, value in input_data.items():
         for key, value in input_data.items():
-
             if value is None:
             if value is None:
                 continue
                 continue
 
 
@@ -212,7 +210,6 @@ def hatchet_kg_factory(
             retries=1, timeout="360m", parents=["kg_extract"]
             retries=1, timeout="360m", parents=["kg_extract"]
         )
         )
         async def kg_entity_description(self, context: Context) -> dict:
         async def kg_entity_description(self, context: Context) -> dict:
-
             input_data = get_input_data_dict(
             input_data = get_input_data_dict(
                 context.workflow_input()["request"]
                 context.workflow_input()["request"]
             )
             )
@@ -259,7 +256,6 @@ def hatchet_kg_factory(
 
 
     @orchestration_provider.workflow(name="extract-triples", timeout="600m")
     @orchestration_provider.workflow(name="extract-triples", timeout="600m")
     class CreateGraphWorkflow:
     class CreateGraphWorkflow:
-
         @orchestration_provider.concurrency(  # type: ignore
         @orchestration_provider.concurrency(  # type: ignore
             max_runs=orchestration_provider.config.kg_concurrency_limit,  # type: ignore
             max_runs=orchestration_provider.config.kg_concurrency_limit,  # type: ignore
             limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
             limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
@@ -319,7 +315,6 @@ def hatchet_kg_factory(
                 }
                 }
 
 
             else:
             else:
-
                 # Extract relationships and store them
                 # Extract relationships and store them
                 extractions = []
                 extractions = []
                 async for extraction in self.kg_service.kg_extraction(
                 async for extraction in self.kg_service.kg_extraction(
@@ -399,7 +394,6 @@ def hatchet_kg_factory(
         async def kg_entity_deduplication_setup(
         async def kg_entity_deduplication_setup(
             self, context: Context
             self, context: Context
         ) -> dict:
         ) -> dict:
-
             input_data = get_input_data_dict(
             input_data = get_input_data_dict(
                 context.workflow_input()["request"]
                 context.workflow_input()["request"]
             )
             )
@@ -467,7 +461,6 @@ def hatchet_kg_factory(
         async def kg_entity_deduplication_summary(
         async def kg_entity_deduplication_summary(
             self, context: Context
             self, context: Context
         ) -> dict:
         ) -> dict:
-
             logger.info(
             logger.info(
                 f"Running KG Entity Deduplication Summary for input data: {context.workflow_input()['request']}"
                 f"Running KG Entity Deduplication Summary for input data: {context.workflow_input()['request']}"
             )
             )
@@ -660,7 +653,6 @@ def hatchet_kg_factory(
 
 
         @orchestration_provider.step(retries=1, timeout="360m")
         @orchestration_provider.step(retries=1, timeout="360m")
         async def kg_community_summary(self, context: Context) -> dict:
         async def kg_community_summary(self, context: Context) -> dict:
-
             start_time = time.time()
             start_time = time.time()
 
 
             logger.info
             logger.info

+ 0 - 3
core/main/orchestration/simple/ingestion_workflow.py

@@ -101,7 +101,6 @@ def simple_ingestion_factory(service: IngestionService):
                         status=KGEnrichmentStatus.OUTDATED,  # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
                         status=KGEnrichmentStatus.OUTDATED,  # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still
                     )
                     )
                 else:
                 else:
-
                     for collection_id in collection_ids:
                     for collection_id in collection_ids:
                         try:
                         try:
                             # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
                             # FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
@@ -332,7 +331,6 @@ def simple_ingestion_factory(service: IngestionService):
                 else:
                 else:
                     for collection_id in collection_ids:
                     for collection_id in collection_ids:
                         try:
                         try:
-
                             name = document_info.title or "N/A"
                             name = document_info.title or "N/A"
                             description = ""
                             description = ""
                             result = await service.providers.database.collections_handler.create_collection(
                             result = await service.providers.database.collections_handler.create_collection(
@@ -419,7 +417,6 @@ def simple_ingestion_factory(service: IngestionService):
             )
             )
 
 
     async def create_vector_index(input_data):
     async def create_vector_index(input_data):
-
         try:
         try:
             from core.main import IngestionServiceAdapter
             from core.main import IngestionServiceAdapter
 
 

+ 0 - 7
core/main/orchestration/simple/kg_workflow.py

@@ -12,10 +12,8 @@ logger = logging.getLogger()
 
 
 
 
 def simple_kg_factory(service: GraphService):
 def simple_kg_factory(service: GraphService):
-
     def get_input_data_dict(input_data):
     def get_input_data_dict(input_data):
         for key, value in input_data.items():
         for key, value in input_data.items():
-
             if type(value) == uuid.UUID:
             if type(value) == uuid.UUID:
                 continue
                 continue
 
 
@@ -41,7 +39,6 @@ def simple_kg_factory(service: GraphService):
         return input_data
         return input_data
 
 
     async def extract_triples(input_data):
     async def extract_triples(input_data):
-
         input_data = get_input_data_dict(input_data)
         input_data = get_input_data_dict(input_data)
 
 
         if input_data.get("document_id"):
         if input_data.get("document_id"):
@@ -105,7 +102,6 @@ def simple_kg_factory(service: GraphService):
                 raise e
                 raise e
 
 
     async def enrich_graph(input_data):
     async def enrich_graph(input_data):
-
         input_data = get_input_data_dict(input_data)
         input_data = get_input_data_dict(input_data)
         workflow_status = await service.providers.database.documents_handler.get_workflow_status(
         workflow_status = await service.providers.database.documents_handler.get_workflow_status(
             id=input_data.get("collection_id", None),
             id=input_data.get("collection_id", None),
@@ -157,7 +153,6 @@ def simple_kg_factory(service: GraphService):
             )
             )
 
 
         except Exception as e:
         except Exception as e:
-
             await service.providers.database.documents_handler.set_workflow_status(
             await service.providers.database.documents_handler.set_workflow_status(
                 id=input_data.get("collection_id", None),
                 id=input_data.get("collection_id", None),
                 status_type="graph_cluster_status",
                 status_type="graph_cluster_status",
@@ -167,7 +162,6 @@ def simple_kg_factory(service: GraphService):
             raise e
             raise e
 
 
     async def kg_community_summary(input_data):
     async def kg_community_summary(input_data):
-
         logger.info(
         logger.info(
             f"Running kg community summary for offset: {input_data['offset']}, limit: {input_data['limit']}"
             f"Running kg community summary for offset: {input_data['offset']}, limit: {input_data['limit']}"
         )
         )
@@ -181,7 +175,6 @@ def simple_kg_factory(service: GraphService):
         )
         )
 
 
     async def entity_deduplication_workflow(input_data):
     async def entity_deduplication_workflow(input_data):
-
         # TODO: We should determine how we want to handle the input here and syncronize it across all simple orchestration methods
         # TODO: We should determine how we want to handle the input here and syncronize it across all simple orchestration methods
         if isinstance(input_data["graph_entity_deduplication_settings"], str):
         if isinstance(input_data["graph_entity_deduplication_settings"], str):
             input_data["graph_entity_deduplication_settings"] = json.loads(
             input_data["graph_entity_deduplication_settings"] = json.loads(

+ 3 - 0
core/main/services/auth_service.py

@@ -127,6 +127,7 @@ class AuthService(Service):
         name: Optional[str] = None,
         name: Optional[str] = None,
         bio: Optional[str] = None,
         bio: Optional[str] = None,
         profile_picture: Optional[str] = None,
         profile_picture: Optional[str] = None,
+        limits_overrides: Optional[dict] = None,
     ) -> User:
     ) -> User:
         user: User = (
         user: User = (
             await self.providers.database.users_handler.get_user_by_id(user_id)
             await self.providers.database.users_handler.get_user_by_id(user_id)
@@ -143,6 +144,8 @@ class AuthService(Service):
             user.bio = bio
             user.bio = bio
         if profile_picture is not None:
         if profile_picture is not None:
             user.profile_picture = profile_picture
             user.profile_picture = profile_picture
+        if limits_overrides is not None:
+            user.limits_overrides = limits_overrides
         return await self.providers.database.users_handler.update_user(user)
         return await self.providers.database.users_handler.update_user(user)
 
 
     @telemetry_event("DeleteUserAccount")
     @telemetry_event("DeleteUserAccount")

+ 0 - 10
core/main/services/graph_service.py

@@ -79,7 +79,6 @@ class GraphService(Service):
         **kwargs,
         **kwargs,
     ):
     ):
         try:
         try:
-
             logger.info(
             logger.info(
                 f"KGService: Processing document {document_id} for KG extraction"
                 f"KGService: Processing document {document_id} for KG extraction"
             )
             )
@@ -138,7 +137,6 @@ class GraphService(Service):
         category: Optional[str] = None,
         category: Optional[str] = None,
         metadata: Optional[dict] = None,
         metadata: Optional[dict] = None,
     ) -> Entity:
     ) -> Entity:
-
         description_embedding = str(
         description_embedding = str(
             await self.providers.embedding.async_get_embedding(description)
             await self.providers.embedding.async_get_embedding(description)
         )
         )
@@ -162,7 +160,6 @@ class GraphService(Service):
         category: Optional[str] = None,
         category: Optional[str] = None,
         metadata: Optional[dict] = None,
         metadata: Optional[dict] = None,
     ) -> Entity:
     ) -> Entity:
-
         description_embedding = None
         description_embedding = None
         if description is not None:
         if description is not None:
             description_embedding = str(
             description_embedding = str(
@@ -272,7 +269,6 @@ class GraphService(Service):
         weight: Optional[float] = None,
         weight: Optional[float] = None,
         metadata: Optional[dict[str, Any] | str] = None,
         metadata: Optional[dict[str, Any] | str] = None,
     ) -> Relationship:
     ) -> Relationship:
-
         description_embedding = None
         description_embedding = None
         if description is not None:
         if description is not None:
             description_embedding = str(
             description_embedding = str(
@@ -471,7 +467,6 @@ class GraphService(Service):
         force_kg_creation: bool = False,
         force_kg_creation: bool = False,
         **kwargs,
         **kwargs,
     ):
     ):
-
         document_status_filter = [
         document_status_filter = [
             KGExtractionStatus.PENDING,
             KGExtractionStatus.PENDING,
             KGExtractionStatus.FAILED,
             KGExtractionStatus.FAILED,
@@ -494,7 +489,6 @@ class GraphService(Service):
         max_description_input_length: int,
         max_description_input_length: int,
         **kwargs,
         **kwargs,
     ):
     ):
-
         start_time = time.time()
         start_time = time.time()
 
 
         logger.info(
         logger.info(
@@ -568,7 +562,6 @@ class GraphService(Service):
         leiden_params: dict,
         leiden_params: dict,
         **kwargs,
         **kwargs,
     ):
     ):
-
         logger.info(
         logger.info(
             f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
             f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
         )
         )
@@ -670,7 +663,6 @@ class GraphService(Service):
         graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
         graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
         **kwargs,
         **kwargs,
     ):
     ):
-
         if graph_id is None and collection_id is None:
         if graph_id is None and collection_id is None:
             raise ValueError(
             raise ValueError(
                 "Either graph_id or collection_id must be provided"
                 "Either graph_id or collection_id must be provided"
@@ -731,7 +723,6 @@ class GraphService(Service):
         generation_config: GenerationConfig,
         generation_config: GenerationConfig,
         **kwargs,
         **kwargs,
     ):
     ):
-
         logger.info(
         logger.info(
             f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
             f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
         )
         )
@@ -1064,7 +1055,6 @@ class GraphService(Service):
                 entities_id_map[entity.name] = result.id
                 entities_id_map[entity.name] = result.id
 
 
             if extraction.relationships:
             if extraction.relationships:
-
                 for relationship in extraction.relationships:
                 for relationship in extraction.relationships:
                     await self.providers.database.graphs_handler.relationships.create(
                     await self.providers.database.graphs_handler.relationships.create(
                         subject=relationship.subject,
                         subject=relationship.subject,

+ 0 - 1
core/main/services/ingestion_service.py

@@ -300,7 +300,6 @@ class IngestionService(Service):
     async def finalize_ingestion(
     async def finalize_ingestion(
         self, document_info: DocumentResponse
         self, document_info: DocumentResponse
     ) -> None:
     ) -> None:
-
         async def empty_generator():
         async def empty_generator():
             yield document_info
             yield document_info
 
 

+ 0 - 1
core/main/services/management_service.py

@@ -455,7 +455,6 @@ class ManagementService(Service):
     async def remove_user_from_collection(
     async def remove_user_from_collection(
         self, user_id: UUID, collection_id: UUID
         self, user_id: UUID, collection_id: UUID
     ) -> bool:
     ) -> bool:
-
         x = await self.providers.database.users_handler.remove_user_from_collection(
         x = await self.providers.database.users_handler.remove_user_from_collection(
             user_id, collection_id
             user_id, collection_id
         )
         )

+ 7 - 3
core/parsers/media/bmp_parser.py

@@ -34,9 +34,13 @@ class BMPParser(AsyncParser[str | bytes]):
             header_size = self.struct.calcsize(header_format)
             header_size = self.struct.calcsize(header_format)
 
 
             # Unpack header data
             # Unpack header data
-            signature, file_size, reserved, reserved2, data_offset = (
-                self.struct.unpack(header_format, data[:header_size])
-            )
+            (
+                signature,
+                file_size,
+                reserved,
+                reserved2,
+                data_offset,
+            ) = self.struct.unpack(header_format, data[:header_size])
 
 
             # DIB header
             # DIB header
             dib_format = "<IiiHHIIiiII"
             dib_format = "<IiiHHIIiiII"

+ 25 - 25
core/pipes/kg/community_summary.py

@@ -52,7 +52,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         relationships: list[Relationship],
         relationships: list[Relationship],
         max_summary_input_length: int,
         max_summary_input_length: int,
     ):
     ):
-
         entity_map: dict[str, dict[str, list[Any]]] = {}
         entity_map: dict[str, dict[str, list[Any]]] = {}
         for entity in entities:
         for entity in entities:
             if not entity.name in entity_map:
             if not entity.name in entity_map:
@@ -172,7 +171,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         )
         )
 
 
         for attempt in range(3):
         for attempt in range(3):
-
             description = (
             description = (
                 (
                 (
                     await self.llm_provider.aget_completion(
                     await self.llm_provider.aget_completion(
@@ -268,34 +266,37 @@ class GraphCommunitySummaryPipe(AsyncPipe):
             f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
             f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
         )
         )
 
 
-        all_entities, _ = (
-            await self.database_provider.graphs_handler.get_entities(
-                parent_id=collection_id,
-                offset=0,
-                limit=-1,
-                include_embeddings=False,
-            )
+        (
+            all_entities,
+            _,
+        ) = await self.database_provider.graphs_handler.get_entities(
+            parent_id=collection_id,
+            offset=0,
+            limit=-1,
+            include_embeddings=False,
         )
         )
 
 
-        all_relationships, _ = (
-            await self.database_provider.graphs_handler.get_relationships(
-                parent_id=collection_id,
-                offset=0,
-                limit=-1,
-                include_embeddings=False,
-            )
+        (
+            all_relationships,
+            _,
+        ) = await self.database_provider.graphs_handler.get_relationships(
+            parent_id=collection_id,
+            offset=0,
+            limit=-1,
+            include_embeddings=False,
         )
         )
 
 
         # Perform clustering
         # Perform clustering
         leiden_params = input.message.get("leiden_params", {})
         leiden_params = input.message.get("leiden_params", {})
-        _, community_clusters = (
-            await self.database_provider.graphs_handler._cluster_and_add_community_info(
-                relationships=all_relationships,
-                relationship_ids_cache={},
-                leiden_params=leiden_params,
-                collection_id=collection_id,
-                clustering_mode=clustering_mode,
-            )
+        (
+            _,
+            community_clusters,
+        ) = await self.database_provider.graphs_handler._cluster_and_add_community_info(
+            relationships=all_relationships,
+            relationship_ids_cache={},
+            leiden_params=leiden_params,
+            collection_id=collection_id,
+            clustering_mode=clustering_mode,
         )
         )
 
 
         # Organize clusters
         # Organize clusters
@@ -330,7 +331,6 @@ class GraphCommunitySummaryPipe(AsyncPipe):
         total_errors = 0
         total_errors = 0
         completed_community_summary_jobs = 0
         completed_community_summary_jobs = 0
         for community_summary in asyncio.as_completed(community_summary_jobs):
         for community_summary in asyncio.as_completed(community_summary_jobs):
-
             summary = await community_summary
             summary = await community_summary
             completed_community_summary_jobs += 1
             completed_community_summary_jobs += 1
             if completed_community_summary_jobs % 50 == 0:
             if completed_community_summary_jobs % 50 == 0:

+ 0 - 1
core/pipes/kg/deduplication.py

@@ -63,7 +63,6 @@ class GraphDeduplicationPipe(AsyncPipe):
     async def kg_named_entity_deduplication(
     async def kg_named_entity_deduplication(
         self, graph_id: UUID | None, collection_id: UUID | None, **kwargs
         self, graph_id: UUID | None, collection_id: UUID | None, **kwargs
     ):
     ):
-
         import numpy as np
         import numpy as np
 
 
         entities = await self._get_entities(graph_id, collection_id)
         entities = await self._get_entities(graph_id, collection_id)

+ 0 - 6
core/pipes/kg/deduplication_summary.py

@@ -19,7 +19,6 @@ logger = logging.getLogger()
 
 
 
 
 class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
 class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
-
     class Input(AsyncPipe.Input):
     class Input(AsyncPipe.Input):
         message: dict
         message: dict
 
 
@@ -48,7 +47,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         entity_descriptions: list[str],
         entity_descriptions: list[str],
         generation_config: GenerationConfig,
         generation_config: GenerationConfig,
     ) -> Entity:
     ) -> Entity:
-
         # find the index until the length is less than 1024
         # find the index until the length is less than 1024
         index = 0
         index = 0
         description_length = 0
         description_length = 0
@@ -89,7 +87,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         entity_descriptions: list[str],
         entity_descriptions: list[str],
         generation_config: GenerationConfig,
         generation_config: GenerationConfig,
     ) -> Entity:
     ) -> Entity:
-
         # TODO: Expose this as a hyperparameter
         # TODO: Expose this as a hyperparameter
         if len(entity_descriptions) <= 5:
         if len(entity_descriptions) <= 5:
             return Entity(
             return Entity(
@@ -103,7 +100,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
     async def _prepare_and_upsert_entities(
     async def _prepare_and_upsert_entities(
         self, entities_batch: list[Entity], graph_id: UUID
         self, entities_batch: list[Entity], graph_id: UUID
     ) -> Any:
     ) -> Any:
-
         embeddings = await self.embedding_provider.async_get_embeddings(
         embeddings = await self.embedding_provider.async_get_embeddings(
             [entity.description or "" for entity in entities_batch]
             [entity.description or "" for entity in entities_batch]
         )
         )
@@ -135,7 +131,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
         limit: int,
         limit: int,
         level,
         level,
     ):
     ):
-
         if graph_id is not None:
         if graph_id is not None:
             return await self.database_provider.graphs_handler.entities.get(
             return await self.database_provider.graphs_handler.entities.get(
                 parent_id=graph_id,
                 parent_id=graph_id,
@@ -235,7 +230,6 @@ class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
                 tasks = []
                 tasks = []
 
 
         if tasks:
         if tasks:
-
             entities_batch = await asyncio.gather(*tasks)
             entities_batch = await asyncio.gather(*tasks)
             for entity in entities_batch:
             for entity in entities_batch:
                 yield entity
                 yield entity

+ 0 - 1
core/pipes/kg/extraction.py

@@ -212,7 +212,6 @@ class GraphExtractionPipe(AsyncPipe[dict]):
         *args: Any,
         *args: Any,
         **kwargs: Any,
         **kwargs: Any,
     ) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]:
     ) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]:
-
         start_time = time.time()
         start_time = time.time()
 
 
         document_id = input.message["document_id"]
         document_id = input.message["document_id"]

+ 0 - 1
core/pipes/kg/storage.py

@@ -49,7 +49,6 @@ class GraphStoragePipe(AsyncPipe):
         total_entities, total_relationships = 0, 0
         total_entities, total_relationships = 0, 0
 
 
         for extraction in kg_extractions:
         for extraction in kg_extractions:
-
             total_entities, total_relationships = (
             total_entities, total_relationships = (
                 total_entities + len(extraction.entities),
                 total_entities + len(extraction.entities),
                 total_relationships + len(extraction.relationships),
                 total_relationships + len(extraction.relationships),

+ 0 - 1
core/pipes/retrieval/chunk_search_pipe.py

@@ -65,7 +65,6 @@ class VectorSearchPipe(SearchPipe):
             search_settings.use_fulltext_search
             search_settings.use_fulltext_search
             and search_settings.use_semantic_search
             and search_settings.use_semantic_search
         ) or search_settings.use_hybrid_search:
         ) or search_settings.use_hybrid_search:
-
             search_results = (
             search_results = (
                 await self.database_provider.chunks_handler.hybrid_search(
                 await self.database_provider.chunks_handler.hybrid_search(
                     query_vector=query_vector,
                     query_vector=query_vector,

+ 0 - 1
core/pipes/retrieval/graph_search_pipe.py

@@ -262,6 +262,5 @@ class GraphSearchSearchPipe(GeneratorPipe):
         *args: Any,
         *args: Any,
         **kwargs: Any,
         **kwargs: Any,
     ) -> AsyncGenerator[GraphSearchResult, None]:
     ) -> AsyncGenerator[GraphSearchResult, None]:
-
         async for result in self.search(input, state, run_id, search_settings):
         async for result in self.search(input, state, run_id, search_settings):
             yield result
             yield result

+ 0 - 1
core/providers/auth/r2r_auth.py

@@ -361,7 +361,6 @@ class R2RAuthProvider(AuthProvider):
 
 
     async def request_password_reset(self, email: str) -> dict[str, str]:
     async def request_password_reset(self, email: str) -> dict[str, str]:
         try:
         try:
-
             user = (
             user = (
                 await self.database_provider.users_handler.get_user_by_email(
                 await self.database_provider.users_handler.get_user_by_email(
                     email=email
                     email=email

+ 1 - 1
core/providers/crypto/bcrypt.py

@@ -116,7 +116,7 @@ class BCryptCryptoProvider(CryptoProvider, ABC):
 
 
         # Generate unique key_id
         # Generate unique key_id
         key_entropy = nacl.utils.random(16)
         key_entropy = nacl.utils.random(16)
-        key_id = f"key_{base64.urlsafe_b64encode(key_entropy).decode()}"
+        key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
 
 
         private_key = base64.b64encode(bytes(signing_key)).decode()
         private_key = base64.b64encode(bytes(signing_key)).decode()
         public_key = base64.b64encode(bytes(verify_key)).decode()
         public_key = base64.b64encode(bytes(verify_key)).decode()

+ 26 - 8
core/providers/crypto/nacl.py

@@ -19,14 +19,26 @@ from core.base import CryptoConfig, CryptoProvider
 DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM"  # Replace or load from env or secrets manager
 DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM"  # Replace or load from env or secrets manager
 
 
 
 
+def encode_bytes_readable(random_bytes: bytes, chars: str) -> str:
+    """Convert random bytes to a readable string using the given character set."""
+    # Each byte gives us 8 bits of randomness
+    # We use modulo to map each byte to our character set
+    result = []
+    for byte in random_bytes:
+        # Use modulo to map the byte (0-255) to our character set length
+        idx = byte % len(chars)
+        result.append(chars[idx])
+    return "".join(result)
+
+
 class NaClCryptoConfig(CryptoConfig):
 class NaClCryptoConfig(CryptoConfig):
     provider: str = "nacl"
     provider: str = "nacl"
     # Interactive parameters for password ops (fast)
     # Interactive parameters for password ops (fast)
-    ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
-    mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
+    ops_limit: int = argon2i.OPSLIMIT_MIN
+    mem_limit: int = argon2i.MEMLIMIT_MIN
     # Sensitive parameters for API key generation (slow but more secure)
     # Sensitive parameters for API key generation (slow but more secure)
-    api_ops_limit: int = argon2i.OPSLIMIT_SENSITIVE
-    api_mem_limit: int = argon2i.MEMLIMIT_SENSITIVE
+    api_ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
+    api_mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
     api_key_bytes: int = 32
     api_key_bytes: int = 32
     secret_key: Optional[str] = None
     secret_key: Optional[str] = None
 
 
@@ -72,14 +84,20 @@ class NaClCryptoProvider(CryptoProvider):
         return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
         return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
 
 
     def generate_api_key(self) -> Tuple[str, str]:
     def generate_api_key(self) -> Tuple[str, str]:
+
+        # Define our character set (excluding ambiguous characters)
+        chars = string.ascii_letters.replace("l", "").replace("I", "").replace(
+            "O", ""
+        ) + string.digits.replace("0", "").replace("1", "")
+
         # Generate a unique key_id
         # Generate a unique key_id
         key_id_bytes = nacl.utils.random(16)  # 16 random bytes
         key_id_bytes = nacl.utils.random(16)  # 16 random bytes
-        key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
+        key_id = f"sk_{encode_bytes_readable(key_id_bytes, chars)}"
 
 
         # Generate a high-entropy API key
         # Generate a high-entropy API key
-        raw_api_key = base64.urlsafe_b64encode(
-            nacl.utils.random(self.config.api_key_bytes)
-        ).decode()
+        raw_api_key = encode_bytes_readable(
+            nacl.utils.random(self.config.api_key_bytes), chars
+        )
 
 
         # The caller will store the hashed version in the database
         # The caller will store the hashed version in the database
         return key_id, raw_api_key
         return key_id, raw_api_key

+ 0 - 1
core/providers/embeddings/litellm.py

@@ -43,7 +43,6 @@ class LiteLLMEmbeddingProvider(EmbeddingProvider):
 
 
         self.rerank_url = None
         self.rerank_url = None
         if config.rerank_model:
         if config.rerank_model:
-
             if "huggingface" not in config.rerank_model:
             if "huggingface" not in config.rerank_model:
                 raise ValueError(
                 raise ValueError(
                     "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
                     "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"

+ 0 - 1
core/providers/ingestion/r2r/base.py

@@ -186,7 +186,6 @@ class R2RIngestionProvider(IngestionProvider):
         parsed_document: str | DocumentChunk,
         parsed_document: str | DocumentChunk,
         ingestion_config_override: dict,
         ingestion_config_override: dict,
     ) -> AsyncGenerator[Any, None]:
     ) -> AsyncGenerator[Any, None]:
-
         text_spliiter = self.text_splitter
         text_spliiter = self.text_splitter
         if ingestion_config_override:
         if ingestion_config_override:
             text_spliiter = self._build_text_splitter(
             text_spliiter = self._build_text_splitter(

+ 0 - 1
core/providers/ingestion/unstructured/base.py

@@ -220,7 +220,6 @@ class UnstructuredIngestionProvider(IngestionProvider):
         document: Document,
         document: Document,
         ingestion_config_override: dict,
         ingestion_config_override: dict,
     ) -> AsyncGenerator[DocumentChunk, None]:
     ) -> AsyncGenerator[DocumentChunk, None]:
-
         ingestion_config = copy(
         ingestion_config = copy(
             {
             {
                 **self.config.to_ingestion_request(),
                 **self.config.to_ingestion_request(),

+ 0 - 1
core/telemetry/telemetry_decorator.py

@@ -80,7 +80,6 @@ if os.getenv("TELEMETRY_ENABLED", "true").lower() in ("true", "1"):
 def telemetry_event(event_name):
 def telemetry_event(event_name):
     def decorator(func):
     def decorator(func):
         def log_telemetry(event_type, user_id, metadata, error_message=None):
         def log_telemetry(event_type, user_id, metadata, error_message=None):
-
             if telemetry_thread_pool is None:
             if telemetry_thread_pool is None:
                 return
                 return
 
 

+ 0 - 2
migrations/versions/8077140e1e99_v3_api_database_revision.py

@@ -34,7 +34,6 @@ if (
 
 
 
 
 def upgrade() -> None:
 def upgrade() -> None:
-
     # Collections table migration
     # Collections table migration
     op.alter_column(
     op.alter_column(
         "collections",
         "collections",
@@ -195,7 +194,6 @@ def upgrade() -> None:
 
 
 
 
 def downgrade() -> None:
 def downgrade() -> None:
-
     # Collections table migration
     # Collections table migration
     op.alter_column(
     op.alter_column(
         "collections",
         "collections",

+ 3 - 0
pyproject.toml

@@ -52,6 +52,7 @@ unstructured-client = "0.25.5"
 psycopg-binary = "^3.2.3"
 psycopg-binary = "^3.2.3"
 aiosmtplib = "^3.0.2"
 aiosmtplib = "^3.0.2"
 types-aiofiles = "^24.1.0.20240626"
 types-aiofiles = "^24.1.0.20240626"
+rich = "^13.9.4"
 aiohttp = "^3.10.10"
 aiohttp = "^3.10.10"
 typing-extensions = "^4.12.2"
 typing-extensions = "^4.12.2"
 
 
@@ -81,9 +82,11 @@ sendgrid = { version = "^6.11.0", optional = true }
 sqlalchemy = { version = "^2.0.30", optional = true }
 sqlalchemy = { version = "^2.0.30", optional = true }
 supabase = { version = "^2.7.4", optional = true }
 supabase = { version = "^2.7.4", optional = true }
 tokenizers = { version = "0.19", optional = true }
 tokenizers = { version = "0.19", optional = true }
+unstructured-client = { version = "0.25.5", optional = true }
 uvicorn = { version = "^0.27.0.post1", optional = true }
 uvicorn = { version = "^0.27.0.post1", optional = true }
 vecs = { version = "^0.4.0", optional = true }
 vecs = { version = "^0.4.0", optional = true }
 
 
+
 # R2R Ingestion
 # R2R Ingestion
 aiofiles = { version = "^24.1.0", optional = true }
 aiofiles = { version = "^24.1.0", optional = true }
 aioshutil = { version = "^1.5", optional = true }
 aioshutil = { version = "^1.5", optional = true }

+ 1 - 1
sdk/async_client.py

@@ -44,7 +44,7 @@ class R2RAsyncClient(
 
 
     def __init__(
     def __init__(
         self,
         self,
-        base_url: str = "http://localhost:7272",
+        base_url: str = "https://api.cloud.sciphi.ai",
         prefix: str = "/v2",
         prefix: str = "/v2",
         custom_client=None,
         custom_client=None,
         timeout: float = 300.0,
         timeout: float = 300.0,

+ 3 - 3
sdk/base/base_client.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
 import contextlib
 import contextlib
+import os
 from functools import wraps
 from functools import wraps
 from typing import Optional
 from typing import Optional
 
 
@@ -34,7 +35,7 @@ def sync_generator_wrapper(async_gen_func):
 class BaseClient:
 class BaseClient:
     def __init__(
     def __init__(
         self,
         self,
-        base_url: str = "http://localhost:7272",
+        base_url: str = "https://api.cloud.sciphi.ai",
         prefix: str = "/v2",
         prefix: str = "/v2",
         timeout: float = 300.0,
         timeout: float = 300.0,
     ):
     ):
@@ -43,7 +44,7 @@ class BaseClient:
         self.timeout = timeout
         self.timeout = timeout
         self.access_token: Optional[str] = None
         self.access_token: Optional[str] = None
         self._refresh_token: Optional[str] = None
         self._refresh_token: Optional[str] = None
-        self.api_key: Optional[str] = None
+        self.api_key: Optional[str] = os.getenv("R2R_API_KEY", None)
 
 
     def _get_auth_header(self) -> dict[str, str]:
     def _get_auth_header(self) -> dict[str, str]:
         if self.access_token and self.api_key:
         if self.access_token and self.api_key:
@@ -69,7 +70,6 @@ class BaseClient:
         return f"{self.base_url}/{version}/{endpoint}"
         return f"{self.base_url}/{version}/{endpoint}"
 
 
     def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
     def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
-
         headers = kwargs.pop("headers", {})
         headers = kwargs.pop("headers", {})
         if (self.access_token or self.api_key) and endpoint not in [
         if (self.access_token or self.api_key) and endpoint not in [
             "register",
             "register",

+ 0 - 1
sdk/sync_client.py

@@ -104,7 +104,6 @@ class R2RClient(R2RAsyncClient):
     def _make_sync_method(
     def _make_sync_method(
         self, async_method: Callable[..., Coroutine[Any, Any, T]]
         self, async_method: Callable[..., Coroutine[Any, Any, T]]
     ) -> Callable[..., T]:
     ) -> Callable[..., T]:
-
         @functools.wraps(async_method)
         @functools.wraps(async_method)
         def wrapped(*args, **kwargs):
         def wrapped(*args, **kwargs):
             return self._loop.run_until_complete(async_method(*args, **kwargs))
             return self._loop.run_until_complete(async_method(*args, **kwargs))

+ 1 - 2
sdk/v2/management.py

@@ -711,9 +711,8 @@ class ManagementMixins:
         Returns:
         Returns:
             dict: The conversation data.
             dict: The conversation data.
         """
         """
-        query_params = f"?branch_id={branch_id}" if branch_id else ""
         return await self._make_request(  # type: ignore
         return await self._make_request(  # type: ignore
-            "GET", f"get_conversation/{str(conversation_id)}{query_params}"
+            "GET", f"get_conversation/{str(conversation_id)}"
         )
         )
 
 
     @deprecated("Use client.conversations.create() instead")
     @deprecated("Use client.conversations.create() instead")

+ 3 - 3
sdk/v3/graphs.py

@@ -497,7 +497,7 @@ class GraphsSDK:
         Returns:
         Returns:
             dict: Created entity information
             dict: Created entity information
         """
         """
-        data = {
+        data: dict[str, Any] = {
             "name": name,
             "name": name,
             "description": description,
             "description": description,
         }
         }
@@ -542,7 +542,7 @@ class GraphsSDK:
         Returns:
         Returns:
             dict: Created relationship information
             dict: Created relationship information
         """
         """
-        data = {
+        data: dict[str, Any] = {
             "subject": subject,
             "subject": subject,
             "subject_id": str(subject_id),
             "subject_id": str(subject_id),
             "predicate": predicate,
             "predicate": predicate,
@@ -585,7 +585,7 @@ class GraphsSDK:
         Returns:
         Returns:
             dict: Created community information
             dict: Created community information
         """
         """
-        data = {
+        data: dict[str, Any] = {
             "name": name,
             "name": name,
             "summary": summary,
             "summary": summary,
         }
         }

+ 6 - 14
sdk/v3/users.py

@@ -1,4 +1,5 @@
 from __future__ import annotations  # for Python 3.10+
 from __future__ import annotations  # for Python 3.10+
+import json
 
 
 from typing import Optional
 from typing import Optional
 from uuid import UUID
 from uuid import UUID
@@ -125,19 +126,6 @@ class UsersSDK:
             version="v3",
             version="v3",
         )
         )
 
 
-    # async def set_api_key(self, api_key: str) -> None:
-    #     """
-    #     Set the API key for the client.
-
-    #     Args:
-    #         api_key (str): API key to set
-    #     """
-    #     if self.client.access_token:
-    #         raise ValueError(
-    #             "Cannot set an API key after logging in, please log out first"
-    #         )
-    #     self.client.set_api_key(api_key)
-
     async def login(self, email: str, password: str) -> dict[str, Token]:
     async def login(self, email: str, password: str) -> dict[str, Token]:
         """
         """
         Log in a user.
         Log in a user.
@@ -194,7 +182,7 @@ class UsersSDK:
             self.client._refresh_token = None
             self.client._refresh_token = None
             raise ValueError("Invalid token provided")
             raise ValueError("Invalid token provided")
 
 
-    async def logout(self) -> WrappedGenericMessageResponse:
+    async def logout(self) -> WrappedGenericMessageResponse | None:
         """Log out the current user."""
         """Log out the current user."""
         if self.client.access_token:
         if self.client.access_token:
             response = await self.client._make_request(
             response = await self.client._make_request(
@@ -209,6 +197,7 @@ class UsersSDK:
 
 
         self.client.access_token = None
         self.client.access_token = None
         self.client._refresh_token = None
         self.client._refresh_token = None
+        return None
 
 
     async def refresh_token(self) -> WrappedTokenResponse:
     async def refresh_token(self) -> WrappedTokenResponse:
         """Refresh the access token using the refresh token."""
         """Refresh the access token using the refresh token."""
@@ -361,6 +350,7 @@ class UsersSDK:
         name: Optional[str] = None,
         name: Optional[str] = None,
         bio: Optional[str] = None,
         bio: Optional[str] = None,
         profile_picture: Optional[str] = None,
         profile_picture: Optional[str] = None,
+        limits_overrides: dict | None = None,
     ) -> WrappedUserResponse:
     ) -> WrappedUserResponse:
         """
         """
         Update user information.
         Update user information.
@@ -387,6 +377,8 @@ class UsersSDK:
             data["bio"] = bio
             data["bio"] = bio
         if profile_picture is not None:
         if profile_picture is not None:
             data["profile_picture"] = profile_picture
             data["profile_picture"] = profile_picture
+        if limits_overrides is not None:
+            data["limits_overrides"] = limits_overrides
 
 
         return await self.client._make_request(
         return await self.client._make_request(
             "POST",
             "POST",

+ 0 - 1
shared/abstractions/graph.py

@@ -62,7 +62,6 @@ class Relationship(R2RSerializable):
 
 
 @dataclass
 @dataclass
 class Community(R2RSerializable):
 class Community(R2RSerializable):
-
     name: str = ""
     name: str = ""
     summary: str = ""
     summary: str = ""
 
 

+ 2 - 3
shared/abstractions/search.py

@@ -468,7 +468,6 @@ def select_search_filters(
     auth_user: Any,
     auth_user: Any,
     search_settings: SearchSettings,
     search_settings: SearchSettings,
 ) -> dict[str, Any]:
 ) -> dict[str, Any]:
-
     filters = copy(search_settings.filters)
     filters = copy(search_settings.filters)
     selected_collections = None
     selected_collections = None
     if not auth_user.is_superuser:
     if not auth_user.is_superuser:
@@ -493,7 +492,7 @@ def select_search_filters(
         }
         }
 
 
         filters.pop("collection_ids", None)
         filters.pop("collection_ids", None)
-
-        filters = {"$and": [collection_filters, filters]}  # type: ignore
+        if filters != {}:
+            filters = {"$and": [collection_filters, filters]}  # type: ignore
 
 
     return filters
     return filters

+ 3 - 0
shared/abstractions/user.py

@@ -53,6 +53,9 @@ class User(R2RSerializable):
     graph_ids: list[UUID] = []
     graph_ids: list[UUID] = []
     document_ids: list[UUID] = []
     document_ids: list[UUID] = []
 
 
+    # Add the new limits_overrides field
+    limits_overrides: Optional[dict] = None
+
     # Optional fields (to update or set at creation)
     # Optional fields (to update or set at creation)
     hashed_password: Optional[str] = None
     hashed_password: Optional[str] = None
     verification_code_expiry: Optional[datetime] = None
     verification_code_expiry: Optional[datetime] = None

+ 56 - 0
tests/cli/async_invoke.py

@@ -0,0 +1,56 @@
+import asyncio
+from types import TracebackType
+from typing import Any, Tuple, Type, cast
+
+import asyncclick as click
+from click import Abort
+from click.testing import CliRunner, Result
+
+
+async def async_invoke(
+    runner: CliRunner, cmd: click.Command, *args: str, **kwargs: Any
+) -> Result:
+    """Helper function to invoke async Click commands in tests."""
+    exit_code = 0
+    exception: BaseException | None = None
+    exc_info: (
+        Tuple[Type[BaseException], BaseException, TracebackType] | None
+    ) = None
+
+    with runner.isolation() as out_err:
+        stdout, stderr = out_err
+        try:
+            # Get current event loop instead of creating new one
+            loop = asyncio.get_event_loop()
+
+            # Run the command using create_task
+            task = loop.create_task(
+                cmd.main(args=args, standalone_mode=False, **kwargs)
+            )
+            return_value = await task
+
+        except Abort as e:
+            exit_code = 1
+            exception = cast(BaseException, e)
+            if e.__traceback__:
+                exc_info = (BaseException, exception, e.__traceback__)
+            return_value = None
+        except Exception as e:
+            exit_code = 1
+            exception = cast(BaseException, e)
+            if e.__traceback__:
+                exc_info = (BaseException, exception, e.__traceback__)
+            return_value = None
+
+        stdout_bytes = stdout.getvalue() or b""
+        stderr_bytes = stderr.getvalue() if stderr else b""
+
+        return Result(
+            runner=runner,
+            stdout_bytes=stdout_bytes,
+            stderr_bytes=stderr_bytes,
+            return_value=return_value,
+            exit_code=exit_code,
+            exception=exception,
+            exc_info=exc_info,
+        )

+ 162 - 0
tests/cli/commands/test_collections_cli.py

@@ -0,0 +1,162 @@
+"""
+Tests for the collection commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - delete
+    - list-documents
+    - list-users
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.collections import (
+    create,
+    delete,
+    list,
+    list_documents,
+    list_users,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_collection_lifecycle():
+    """Test the complete lifecycle of a collection: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_name = f"test-collection-{uuid.uuid4()}"
+    description = "Test collection description"
+
+    # Create collection
+    create_result = await async_invoke(
+        runner,
+        create,
+        collection_name,
+        "--description",
+        description,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    collection_id = create_response["results"]["id"]
+
+    try:
+        # Retrieve collection
+        retrieve_result = await async_invoke(
+            runner, retrieve, collection_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        assert collection_id in retrieve_output
+
+        # List documents in collection
+        list_docs_result = await async_invoke(
+            runner, list_documents, collection_id, obj=client
+        )
+        assert list_docs_result.exit_code == 0
+
+        # List users in collection
+        list_users_result = await async_invoke(
+            runner, list_users, collection_id, obj=client
+        )
+        assert list_users_result.exit_code == 0
+    finally:
+        # Delete collection
+        delete_result = await async_invoke(
+            runner, delete, collection_id, obj=client
+        )
+        assert delete_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_list_collections():
+    """Test listing collections with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test collection first
+    create_result = await async_invoke(
+        runner, create, f"test-collection-{uuid.uuid4()}", obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    collection_id = response["results"]["id"]
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+        # Get paginated results just to verify they exist
+        list_paginated = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_paginated.exit_code == 0
+
+    finally:
+        # Cleanup
+        await async_invoke(runner, delete, collection_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_collection():
+    """Test operations on a nonexistent collection."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_id, obj=client
+    )
+    # Updated assertion to match actual error message
+    assert (
+        "the specified collection does not exist."
+        in retrieve_result.stderr_bytes.decode().lower()
+    )
+
+    # Test list_documents
+    list_docs_result = await async_invoke(
+        runner, list_documents, nonexistent_id, obj=client
+    )
+    assert (
+        "collection not found"
+        in list_docs_result.stderr_bytes.decode().lower()
+    )
+
+    # Test list_users
+    list_users_result = await async_invoke(
+        runner, list_users, nonexistent_id, obj=client
+    )
+    assert (
+        "collection not found"
+        in list_users_result.stderr_bytes.decode().lower()
+    )

+ 7 - 0
tests/cli/commands/test_config_cli.py

@@ -0,0 +1,7 @@
+"""
+Tests for the conversations commands in the CLI.
+    x reset
+    x key
+    x host
+    x view
+"""

+ 168 - 0
tests/cli/commands/test_conversations_cli.py

@@ -0,0 +1,168 @@
+"""
+Tests for the conversations commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - delete
+    - list-users
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.conversations import (
+    create,
+    delete,
+    list,
+    list_users,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_conversation_lifecycle():
+    """Test the complete lifecycle of a conversation: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create conversation
+    create_result = await async_invoke(
+        runner,
+        create,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    conversation_id = create_response["results"]["id"]
+
+    try:
+        # Retrieve conversation
+        retrieve_result = await async_invoke(
+            runner, retrieve, conversation_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        # FIXME: This assertion fails, we need to sync Conversation and ConversationResponse
+        # assert conversation_id in retrieve_output
+        assert "results" in retrieve_output
+
+        # List users in conversation
+        list_users_result = await async_invoke(
+            runner, list_users, conversation_id, obj=client
+        )
+        assert list_users_result.exit_code == 0
+    finally:
+        # Delete conversation
+        delete_result = await async_invoke(
+            runner, delete, conversation_id, obj=client
+        )
+        assert delete_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_list_conversations():
+    """Test listing conversations with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test conversation first
+    create_result = await async_invoke(runner, create, obj=client)
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    conversation_id = response["results"]["id"]
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+        # Test paginated results
+        list_paginated = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_paginated.exit_code == 0
+
+    finally:
+        # Cleanup
+        await async_invoke(runner, delete, conversation_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_conversation():
+    """Test operations on a nonexistent conversation."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_id, obj=client
+    )
+    assert "not found" in retrieve_result.stderr_bytes.decode().lower()
+
+    # Test list_users
+    list_users_result = await async_invoke(
+        runner, list_users, nonexistent_id, obj=client
+    )
+    assert "not found" in list_users_result.stderr_bytes.decode().lower()
+
+
+@pytest.mark.asyncio
+async def test_list_conversations_pagination():
+    """Test pagination functionality of list conversations."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create multiple conversations
+    conversation_ids = []
+    for _ in range(3):
+        create_result = await async_invoke(runner, create, obj=client)
+        response = extract_json_block(create_result.stdout_bytes.decode())
+        conversation_ids.append(response["results"]["id"])
+
+    try:
+        # Test with different pagination parameters
+        list_first_page = await async_invoke(
+            runner, list, "--offset", "0", "--limit", "2", obj=client
+        )
+        assert list_first_page.exit_code == 0
+        first_page_output = list_first_page.stdout_bytes.decode()
+
+        list_second_page = await async_invoke(
+            runner, list, "--offset", "2", "--limit", "2", obj=client
+        )
+        assert list_second_page.exit_code == 0
+        second_page_output = list_second_page.stdout_bytes.decode()
+
+        # Verify different results on different pages
+        assert first_page_output != second_page_output
+
+    finally:
+        # Cleanup
+        for conversation_id in conversation_ids:
+            await async_invoke(runner, delete, conversation_id, obj=client)

+ 0 - 0
tests/cli/commands/test_database_cli.py


+ 330 - 0
tests/cli/commands/test_documents_cli.py

@@ -0,0 +1,330 @@
+"""
+Tests for the document commands in the CLI.
+    - create
+    - retrieve
+    - list
+    - delete
+    - list-chunks
+    - list-collections
+    x ingest-files-from-url
+    x extract
+    x list-entities
+    x list-relationships
+    x create-sample
+    x create-samples
+"""
+
+import contextlib
+import json
+import os
+import tempfile
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.documents import (
+    create,
+    delete,
+    list,
+    list_chunks,
+    list_collections,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+@pytest.fixture
+def temp_text_file():
+    """Create a temporary text file for testing."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write("This is test content for document testing.")
+        temp_path = f.name
+
+    yield temp_path
+
+    # Cleanup temp file
+    if os.path.exists(temp_path):
+        os.unlink(temp_path)
+
+
+@pytest.fixture
+def temp_json_file():
+    """Create a temporary JSON file for testing."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".json", delete=False
+    ) as f:
+        json.dump({"test": "content", "for": "document testing"}, f)
+        temp_path = f.name
+
+    yield temp_path
+
+    # Cleanup temp file
+    if os.path.exists(temp_path):
+        os.unlink(temp_path)
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    # We assume the output contains at least one JSON object printed with json.dumps(indent=2).
+    # We'll find the first '{' and the matching closing '}' that forms a valid JSON object.
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    # Track braces to find the matching '}'
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+
+            if brace_count == 0:
+                # Found the matching closing brace
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_document_lifecycle(temp_text_file):
+    """Test the complete lifecycle of a document: create, retrieve, delete."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create document
+    create_result = await async_invoke(
+        runner, create, temp_text_file, obj=client
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    document_id = create_response["results"]["document_id"]
+
+    try:
+        # Retrieve document
+        retrieve_result = await async_invoke(
+            runner, retrieve, document_id, obj=client
+        )
+        assert (
+            retrieve_result.exit_code == 0
+        ), retrieve_result.stdout_bytes.decode()
+
+        # Instead of parsing JSON, verify the ID appears in the table output
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        assert document_id in retrieve_output
+
+        # List chunks
+        list_chunks_result = await async_invoke(
+            runner, list_chunks, document_id, obj=client
+        )
+        assert (
+            list_chunks_result.exit_code == 0
+        ), list_chunks_result.stdout_bytes.decode()
+
+        # List collections
+        list_collections_result = await async_invoke(
+            runner, list_collections, document_id, obj=client
+        )
+        assert (
+            list_collections_result.exit_code == 0
+        ), list_collections_result.stdout_bytes.decode()
+    finally:
+        # Delete document
+        delete_result = await async_invoke(
+            runner, delete, document_id, obj=client
+        )
+        assert (
+            delete_result.exit_code == 0
+        ), delete_result.stdout_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_create_multiple_documents(temp_text_file, temp_json_file):
+    """Test creating multiple documents with metadata."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    metadatas = json.dumps(
+        [
+            {"description": "Test document 1"},
+            {"description": "Test document 2"},
+        ]
+    )
+
+    create_result = await async_invoke(
+        runner,
+        create,
+        temp_text_file,
+        temp_json_file,
+        "--metadatas",
+        metadatas,
+        obj=client,
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    # The command may print multiple JSON objects separated by dashes and status lines.
+    # Extract all JSON objects.
+    json_objects = []
+    start_idx = 0
+    while True:
+        try:
+            # Attempt to extract a JSON object from output[start_idx:]
+            block = extract_json_block(output[start_idx:])
+            json_objects.append(block)
+            # Move start_idx beyond this block to find the next one
+            next_start = output[start_idx:].find("{")
+            start_idx += output[start_idx:].find("{") + 1
+            # Move past the first '{' we found
+            # Actually, let's break after one extraction to avoid infinite loops if the output is large.
+            # Instead, we find multiple objects by splitting on the line of dashes:
+            break
+        except ValueError:
+            break
+
+    # Alternatively, if multiple objects are separated by "----------", we can split and parse each:
+    # This assumes each block between "----------" lines contains exactly one JSON object.
+    blocks = output.split("-" * 40)
+    json_objects = []
+    for block in blocks:
+        block = block.strip()
+        if '"results"' in block and "{" in block and "}" in block:
+            with contextlib.suppress(ValueError):
+                json_objects.append(extract_json_block(block))
+
+    assert (
+        len(json_objects) == 2
+    ), f"Expected 2 JSON objects, got {len(json_objects)}: {output}"
+
+    document_ids = [obj["results"]["document_id"] for obj in json_objects]
+
+    try:
+        # List all documents
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0, list_result.stdout_bytes.decode()
+
+        # Verify both documents were created
+        for doc_id in document_ids:
+            retrieve_result = await async_invoke(
+                runner, retrieve, doc_id, obj=client
+            )
+            assert (
+                retrieve_result.exit_code == 0
+            ), retrieve_result.stdout_bytes.decode()
+    finally:
+        # Cleanup - delete all created documents
+        for doc_id in document_ids:
+            delete_result = await async_invoke(
+                runner, delete, doc_id, obj=client
+            )
+            assert (
+                delete_result.exit_code == 0
+            ), delete_result.stdout_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_create_with_custom_id():
+    """Test creating a document with a custom ID."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    custom_id = str(uuid.uuid4())
+
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write("Test content")
+        temp_path = f.name
+
+    try:
+        create_result = await async_invoke(
+            runner, create, temp_path, "--ids", custom_id, obj=client
+        )
+        assert (
+            create_result.exit_code == 0
+        ), create_result.stdout_bytes.decode()
+
+        output = create_result.stdout_bytes.decode()
+        create_response = extract_json_block(output)
+        assert create_response["results"]["document_id"] == custom_id
+    finally:
+        if os.path.exists(temp_path):
+            os.unlink(temp_path)
+
+        await async_invoke(runner, delete, custom_id, obj=client)
+
+
+@pytest.mark.asyncio
+async def test_retrieve_nonexistent_document():
+    """Test retrieving a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(runner, retrieve, nonexistent_id, obj=client)
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "Document not found" in stderr
+        or "Document not found" in result.stdout_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_list_chunks_nonexistent_document():
+    """Test listing chunks for a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(
+        runner, list_chunks, nonexistent_id, obj=client
+    )
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "No chunks found for the given document ID." in stderr
+        or "No chunks found for the given document ID."
+        in result.stdout_bytes.decode()
+    )
+
+
+# FIXME: This should be returning 'Document not found' but returns an empty list instead.
+# @pytest.mark.asyncio
+# async def test_list_collections_nonexistent_document():
+#     """Test listing collections for a document that doesn't exist."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     nonexistent_id = str(uuid.uuid4())
+#     result = await async_invoke(
+#         runner, list_collections, nonexistent_id, obj=client
+#     )
+
+#     stderr = result.stderr_bytes.decode()
+#     assert (
+#         "Document not found" in stderr
+#         or "Document not found" in result.stdout_bytes.decode()
+#     )
+
+
+@pytest.mark.asyncio
+async def test_delete_nonexistent_document():
+    """Test deleting a document that doesn't exist."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_id = str(uuid.uuid4())
+    result = await async_invoke(runner, delete, nonexistent_id, obj=client)
+
+    stderr = result.stderr_bytes.decode()
+    assert (
+        "No entries found for deletion" in stderr
+        or "No entries found for deletion" in result.stdout_bytes.decode()
+    )

+ 312 - 0
tests/cli/commands/test_graphs_cli.py

@@ -0,0 +1,312 @@
+"""
+Tests for the graphs commands in the CLI.
+    - list
+    - retrieve
+    - reset
+    - update
+    - list-entities
+    - get-entity
+    x remove-entity
+    - list-relationships
+    - get-relationship
+    x remove-relationship
+    - build
+    - list-communities
+    - get-community
+    x update-community
+    x delete-community
+    - pull
+    - remove-document
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.collections import create as create_collection
+from cli.commands.graphs import (
+    build,
+    delete_community,
+    get_community,
+    get_entity,
+    get_relationship,
+    list,
+    list_communities,
+    list_entities,
+    list_relationships,
+    pull,
+    remove_document,
+    remove_entity,
+    remove_relationship,
+    reset,
+    retrieve,
+    update,
+    update_community,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+async def create_test_collection(
+    runner: CliRunner, client: R2RAsyncClient
+) -> str:
+    """Helper function to create a test collection and return its ID."""
+    collection_name = f"test-collection-{uuid.uuid4()}"
+    create_result = await async_invoke(
+        runner, create_collection, collection_name, obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    return response["results"]["id"]
+
+
+@pytest.mark.asyncio
+async def test_graph_basic_operations():
+    """Test basic graph operations: retrieve, reset, update."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # Retrieve graph
+        retrieve_result = await async_invoke(
+            runner, retrieve, collection_id, obj=client
+        )
+        assert retrieve_result.exit_code == 0
+        assert collection_id in retrieve_result.stdout_bytes.decode()
+
+        # Update graph
+        new_name = "Updated Graph Name"
+        new_description = "Updated description"
+        update_result = await async_invoke(
+            runner,
+            update,
+            collection_id,
+            "--name",
+            new_name,
+            "--description",
+            new_description,
+            obj=client,
+        )
+        assert update_result.exit_code == 0
+
+        # Reset graph
+        reset_result = await async_invoke(
+            runner, reset, collection_id, obj=client
+        )
+        assert reset_result.exit_code == 0
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_entity_operations():
+    """Test entity-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List entities (empty initially)
+        list_entities_result = await async_invoke(
+            runner, list_entities, collection_id, obj=client
+        )
+        assert list_entities_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_entities,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent entity operations
+        nonexistent_entity_id = str(uuid.uuid4())
+        get_entity_result = await async_invoke(
+            runner,
+            get_entity,
+            collection_id,
+            nonexistent_entity_id,
+            obj=client,
+        )
+        assert "not found" in get_entity_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_relationship_operations():
+    """Test relationship-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List relationships
+        list_rel_result = await async_invoke(
+            runner, list_relationships, collection_id, obj=client
+        )
+        assert list_rel_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_relationships,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent relationship operations
+        nonexistent_rel_id = str(uuid.uuid4())
+        get_rel_result = await async_invoke(
+            runner,
+            get_relationship,
+            collection_id,
+            nonexistent_rel_id,
+            obj=client,
+        )
+        assert "not found" in get_rel_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_community_operations():
+    """Test community-related operations in a graph."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # List communities
+        list_comm_result = await async_invoke(
+            runner, list_communities, collection_id, obj=client
+        )
+        assert list_comm_result.exit_code == 0
+
+        # Test with pagination
+        paginated_result = await async_invoke(
+            runner,
+            list_communities,
+            collection_id,
+            "--offset",
+            "0",
+            "--limit",
+            "2",
+            obj=client,
+        )
+        assert paginated_result.exit_code == 0
+
+        # Test nonexistent community operations
+        nonexistent_comm_id = str(uuid.uuid4())
+        get_comm_result = await async_invoke(
+            runner,
+            get_community,
+            collection_id,
+            nonexistent_comm_id,
+            obj=client,
+        )
+        assert "not found" in get_comm_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_graph_build_and_pull():
+    """Test graph building and document pull operations."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    collection_id = await create_test_collection(runner, client)
+
+    try:
+        # Test build with minimal settings
+        settings = {"some_setting": "value"}
+        build_result = await async_invoke(
+            runner,
+            build,
+            collection_id,
+            "--settings",
+            json.dumps(settings),
+            obj=client,
+        )
+        assert build_result.exit_code == 0
+
+        # Test pull documents
+        pull_result = await async_invoke(
+            runner, pull, collection_id, obj=client
+        )
+        assert pull_result.exit_code == 0
+
+        # Test remove document (with nonexistent document)
+        nonexistent_doc_id = str(uuid.uuid4())
+        remove_doc_result = await async_invoke(
+            runner,
+            remove_document,
+            collection_id,
+            nonexistent_doc_id,
+            obj=client,
+        )
+        assert "not found" in remove_doc_result.stderr_bytes.decode().lower()
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass
+
+
+@pytest.mark.asyncio
+async def test_list_graphs():
+    """Test listing graphs."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    try:
+        # Test basic list
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0
+
+    finally:
+        # Cleanup will be handled by collection deletion
+        pass

+ 6 - 0
tests/cli/commands/test_indices_cli.py

@@ -0,0 +1,6 @@
+"""
+Tests for the indices commands in the CLI.
+    x list
+    x retrieve
+    x delete
+"""

+ 95 - 0
tests/cli/commands/test_prompts_cli.py

@@ -0,0 +1,95 @@
+"""
+Tests for the prompts commands in the CLI.
+    - list
+    - retrieve
+    x delete
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.prompts import list, retrieve
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_prompts_list():
+    """Test listing prompts."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Test basic list
+    list_result = await async_invoke(runner, list, obj=client)
+    assert list_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_prompts_retrieve():
+    """Test retrieving prompts with various parameters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Test retrieve with just name
+    name = "hyde"
+    retrieve_result = await async_invoke(runner, retrieve, name, obj=client)
+    assert retrieve_result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_nonexistent_prompt():
+    """Test operations on a nonexistent prompt."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    nonexistent_name = f"nonexistent-{uuid.uuid4()}"
+
+    # Test retrieve
+    retrieve_result = await async_invoke(
+        runner, retrieve, nonexistent_name, obj=client
+    )
+    assert "not found" in retrieve_result.stderr_bytes.decode().lower()
+
+
+@pytest.mark.asyncio
+async def test_prompt_retrieve_with_all_options():
+    """Test retrieving a prompt with all options combined."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    name = "example-prompt"
+    inputs = "input1,input2"
+    override = "custom prompt text"
+
+    retrieve_result = await async_invoke(
+        runner,
+        retrieve,
+        name,
+        "--inputs",
+        inputs,
+        "--prompt-override",
+        override,
+        obj=client,
+    )
+    assert retrieve_result.exit_code == 0

+ 213 - 0
tests/cli/commands/test_retrieval_cli.py

@@ -0,0 +1,213 @@
+"""
+Tests for the retrieval commands in the CLI.
+    - search
+    - rag
+"""
+
+import json
+import tempfile
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.documents import create as create_document
+from cli.commands.retrieval import rag, search
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+async def create_test_document(
+    runner: CliRunner, client: R2RAsyncClient
+) -> str:
+    """Helper function to create a test document and return its ID."""
+    with tempfile.NamedTemporaryFile(
+        mode="w", suffix=".txt", delete=False
+    ) as f:
+        f.write(
+            "This is a test document about artificial intelligence and machine learning. "
+            "AI systems can be trained on large datasets to perform various tasks."
+        )
+        temp_path = f.name
+
+    create_result = await async_invoke(
+        runner, create_document, temp_path, obj=client
+    )
+    response = extract_json_block(create_result.stdout_bytes.decode())
+    return response["results"]["document_id"]
+
+
+@pytest.mark.asyncio
+async def test_basic_search():
+    """Test basic search functionality."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test document first
+    document_id = await create_test_document(runner, client)
+
+    try:
+        # Test basic search
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "artificial intelligence",
+            "--limit",
+            "5",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        assert "Vector search results:" in search_result.stdout_bytes.decode()
+
+    finally:
+        # Cleanup will be handled by document deletion in a real implementation
+        pass
+
+
+@pytest.mark.asyncio
+async def test_search_with_filters():
+    """Test search with filters."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        filters = json.dumps({"document_id": {"$in": [document_id]}})
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "machine learning",
+            "--filters",
+            filters,
+            "--limit",
+            "5",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        output = search_result.stdout_bytes.decode()
+        assert "Vector search results:" in output
+        assert document_id in output
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_search_with_advanced_options():
+    """Test search with advanced options."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        search_result = await async_invoke(
+            runner,
+            search,
+            "--query",
+            "AI systems",
+            "--use-hybrid-search",
+            "true",
+            "--search-strategy",
+            "vanilla",
+            "--graph-search-enabled",
+            "true",
+            "--chunk-search-enabled",
+            "true",
+            obj=client,
+        )
+        assert search_result.exit_code == 0
+        output = search_result.stdout_bytes.decode()
+        assert "Vector search results:" in output
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_basic_rag():
+    """Test basic RAG functionality."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_rag_with_streaming():
+    """Test RAG with streaming enabled."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            "--stream",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass
+
+
+@pytest.mark.asyncio
+async def test_rag_with_model_specification():
+    """Test RAG with specific model."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    document_id = await create_test_document(runner, client)
+
+    try:
+        rag_result = await async_invoke(
+            runner,
+            rag,
+            "--query",
+            "What is this document about?",
+            "--rag-model",
+            "azure/gpt-4o-mini",
+            obj=client,
+        )
+        assert rag_result.exit_code == 0
+
+    finally:
+        pass

+ 338 - 0
tests/cli/commands/test_system_cli.py

@@ -0,0 +1,338 @@
+"""
+Tests for the system commands in the CLI.
+    - health
+    - settings
+    - status
+    x serve
+    x image-exists
+    x docker-down
+    x generate-report
+    x update
+    - version
+
+"""
+
+import json
+from importlib.metadata import version as get_version
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.system import health, settings, status, version
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+@pytest.mark.asyncio
+async def test_health_against_server():
+    """Test health check against a real server."""
+    # Create real client
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+
+    # Run command
+    runner = CliRunner(mix_stderr=False)
+    result = await async_invoke(runner, health, obj=client)
+
+    # Extract just the JSON part (everything after the "Time taken" line)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Basic validation
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+    assert "message" in response_data["results"]
+    assert response_data["results"]["message"] == "ok"
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_health_server_down():
+    """Test health check when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, health, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_health_invalid_url():
+    """Test health check with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, health, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_settings_against_server():
+    """Test settings retrieval against a real server."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+
+    # Extract JSON part after "Time taken" line
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Validate response structure
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+    assert "config" in response_data["results"]
+    assert "prompts" in response_data["results"]
+
+    # Validate key configuration sections
+    config = response_data["results"]["config"]
+    assert "completion" in config
+    assert "database" in config
+    assert "embedding" in config
+    assert "ingestion" in config
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_settings_server_down():
+    """Test settings retrieval when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_settings_invalid_url():
+    """Test settings retrieval with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_settings_response_structure():
+    """Test detailed structure of settings response."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    # Validate prompts structure
+    prompts = response_data["results"]["prompts"]
+    assert "results" in prompts
+    assert "total_entries" in prompts
+    assert isinstance(prompts["results"], list)
+
+    # Validate prompt entries
+    for prompt in prompts["results"]:
+        assert "name" in prompt
+        assert "id" in prompt
+        assert "template" in prompt
+        assert "input_types" in prompt
+        assert "created_at" in prompt
+        assert "updated_at" in prompt
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_settings_config_validation():
+    """Test specific configuration values in settings response."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, settings, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    config = response_data["results"]["config"]
+
+    # Validate completion config
+    completion = config["completion"]
+    assert "provider" in completion
+    assert "concurrent_request_limit" in completion
+    assert "generation_config" in completion
+
+    # Validate database config
+    database = config["database"]
+    assert "provider" in database
+    assert "default_collection_name" in database
+    assert "limits" in database
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_against_server():
+    """Test status check against a real server."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+
+    # Extract JSON part after "Time taken" line
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+
+    # Validate response structure
+    response_data = json.loads(json_str)
+    assert "results" in response_data
+
+    # Validate specific fields
+    results = response_data["results"]
+    assert "start_time" in results
+    assert "uptime_seconds" in results
+    assert "cpu_usage" in results
+    assert "memory_usage" in results
+
+    # Validate data types
+    assert isinstance(results["uptime_seconds"], (int, float))
+    assert isinstance(results["cpu_usage"], (int, float))
+    assert isinstance(results["memory_usage"], (int, float))
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_server_down():
+    """Test status check when server is unreachable."""
+    client = R2RAsyncClient(base_url="http://localhost:54321")  # Invalid port
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    assert result.exit_code != 0
+    assert (
+        "Request failed: All connection attempts failed"
+        in result.stderr_bytes.decode()
+    )
+
+
+@pytest.mark.asyncio
+async def test_status_invalid_url():
+    """Test status check with invalid URL."""
+    client = R2RAsyncClient(base_url="http://invalid.localhost")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    assert result.exit_code != 0
+    assert "Request failed" in result.stderr_bytes.decode()
+
+
+@pytest.mark.asyncio
+async def test_status_value_ranges():
+    """Test that status values are within expected ranges."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    results = response_data["results"]
+
+    # CPU usage should be between 0 and 100
+    assert 0 <= results["cpu_usage"] <= 100
+
+    # Memory usage should be between 0 and 100
+    assert 0 <= results["memory_usage"] <= 100
+
+    # Uptime should be positive
+    assert results["uptime_seconds"] > 0
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_status_start_time_format():
+    """Test that start_time is in correct ISO format."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    result = await async_invoke(runner, status, obj=client)
+    output = result.stdout_bytes.decode()
+    json_str = output.split("\n", 1)[1]
+    response_data = json.loads(json_str)
+
+    from datetime import datetime
+
+    # Verify start_time is valid ISO format
+    start_time = response_data["results"]["start_time"]
+    try:
+        datetime.fromisoformat(start_time.replace("Z", "+00:00"))
+    except ValueError:
+        pytest.fail("start_time is not in valid ISO format")
+
+    assert result.exit_code == 0
+
+
+@pytest.mark.asyncio
+async def test_version_command():
+    """Test basic version command functionality."""
+    runner = CliRunner()
+    result = await async_invoke(runner, version)
+
+    # Verify command succeeded
+    assert result.exit_code == 0
+
+    # Verify output is valid JSON and matches actual package version
+    expected_version = get_version("r2r")
+    actual_version = json.loads(result.stdout_bytes.decode())
+    assert actual_version == expected_version
+
+
+@pytest.mark.asyncio
+async def test_version_output_format():
+    """Test that version output is properly formatted JSON."""
+    runner = CliRunner()
+    result = await async_invoke(runner, version)
+
+    # Verify output is valid JSON
+    try:
+        output = result.stdout_bytes.decode()
+        parsed = json.loads(output)
+        assert isinstance(parsed, str)  # Version should be a string
+    except json.JSONDecodeError:
+        pytest.fail("Version output is not valid JSON")
+
+    # Should be non-empty output ending with newline
+    assert output.strip()
+    assert output.endswith("\n")
+
+
+@pytest.mark.asyncio
+async def test_version_error_handling(monkeypatch):
+    """Test error handling when version import fails."""
+
+    def mock_version(_):
+        raise ImportError("Package not found")
+
+    # Mock the version function to raise an error
+    monkeypatch.setattr("importlib.metadata.version", mock_version)
+
+    runner = CliRunner(mix_stderr=False)
+    result = await async_invoke(runner, version)
+
+    # Verify command failed with exception
+    assert result.exit_code == 1
+    error_output = result.stderr_bytes.decode()
+    assert "An unexpected error occurred" in error_output
+    assert "Package not found" in error_output

+ 143 - 0
tests/cli/commands/test_users_cli.py

@@ -0,0 +1,143 @@
+"""
+Tests for the user commands in the CLI.
+    - create
+    - list
+    - retrieve
+    - me
+    x list-collections
+    x add-to-collection
+    x remove-from-collection
+"""
+
+import json
+import uuid
+
+import pytest
+from click.testing import CliRunner
+
+from cli.commands.users import (
+    add_to_collection,
+    create,
+    list,
+    list_collections,
+    me,
+    remove_from_collection,
+    retrieve,
+)
+from r2r import R2RAsyncClient
+from tests.cli.async_invoke import async_invoke
+
+
+def extract_json_block(output: str) -> dict:
+    """Extract and parse the first valid JSON object found in the output."""
+    start = output.find("{")
+    if start == -1:
+        raise ValueError("No JSON object start found in output")
+
+    brace_count = 0
+    for i, char in enumerate(output[start:], start=start):
+        if char == "{":
+            brace_count += 1
+        elif char == "}":
+            brace_count -= 1
+
+            if brace_count == 0:
+                json_str = output[start : i + 1].strip()
+                return json.loads(json_str)
+    raise ValueError("No complete JSON object found in output")
+
+
+@pytest.mark.asyncio
+async def test_user_lifecycle():
+    """Test the complete lifecycle of a user: create, retrieve, list, collections."""
+    client = R2RAsyncClient(base_url="http://localhost:7272")
+    runner = CliRunner(mix_stderr=False)
+
+    # Create test user with random email
+    test_email = f"test_{uuid.uuid4()}@example.com"
+    test_password = "TestPassword123!"
+
+    # Create user
+    create_result = await async_invoke(
+        runner, create, test_email, test_password, obj=client
+    )
+    assert create_result.exit_code == 0, create_result.stdout_bytes.decode()
+
+    output = create_result.stdout_bytes.decode()
+    create_response = extract_json_block(output)
+    user_id = create_response["results"]["id"]
+
+    try:
+        # List users and verify our new user is included
+        list_result = await async_invoke(runner, list, obj=client)
+        assert list_result.exit_code == 0, list_result.stdout_bytes.decode()
+        list_output = list_result.stdout_bytes.decode()
+        assert test_email in list_output
+
+        # Retrieve specific user
+        retrieve_result = await async_invoke(
+            runner, retrieve, user_id, obj=client
+        )
+        assert (
+            retrieve_result.exit_code == 0
+        ), retrieve_result.stdout_bytes.decode()
+        retrieve_output = retrieve_result.stdout_bytes.decode()
+        retrieve_response = extract_json_block(retrieve_output)
+        assert retrieve_response["results"]["email"] == test_email
+
+        # Test me endpoint
+        me_result = await async_invoke(runner, me, obj=client)
+        assert me_result.exit_code == 0, me_result.stdout_bytes.decode()
+
+        # List collections for user
+        collections_result = await async_invoke(
+            runner, list_collections, user_id, obj=client
+        )
+        assert (
+            collections_result.exit_code == 0
+        ), collections_result.stdout_bytes.decode()
+
+    finally:
+        # We don't delete the user since there's no delete command
+        pass
+
+
+# FIXME: This should be returning 'User not found' but returns an empty list instead.
+# @pytest.mark.asyncio
+# async def test_retrieve_nonexistent_user():
+#     """Test retrieving a user that doesn't exist."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     nonexistent_id = str(uuid.uuid4())
+#     result = await async_invoke(runner, retrieve, nonexistent_id, obj=client)
+
+#     assert result.exit_code != 0
+#     error_output = result.stderr_bytes.decode()
+#     assert "User not found" in error_output
+
+
+# FIXME: This is returning with a status of 0 but has a 400 on the server side?
+# @pytest.mark.asyncio
+# async def test_create_duplicate_user():
+#     """Test creating a user with an email that already exists."""
+#     client = R2RAsyncClient(base_url="http://localhost:7272")
+#     runner = CliRunner(mix_stderr=False)
+
+#     test_email = f"test_{uuid.uuid4()}@example.com"
+#     test_password = "TestPassword123!"
+
+#     # Create first user
+#     first_result = await async_invoke(
+#         runner, create, test_email, test_password, obj=client
+#     )
+#     assert first_result.exit_code == 0
+
+#     # Try to create second user with same email
+#     second_result = await async_invoke(
+#         runner, create, test_email, test_password, obj=client
+#     )
+#     print(f"SECOND RESULT: {second_result}")
+#     assert second_result.exit_code != 0
+#     error_output = second_result.stderr_bytes.decode()
+#     assert "already exists" in error_output.lower()

+ 118 - 0
tests/cli/utils/test_timer.py

@@ -0,0 +1,118 @@
+import asyncio
+import time
+from unittest.mock import patch
+
+import asyncclick as click
+import pytest
+from click.testing import CliRunner
+
+from cli.utils.timer import timer
+from tests.cli.async_invoke import async_invoke
+
+
+@click.command()
+async def test_command():
+    with timer():
+        time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_measures_time():
+    runner = CliRunner()
+    result = await async_invoke(runner, test_command)
+    output = result.stdout_bytes.decode()
+    assert "Time taken:" in output
+    assert "seconds" in output
+    measured_time = float(output.split(":")[1].split()[0])
+    assert 0.1 <= measured_time <= 0.2
+
+
+@click.command()
+async def zero_duration_command():
+    with timer():
+        pass
+
+
+@pytest.mark.asyncio
+async def test_timer_zero_duration():
+    runner = CliRunner()
+    result = await async_invoke(runner, zero_duration_command)
+    output = result.stdout_bytes.decode()
+    measured_time = float(output.split(":")[1].split()[0])
+    assert measured_time >= 0
+    assert measured_time < 0.1
+
+
+@click.command()
+async def exception_command():
+    with timer():
+        raise ValueError("Test exception")
+
+
+@pytest.mark.asyncio
+async def test_timer_with_exception():
+    runner = CliRunner()
+    result = await async_invoke(runner, exception_command)
+    assert result.exit_code != 0
+    assert isinstance(result.exception, ValueError)
+
+
+@click.command()
+async def async_command():
+    with timer():
+        await asyncio.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_with_async_code():
+    runner = CliRunner()
+    result = await async_invoke(runner, async_command)
+    output = result.stdout_bytes.decode()
+    measured_time = float(output.split(":")[1].split()[0])
+    assert 0.1 <= measured_time <= 0.2
+
+
+@click.command()
+async def nested_command():
+    with timer():
+        time.sleep(0.1)
+        with timer():
+            time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_multiple_nested():
+    runner = CliRunner()
+    result = await async_invoke(runner, nested_command)
+    output = result.stdout_bytes.decode()
+    assert output.count("Time taken:") == 2
+
+
+@click.command()
+async def mock_time_command():
+    with timer():
+        pass
+
+
+@pytest.mark.asyncio
+@patch("time.time")
+async def test_timer_with_mock_time(mock_time):
+    mock_time.side_effect = [0, 1]  # Start and end times
+    runner = CliRunner()
+    result = await async_invoke(runner, mock_time_command)
+    output = result.stdout_bytes.decode()
+    assert "Time taken: 1.00 seconds" in output
+
+
+@click.command()
+async def precision_command():
+    with timer():
+        time.sleep(0.1)
+
+
+@pytest.mark.asyncio
+async def test_timer_precision():
+    runner = CliRunner()
+    result = await async_invoke(runner, precision_command)
+    output = result.stdout_bytes.decode()
+    assert len(output.split(":")[1].split()[0].split(".")[1]) == 2

+ 0 - 2
tests/integration/test_conversations.py

@@ -86,7 +86,6 @@ def test_delete_non_existent_conversation(client):
     bad_id = str(uuid.uuid4())
     bad_id = str(uuid.uuid4())
     with pytest.raises(R2RException) as exc_info:
     with pytest.raises(R2RException) as exc_info:
         result = client.conversations.delete(id=bad_id)
         result = client.conversations.delete(id=bad_id)
-        print(result)
     assert (
     assert (
         exc_info.value.status_code == 404
         exc_info.value.status_code == 404
     ), "Wrong error code for delete non-existent"
     ), "Wrong error code for delete non-existent"
@@ -122,7 +121,6 @@ def test_update_message(client, test_conversation):
         content="Updated content",
         content="Updated content",
         metadata={"new_key": "new_value"},
         metadata={"new_key": "new_value"},
     )["results"]
     )["results"]
-    print(update_resp)
     # /new_branch_id = update_resp["new_branch_id"]
     # /new_branch_id = update_resp["new_branch_id"]
 
 
     assert update_resp["message"], "No message returned after update"
     assert update_resp["message"], "No message returned after update"

+ 187 - 0
tests/integration/test_filters.py

@@ -0,0 +1,187 @@
+import uuid
+
+import pytest
+
+from r2r import R2RException
+
+
+@pytest.fixture
+def setup_docs_with_collections(client):
+    # Create some test collections
+
+    random_suffix = str(uuid.uuid4())[:8]
+    coll_ids = []
+    for i in range(3):
+        resp = client.collections.create(name=f"TestColl{i}")["results"]
+        coll_ids.append(resp["id"])
+
+    # Create documents with different collection arrangements:
+    # doc1: [coll1]
+    doc1 = client.documents.create(
+        raw_text="Doc in coll1" + random_suffix, run_with_orchestration=False
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[0], doc1)
+
+    # doc2: [coll1, coll2]
+    doc2 = client.documents.create(
+        raw_text="Doc in coll1 and coll2" + random_suffix,
+        run_with_orchestration=False,
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[0], doc2)
+    client.collections.add_document(coll_ids[1], doc2)
+
+    # doc3: no collections
+    doc3 = client.documents.create(
+        raw_text="Doc in no collections" + random_suffix,
+        run_with_orchestration=False,
+    )["results"]["document_id"]
+
+    # doc4: [coll3]
+    doc4 = client.documents.create(
+        raw_text="Doc in coll3" + random_suffix, run_with_orchestration=False
+    )["results"]["document_id"]
+    client.collections.add_document(coll_ids[2], doc4)
+
+    yield {"coll_ids": coll_ids, "doc_ids": [doc1, doc2, doc3, doc4]}
+
+    # Cleanup
+    for d_id in [doc1, doc2, doc3, doc4]:
+        try:
+            client.documents.delete(id=d_id)
+        except R2RException:
+            pass
+    for c_id in coll_ids:
+        try:
+            client.collections.delete(c_id)
+        except R2RException:
+            pass
+
+
+def test_collection_id_eq_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id = coll_ids[0] should match doc1 and doc2 only
+    filters = {"collection_id": {"$eq": coll_ids[0]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+    } == found_ids, f"Expected doc1 and doc2, got {found_ids}"
+
+
+def test_collection_id_ne_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id != coll_ids[0] means docs that are NOT in coll0
+    # Those are doc3 (no collections) and doc4 (in coll3 only)
+    filters = {"collection_id": {"$ne": coll_ids[0]}}
+    # listed = client.documents.list(limit=10, offset=0, filters=filters)["results"]
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc3,
+        doc4,
+    } == found_ids, f"Expected doc3 and doc4, got {found_ids}"
+
+
+def test_collection_id_in_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id in [coll_ids[0], coll_ids[2]] means docs in either coll0 or coll2
+    # doc1 in coll0, doc2 in coll0, doc4 in coll2
+    # doc3 is in none
+    filters = {"collection_id": {"$in": [coll_ids[0], coll_ids[2]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+        doc4,
+    } == found_ids, f"Expected doc1, doc2, doc4, got {found_ids}"
+
+
+def test_collection_id_nin_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # collection_id nin [coll_ids[1]] means docs that do NOT belong to coll1
+    # doc2 belongs to coll1, so exclude doc2
+    # doc1, doc3, doc4 remain
+    filters = {"collection_id": {"$nin": [coll_ids[1]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc3,
+        doc4,
+    } == found_ids, f"Expected doc1, doc3, doc4, got {found_ids}"
+
+
+def test_collection_id_contains_filter(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # $contains: For a single collection_id, we interpret as arrays that must contain the given UUID.
+    # If collection_id {"$contains": "coll_ids[0]"}, docs must have coll0 in their array
+    # That would be doc1 and doc2 only
+    filters = {"collection_id": {"$contains": coll_ids[0]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {
+        doc1,
+        doc2,
+    } == found_ids, f"Expected doc1 and doc2, got {found_ids}"
+
+
+def test_collection_id_contains_multiple(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc_ids = setup_docs_with_collections["doc_ids"]
+    doc1, doc2, doc3, doc4 = doc_ids
+
+    # If we allow $contains with a list, e.g., {"$contains": [coll_ids[0], coll_ids[1]]},
+    # this should mean the doc's collection_ids contain ALL of these.
+    # Only doc2 has coll0 AND coll1. doc1 only has coll0, doc3 no collections, doc4 only coll3.
+    filters = {"collection_id": {"$contains": [coll_ids[0], coll_ids[1]]}}
+    listed = client.retrieval.search(
+        query="whoami", search_settings={"filters": filters}
+    )["results"]["chunk_search_results"]
+    found_ids = {d["document_id"] for d in listed}
+    assert {doc2} == found_ids, f"Expected doc2 only, got {found_ids}"
+
+
+def test_delete_by_collection_id_eq(client, setup_docs_with_collections):
+    coll_ids = setup_docs_with_collections["coll_ids"]
+    doc1, doc2, doc3, doc4 = setup_docs_with_collections["doc_ids"]
+
+    # Delete documents in coll0
+    filters = {"collection_id": {"$eq": coll_ids[0]}}
+    del_resp = client.documents.delete_by_filter(filters)["results"]
+    assert del_resp["success"], "Failed to delete by collection_id $eq filter"
+
+    # doc1 and doc2 should be deleted, doc3 and doc4 remain
+    for d_id in [doc1, doc2]:
+        with pytest.raises(R2RException) as exc:
+            client.documents.retrieve(d_id)
+        assert exc.value.status_code == 404, f"Doc {d_id} still exists!"
+    # Check doc3 and doc4 still exist
+    assert client.documents.retrieve(doc3)
+    assert client.documents.retrieve(doc4)

+ 37 - 1
tests/integration/test_users.py

@@ -2,7 +2,9 @@ import uuid
 
 
 import pytest
 import pytest
 
 
+from core.database.postgres import PostgresUserHandler
 from r2r import R2RClient, R2RException
 from r2r import R2RClient, R2RException
+from shared.abstractions import User
 
 
 
 
 @pytest.fixture(scope="session")
 @pytest.fixture(scope="session")
@@ -334,7 +336,6 @@ def test_non_owner_delete_collection(client):
     client.users.login(non_owner_email, non_owner_password)
     client.users.login(non_owner_email, non_owner_password)
     with pytest.raises(R2RException) as exc_info:
     with pytest.raises(R2RException) as exc_info:
         result = client.collections.delete(coll_id)
         result = client.collections.delete(coll_id)
-        print("result = ", result)
     assert (
     assert (
         exc_info.value.status_code == 403
         exc_info.value.status_code == 403
     ), "Wrong error code for non-owner deletion attempt"
     ), "Wrong error code for non-owner deletion attempt"
@@ -599,3 +600,38 @@ def test_multiple_api_keys(client):
         ), f"Key {key_id} still exists after deletion"
         ), f"Key {key_id} still exists after deletion"
 
 
     client.users.logout()
     client.users.logout()
+
+
+def test_update_user_limits_overrides(client: R2RClient):
+    # 1) Create user
+    user_email = f"test_{uuid.uuid4()}@example.com"
+    client.users.register(user_email, "SomePassword123!")
+    client.users.login(user_email, "SomePassword123!")
+
+    # 2) Confirm the default overrides is None
+    fetched_user = client.users.me()["results"]
+    client.users.logout()
+
+    assert len(fetched_user["limits_overrides"]) == 0
+
+    # 3) Update the overrides
+    overrides = {
+        "global_per_min": 10,
+        "monthly_limit": 3000,
+        "route_overrides": {
+            "/some-route": {"route_per_min": 5},
+        },
+    }
+    client.users.update(id=fetched_user["id"], limits_overrides=overrides)
+
+    # 4) Fetch user again, check
+    client.users.login(user_email, "SomePassword123!")
+    updated_user = client.users.me()["results"]
+    assert len(updated_user["limits_overrides"]) != 0
+    assert updated_user["limits_overrides"]["global_per_min"] == 10
+    assert (
+        updated_user["limits_overrides"]["route_overrides"]["/some-route"][
+            "route_per_min"
+        ]
+        == 5
+    )

+ 344 - 0
tests/unit/conftest.py

@@ -0,0 +1,344 @@
+# tests/conftest.py
+import os
+from uuid import uuid4
+
+import pytest
+
+from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
+from core.database.postgres import (
+    PostgresChunksHandler,
+    PostgresCollectionsHandler,
+    PostgresConnectionManager,
+    PostgresConversationsHandler,
+    PostgresDatabaseProvider,
+    PostgresDocumentsHandler,
+    PostgresGraphsHandler,
+    PostgresLimitsHandler,
+)
+from core.database.users import (  # Make sure this import is correct
+    PostgresUserHandler,
+)
+from core.providers import NaClCryptoConfig, NaClCryptoProvider
+from core.utils import generate_user_id
+
+TEST_DB_CONNECTION_STRING = os.environ.get(
+    "TEST_DB_CONNECTION_STRING",
+    "postgresql://postgres:postgres@localhost:5432/test_db",
+)
+
+
+@pytest.fixture
+async def db_provider():
+    crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
+    db_config = DatabaseConfig(
+        app=AppConfig(project_name="test_project"),
+        provider="postgres",
+        connection_string=TEST_DB_CONNECTION_STRING,
+        postgres_configuration_settings={
+            "max_connections": 10,
+            "statement_cache_size": 100,
+        },
+        project_name="test_project",
+    )
+
+    dimension = 4
+    quantization_type = VectorQuantizationType.FP32
+
+    db_provider = PostgresDatabaseProvider(
+        db_config, dimension, crypto_provider, quantization_type
+    )
+
+    await db_provider.initialize()
+    yield db_provider
+    # Teardown logic if needed
+    await db_provider.close()
+
+
+@pytest.fixture
+def crypto_provider():
+    # Provide a crypto provider fixture if needed separately
+    return NaClCryptoProvider(NaClCryptoConfig(app={}))
+
+
+@pytest.fixture
+async def chunks_handler(db_provider):
+    dimension = db_provider.dimension
+    quantization_type = db_provider.quantization_type
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresChunksHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+        quantization_type=quantization_type,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def collections_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    config = db_provider.config
+
+    handler = PostgresCollectionsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        config=config,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def conversations_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresConversationsHandler(project_name, connection_manager)
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def documents_handler(db_provider):
+    dimension = db_provider.dimension
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresDocumentsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def graphs_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    dimension = db_provider.dimension
+    quantization_type = db_provider.quantization_type
+
+    # If collections_handler is needed, you can depend on the collections_handler fixture
+    # or pass None if it's optional.
+    handler = PostgresGraphsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        dimension=dimension,
+        quantization_type=quantization_type,
+        collections_handler=None,  # if needed, or await collections_handler fixture
+    )
+    await handler.create_tables()
+    return handler
+
+
+@pytest.fixture
+async def limits_handler(db_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+    config = db_provider.config
+
+    handler = PostgresLimitsHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        config=config,
+    )
+    await handler.create_tables()
+    # Optionally truncate
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('request_log')};"
+    )
+    return handler
+
+
+@pytest.fixture
+async def users_handler(db_provider, crypto_provider):
+    project_name = db_provider.project_name
+    connection_manager = db_provider.connection_manager
+
+    handler = PostgresUserHandler(
+        project_name=project_name,
+        connection_manager=connection_manager,
+        crypto_provider=crypto_provider,
+    )
+    await handler.create_tables()
+
+    # Optionally clean up users table before each test
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('users')} CASCADE;"
+    )
+    await connection_manager.execute_query(
+        f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;"
+    )
+
+    return handler
+
+
+# # tests/conftest.py
+# import pytest
+# import os
+
+# from core.database.postgres import (
+#     PostgresChunksHandler,
+#     PostgresConnectionManager,
+#     PostgresDatabaseProvider,
+#     PostgresCollectionsHandler,
+#     PostgresConversationsHandler,
+#     PostgresDocumentsHandler,
+#     PostgresGraphsHandler,
+#     PostgresLimitsHandler,
+#     PostgresUserHandler
+# )
+# from core.providers import NaClCryptoConfig, NaClCryptoProvider
+# from core.base import  DatabaseConfig, VectorQuantizationType
+
+
+# TEST_DB_CONNECTION_STRING = os.environ.get(
+#     "TEST_DB_CONNECTION_STRING",
+#     "postgresql://postgres:postgres@localhost:5432/test_db",
+# )
+
+# @pytest.fixture
+# async def db_provider():
+#     # Example: a crypto provider needed by the database
+#     crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
+
+#     db_config = DatabaseConfig(
+#         app={},
+#         provider="postgres",
+#         connection_string=TEST_DB_CONNECTION_STRING,
+#         # Set these values as appropriate
+#         postgres_configuration_settings={
+#             "max_connections": 10,
+#             "statement_cache_size": 100,
+#         },
+#     )
+
+#     dimension = 4
+#     quantization_type = VectorQuantizationType.FP32
+
+#     db_provider = PostgresDatabaseProvider(
+#         db_config, dimension, crypto_provider, quantization_type
+#     )
+#     await db_provider.initialize()
+#     yield db_provider
+
+#     # Teardown logic if needed: close pools, drop tables, etc.
+#     await db_provider.close()
+
+
+# @pytest.fixture
+# async def chunks_handler(db_provider):
+#     # Assuming project_name and dimension are retrieved from db_provider
+#     dimension = db_provider.dimension
+#     quantization_type = db_provider.quantization_type
+#     project_name = db_provider.project_name
+#     connection_manager = (
+#         db_provider.connection_manager
+#     )  # type: PostgresConnectionManager
+
+#     handler = PostgresChunksHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#         quantization_type=quantization_type,
+#     )
+#     await handler.create_tables()
+#     return handler
+
+
+# @pytest.fixture
+# async def collections_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     config = db_provider.config
+
+#     handler = PostgresCollectionsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         config=config
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def conversations_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresConversationsHandler(project_name, connection_manager)
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def documents_handler(db_provider):
+#     dimension = db_provider.dimension
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresDocumentsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def graphs_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     dimension = db_provider.dimension
+#     quantization_type = db_provider.quantization_type
+
+#     # Constructing graphs handler with required args
+#     handler = PostgresGraphsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         dimension=dimension,
+#         quantization_type=quantization_type,
+#         collections_handler=None  # If needed, you can mock or create a collections_handler
+#     )
+#     await handler.create_tables()
+#     return handler
+
+# @pytest.fixture
+# async def limits_handler(db_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+#     config = db_provider.config  # This has default limits
+
+#     handler = PostgresLimitsHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         config=config,
+#     )
+#     await handler.create_tables()
+#     # Optionally truncate after creation to ensure clean state
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('request_log')};")
+
+#     return handler
+
+
+# @pytest.fixture
+# async def users_handler(db_provider, crypto_provider):
+#     project_name = db_provider.project_name
+#     connection_manager = db_provider.connection_manager
+
+#     handler = PostgresUserHandler(
+#         project_name=project_name,
+#         connection_manager=connection_manager,
+#         crypto_provider=crypto_provider,
+#     )
+#     await handler.create_tables()
+
+#     # Optionally clean up users table before each test
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
+#     await connection_manager.execute_query(f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
+
+#     return handler

+ 315 - 0
tests/unit/test_chunks.py

@@ -0,0 +1,315 @@
+# tests/integration/test_chunks.py
+import asyncio
+import uuid
+from typing import AsyncGenerator, Optional, Tuple
+
+import pytest
+
+from r2r import R2RAsyncClient, R2RException
+
+
+class AsyncR2RTestClient:
+    """Wrapper to ensure async operations use the correct event loop"""
+
+    def __init__(self, base_url: str = "http://localhost:7272"):
+        self.client = R2RAsyncClient(base_url)
+
+    async def create_document(
+        self, chunks: list[str], run_with_orchestration: bool = False
+    ) -> Tuple[str, list[dict]]:
+        response = await self.client.documents.create(
+            chunks=chunks, run_with_orchestration=run_with_orchestration
+        )
+        return response["results"]["document_id"], []
+
+    async def delete_document(self, doc_id: str) -> None:
+        await self.client.documents.delete(id=doc_id)
+
+    async def list_chunks(self, doc_id: str) -> list[dict]:
+        response = await self.client.documents.list_chunks(id=doc_id)
+        return response["results"]
+
+    async def retrieve_chunk(self, chunk_id: str) -> dict:
+        response = await self.client.chunks.retrieve(id=chunk_id)
+        return response["results"]
+
+    async def update_chunk(
+        self, chunk_id: str, text: str, metadata: Optional[dict] = None
+    ) -> dict:
+        response = await self.client.chunks.update(
+            {"id": chunk_id, "text": text, "metadata": metadata or {}}
+        )
+        return response["results"]
+
+    async def delete_chunk(self, chunk_id: str) -> dict:
+        response = await self.client.chunks.delete(id=chunk_id)
+        return response["results"]
+
+    async def search_chunks(self, query: str, limit: int = 5) -> list[dict]:
+        response = await self.client.chunks.search(
+            query=query, search_settings={"limit": limit}
+        )
+        return response["results"]
+
+    async def register_user(self, email: str, password: str) -> None:
+        await self.client.users.register(email, password)
+
+    async def login_user(self, email: str, password: str) -> None:
+        await self.client.users.login(email, password)
+
+    async def logout_user(self) -> None:
+        await self.client.users.logout()
+
+
+@pytest.fixture
+async def test_client() -> AsyncGenerator[AsyncR2RTestClient, None]:
+    """Create a test client."""
+    client = AsyncR2RTestClient()
+    yield client
+
+
+@pytest.fixture
+async def test_document(
+    test_client: AsyncR2RTestClient,
+) -> AsyncGenerator[Tuple[str, list[dict]], None]:
+    """Create a test document with chunks."""
+    doc_id, _ = await test_client.create_document(
+        ["Test chunk 1", "Test chunk 2"]
+    )
+    await asyncio.sleep(1)  # Wait for ingestion
+    chunks = await test_client.list_chunks(doc_id)
+    yield doc_id, chunks
+    try:
+        await test_client.delete_document(doc_id)
+    except R2RException:
+        pass
+
+
+class TestChunks:
+    @pytest.mark.asyncio
+    async def test_create_and_list_chunks(
+        self, test_client: AsyncR2RTestClient
+    ):
+        # Create document with chunks
+        doc_id, _ = await test_client.create_document(
+            ["Hello chunk", "World chunk"]
+        )
+        await asyncio.sleep(1)  # Wait for ingestion
+
+        # List and verify chunks
+        chunks = await test_client.list_chunks(doc_id)
+        assert len(chunks) == 2, "Expected 2 chunks in the document"
+
+        # Cleanup
+        await test_client.delete_document(doc_id)
+
+    @pytest.mark.asyncio
+    async def test_retrieve_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        retrieved = await test_client.retrieve_chunk(chunk_id)
+        assert retrieved["id"] == chunk_id, "Retrieved wrong chunk ID"
+        assert retrieved["text"] == "Test chunk 1", "Chunk text mismatch"
+
+    @pytest.mark.asyncio
+    async def test_update_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Update chunk
+        updated = await test_client.update_chunk(
+            chunk_id, "Updated text", {"version": 2}
+        )
+        assert updated["text"] == "Updated text", "Chunk text not updated"
+        assert updated["metadata"]["version"] == 2, "Metadata not updated"
+
+    @pytest.mark.asyncio
+    async def test_delete_chunk(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Delete and verify
+        result = await test_client.delete_chunk(chunk_id)
+        assert result["success"], "Chunk deletion failed"
+
+        # Verify deletion
+        with pytest.raises(R2RException) as exc_info:
+            await test_client.retrieve_chunk(chunk_id)
+        assert exc_info.value.status_code == 404
+
+    @pytest.mark.asyncio
+    async def test_search_chunks(self, test_client: AsyncR2RTestClient):
+        # Create searchable document
+        doc_id, _ = await test_client.create_document(
+            ["Aristotle reference", "Another piece of text"]
+        )
+        await asyncio.sleep(1)  # Wait for indexing
+
+        # Search
+        results = await test_client.search_chunks("Aristotle")
+        assert len(results) > 0, "No search results found"
+
+        # Cleanup
+        await test_client.delete_document(doc_id)
+
+    @pytest.mark.asyncio
+    async def test_unauthorized_chunk_access(
+        self, test_client: AsyncR2RTestClient, test_document
+    ):
+        doc_id, chunks = test_document
+        chunk_id = chunks[0]["id"]
+
+        # Create and login as different user
+        non_owner_client = AsyncR2RTestClient()
+        email = f"test_{uuid.uuid4()}@example.com"
+        await non_owner_client.register_user(email, "password123")
+        await non_owner_client.login_user(email, "password123")
+
+        # Attempt unauthorized access
+        with pytest.raises(R2RException) as exc_info:
+            await non_owner_client.retrieve_chunk(chunk_id)
+        assert exc_info.value.status_code == 403
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_with_filters(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test listing chunks with owner_id filter."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        try:
+            # Create a document with chunks
+            doc_id, _ = await test_client.create_document(
+                ["Test chunk 1", "Test chunk 2"]
+            )
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # Test listing chunks (filters automatically applied on server)
+            response = await test_client.client.chunks.list(offset=0, limit=1)
+
+            assert "results" in response, "Expected 'results' in response"
+            # assert "page_info" in response, "Expected 'page_info' in response"
+            assert (
+                len(response["results"]) <= 1
+            ), "Expected at most 1 result due to limit"
+
+            if len(response["results"]) > 0:
+                # Verify we only get chunks owned by our temp user
+                chunk = response["results"][0]
+                chunks = await test_client.list_chunks(doc_id)
+                assert chunk["owner_id"] in [
+                    c["owner_id"] for c in chunks
+                ], "Got chunk from wrong owner"
+
+        finally:
+            # Cleanup
+            try:
+                await test_client.delete_document(doc_id)
+            except:
+                pass
+            await test_client.logout_user()
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_pagination(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test chunk listing with pagination."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        doc_id = None
+        try:
+            # Create a document with multiple chunks
+            chunks = [f"Test chunk {i}" for i in range(5)]
+            doc_id, _ = await test_client.create_document(chunks)
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # Test first page
+            response1 = await test_client.client.chunks.list(offset=0, limit=2)
+
+            assert (
+                len(response1["results"]) == 2
+            ), "Expected 2 results on first page"
+            # assert response1["page_info"]["has_next"], "Expected more pages"
+
+            # Test second page
+            response2 = await test_client.client.chunks.list(offset=2, limit=2)
+
+            assert (
+                len(response2["results"]) == 2
+            ), "Expected 2 results on second page"
+
+            # Verify no duplicate results
+            ids_page1 = {chunk["id"] for chunk in response1["results"]}
+            ids_page2 = {chunk["id"] for chunk in response2["results"]}
+            assert not ids_page1.intersection(
+                ids_page2
+            ), "Found duplicate chunks across pages"
+
+        finally:
+            # Cleanup
+            if doc_id:
+                try:
+                    await test_client.delete_document(doc_id)
+                except:
+                    pass
+            await test_client.logout_user()
+
+    @pytest.mark.asyncio
+    async def test_list_chunks_with_multiple_documents(
+        self, test_client: AsyncR2RTestClient
+    ):
+        """Test listing chunks across multiple documents."""
+        # Create and login as temporary user
+        temp_email = f"{uuid.uuid4()}@example.com"
+        await test_client.register_user(temp_email, "password123")
+        await test_client.login_user(temp_email, "password123")
+
+        doc_ids = []
+        try:
+            # Create multiple documents
+            for i in range(2):
+                doc_id, _ = await test_client.create_document(
+                    [f"Doc {i} chunk 1", f"Doc {i} chunk 2"]
+                )
+                doc_ids.append(doc_id)
+
+            await asyncio.sleep(1)  # Wait for ingestion
+
+            # List all chunks
+            response = await test_client.client.chunks.list(offset=0, limit=10)
+
+            assert len(response["results"]) == 4, "Expected 4 total chunks"
+
+            # Verify all chunks belong to our documents
+            chunk_doc_ids = {
+                chunk["document_id"] for chunk in response["results"]
+            }
+            assert all(
+                str(doc_id) in chunk_doc_ids for doc_id in doc_ids
+            ), "Got chunks from wrong documents"
+
+        finally:
+            # Cleanup
+            for doc_id in doc_ids:
+                try:
+                    await test_client.delete_document(doc_id)
+                except:
+                    pass
+            await test_client.logout_user()
+
+
+if __name__ == "__main__":
+    pytest.main(["-v", "--asyncio-mode=auto"])

+ 205 - 0
tests/unit/test_collections.py

@@ -0,0 +1,205 @@
+import pytest
+import uuid
+from uuid import UUID
+from core.base.api.models import CollectionResponse
+from core.base import R2RException
+
+
+@pytest.mark.asyncio
+async def test_create_collection(collections_handler):
+    owner_id = uuid.uuid4()
+    resp = await collections_handler.create_collection(
+        owner_id=owner_id,
+        name="Test Collection",
+        description="A test collection",
+    )
+    assert isinstance(resp, CollectionResponse)
+    assert resp.name == "Test Collection"
+    assert resp.owner_id == owner_id
+    assert resp.description == "A test collection"
+
+
+@pytest.mark.asyncio
+async def test_create_collection_default_name(collections_handler):
+    owner_id = uuid.uuid4()
+    # If no name provided, should use default_collection_name from config
+    resp = await collections_handler.create_collection(owner_id=owner_id)
+    assert isinstance(resp, CollectionResponse)
+    assert resp.name is not None  # default collection name should be set
+    assert resp.owner_id == owner_id
+
+
+@pytest.mark.asyncio
+async def test_update_collection(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="Original Name", description="Original Desc"
+    )
+
+    updated = await collections_handler.update_collection(
+        collection_id=coll.id,
+        name="Updated Name",
+        description="New Description",
+    )
+    assert updated.name == "Updated Name"
+    assert updated.description == "New Description"
+    # user_count and document_count should be integers
+    assert isinstance(updated.user_count, int)
+    assert isinstance(updated.document_count, int)
+
+
+@pytest.mark.asyncio
+async def test_update_collection_no_fields(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="NoUpdate", description="No Update"
+    )
+
+    with pytest.raises(R2RException) as exc:
+        await collections_handler.update_collection(collection_id=coll.id)
+    assert exc.value.status_code == 400
+
+
+@pytest.mark.asyncio
+async def test_delete_collection_relational(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="ToDelete"
+    )
+
+    # Confirm existence
+    exists = await collections_handler.collection_exists(coll.id)
+    assert exists is True
+
+    await collections_handler.delete_collection_relational(coll.id)
+
+    exists = await collections_handler.collection_exists(coll.id)
+    assert exists is False
+
+
+@pytest.mark.asyncio
+async def test_collection_exists(collections_handler):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(owner_id=owner_id)
+    assert await collections_handler.collection_exists(coll.id) is True
+
+
+@pytest.mark.asyncio
+async def test_documents_in_collection(collections_handler, db_provider):
+    # Create a collection
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="DocCollection"
+    )
+
+    # Insert some documents related to this collection
+    # We'll directly insert into the documents table for simplicity
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents (id, collection_ids, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status)
+        VALUES ($1, $2, $3, 'txt', '{{}}', 'Test Doc', 'v1', 1234, 'pending', 'pending')
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, [coll.id], owner_id]
+    )
+
+    # Now fetch documents in collection
+    res = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(res["results"]) == 1
+    assert res["total_entries"] == 1
+    assert res["results"][0].id == doc_id
+    assert res["results"][0].title == "Test Doc"
+
+
+@pytest.mark.asyncio
+async def test_get_collections_overview(collections_handler, db_provider):
+    owner_id = uuid.uuid4()
+    coll1 = await collections_handler.create_collection(
+        owner_id=owner_id, name="Overview1"
+    )
+    coll2 = await collections_handler.create_collection(
+        owner_id=owner_id, name="Overview2"
+    )
+
+    overview = await collections_handler.get_collections_overview(
+        offset=0, limit=10
+    )
+    # There should be at least these two
+    ids = [c.id for c in overview["results"]]
+    assert coll1.id in ids
+    assert coll2.id in ids
+
+
+@pytest.mark.asyncio
+async def test_assign_document_to_collection_relational(
+    collections_handler, db_provider
+):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="Assign"
+    )
+
+    # Insert a doc
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents (id, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status, collection_ids)
+        VALUES ($1, $2, 'txt', '{{}}', 'Standalone Doc', 'v1', 10, 'pending', 'pending', ARRAY[]::uuid[])
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, owner_id]
+    )
+
+    # Assign this doc to the collection
+    await collections_handler.assign_document_to_collection_relational(
+        doc_id, coll.id
+    )
+
+    # Verify doc is now in collection
+    docs = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(docs["results"]) == 1
+    assert docs["results"][0].id == doc_id
+
+
+@pytest.mark.asyncio
+async def test_remove_document_from_collection_relational(
+    collections_handler, db_provider
+):
+    owner_id = uuid.uuid4()
+    coll = await collections_handler.create_collection(
+        owner_id=owner_id, name="RemoveDoc"
+    )
+
+    # Insert a doc already in collection
+    doc_id = uuid.uuid4()
+    insert_doc_query = f"""
+        INSERT INTO {db_provider.project_name}.documents
+        (id, owner_id, type, metadata, title, version, size_in_bytes, ingestion_status, extraction_status, collection_ids)
+        VALUES ($1, $2, 'txt', '{{}}'::jsonb, 'Another Doc', 'v1', 10, 'pending', 'pending', $3)
+    """
+    await db_provider.connection_manager.execute_query(
+        insert_doc_query, [doc_id, owner_id, [coll.id]]
+    )
+
+    # Remove it
+    await collections_handler.remove_document_from_collection_relational(
+        doc_id, coll.id
+    )
+
+    docs = await collections_handler.documents_in_collection(
+        coll.id, offset=0, limit=10
+    )
+    assert len(docs["results"]) == 0
+
+
+@pytest.mark.asyncio
+async def test_delete_nonexistent_collection(collections_handler):
+    non_existent_id = uuid.uuid4()
+    with pytest.raises(R2RException) as exc:
+        await collections_handler.delete_collection_relational(non_existent_id)
+    assert (
+        exc.value.status_code == 404
+    ), "Should raise 404 for non-existing collection"

+ 132 - 0
tests/unit/test_conversations.py

@@ -0,0 +1,132 @@
+import json
+import uuid
+from uuid import UUID
+
+import pytest
+
+from core.base import Message, R2RException
+from shared.api.models.management.responses import (
+    ConversationResponse,
+    MessageResponse,
+)
+
+
+@pytest.mark.asyncio
+async def test_create_conversation(conversations_handler):
+    resp = await conversations_handler.create_conversation()
+    assert isinstance(resp, ConversationResponse)
+    assert resp.id is not None
+    assert resp.created_at is not None
+
+
+@pytest.mark.asyncio
+async def test_create_conversation_with_user_and_name(conversations_handler):
+    user_id = uuid.uuid4()
+    resp = await conversations_handler.create_conversation(
+        user_id=user_id, name="Test Conv"
+    )
+    assert resp.id is not None
+    assert resp.created_at is not None
+    # There's no direct field for user_id in ConversationResponse,
+    # but we can verify by fetch:
+    # Just trust it for now since the handler doesn't return user_id directly.
+
+
+@pytest.mark.asyncio
+async def test_add_message(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="Hello!")
+    resp = await conversations_handler.add_message(conv_id, msg)
+    assert isinstance(resp, MessageResponse)
+    assert resp.id is not None
+    assert resp.message.content == "Hello!"
+
+
+@pytest.mark.asyncio
+async def test_add_message_with_parent(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    parent_msg = Message(role="user", content="Parent message")
+    parent_resp = await conversations_handler.add_message(conv_id, parent_msg)
+    parent_id = parent_resp.id
+
+    child_msg = Message(role="assistant", content="Child reply")
+    child_resp = await conversations_handler.add_message(
+        conv_id, child_msg, parent_id=parent_id
+    )
+    assert child_resp.id is not None
+    assert child_resp.message.content == "Child reply"
+
+
+@pytest.mark.asyncio
+async def test_edit_message(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    original_msg = Message(role="user", content="Original")
+    resp = await conversations_handler.add_message(conv_id, original_msg)
+    msg_id = resp.id
+
+    updated = await conversations_handler.edit_message(
+        msg_id, "Edited content"
+    )
+    assert updated["message"].content == "Edited content"
+    assert updated["metadata"]["edited"] is True
+
+
+@pytest.mark.asyncio
+async def test_update_message_metadata(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="Meta-test")
+    resp = await conversations_handler.add_message(conv_id, msg)
+    msg_id = resp.id
+
+    await conversations_handler.update_message_metadata(
+        msg_id, {"test_key": "test_value"}
+    )
+
+    # Verify metadata updated
+    full_conversation = await conversations_handler.get_conversation(conv_id)
+    for m in full_conversation:
+        if m.id == str(msg_id):
+            assert m.metadata["test_key"] == "test_value"
+            break
+
+
+@pytest.mark.asyncio
+async def test_get_conversation(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg1 = Message(role="user", content="Msg1")
+    msg2 = Message(role="assistant", content="Msg2")
+
+    await conversations_handler.add_message(conv_id, msg1)
+    await conversations_handler.add_message(conv_id, msg2)
+
+    messages = await conversations_handler.get_conversation(conv_id)
+    assert len(messages) == 2
+    assert messages[0].message.content == "Msg1"
+    assert messages[1].message.content == "Msg2"
+
+
+@pytest.mark.asyncio
+async def test_delete_conversation(conversations_handler):
+    conv = await conversations_handler.create_conversation()
+    conv_id = conv.id
+
+    msg = Message(role="user", content="To be deleted")
+    await conversations_handler.add_message(conv_id, msg)
+
+    await conversations_handler.delete_conversation(conv_id)
+
+    with pytest.raises(R2RException) as exc:
+        await conversations_handler.get_conversation(conv_id)
+    assert (
+        exc.value.status_code == 404
+    ), "Conversation should be deleted and not found"

+ 143 - 0
tests/unit/test_documents.py

@@ -0,0 +1,143 @@
+import json
+import uuid
+from uuid import UUID
+
+import pytest
+
+from core.base import (
+    DocumentResponse,
+    DocumentType,
+    IngestionStatus,
+    KGExtractionStatus,
+    R2RException,
+    SearchSettings,
+)
+
+
+def make_db_entry(doc: DocumentResponse):
+    # This simulates what your real code should do:
+    return {
+        "id": doc.id,
+        "collection_ids": doc.collection_ids,
+        "owner_id": doc.owner_id,
+        "document_type": doc.document_type.value,
+        "metadata": json.dumps(doc.metadata),
+        "title": doc.title,
+        "version": doc.version,
+        "size_in_bytes": doc.size_in_bytes,
+        "ingestion_status": doc.ingestion_status.value,
+        "extraction_status": doc.extraction_status.value,
+        "created_at": doc.created_at,
+        "updated_at": doc.updated_at,
+        "ingestion_attempt_number": 0,
+        "summary": doc.summary,
+        # If summary_embedding is a list, we can store it as a string here if needed
+        "summary_embedding": (
+            str(doc.summary_embedding)
+            if doc.summary_embedding is not None
+            else None
+        ),
+    }
+
+
+@pytest.mark.asyncio
+async def test_upsert_documents_overview_insert(documents_handler):
+    doc_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=uuid.uuid4(),
+        document_type=DocumentType.TXT,
+        metadata={"description": "A test document"},
+        title="Test Doc",
+        version="v1",
+        size_in_bytes=1234,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    # Simulate the handler call
+    await documents_handler.upsert_documents_overview(
+        [doc]
+    )  # adjust your handler to accept list or doc
+    # If your handler expects a db entry dict, you may need to patch handler or adapt your code
+
+    # Verify
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    assert res["total_entries"] == 1
+    fetched_doc = res["results"][0]
+    assert fetched_doc.id == doc_id
+    assert fetched_doc.title == "Test Doc"
+    assert fetched_doc.metadata["description"] == "A test document"
+
+
+@pytest.mark.asyncio
+async def test_upsert_documents_overview_update(documents_handler):
+    doc_id = uuid.uuid4()
+    owner_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=owner_id,
+        document_type=DocumentType.TXT,
+        metadata={"note": "initial"},
+        title="Initial Title",
+        version="v1",
+        size_in_bytes=100,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    await documents_handler.upsert_documents_overview([doc])
+
+    # Update document
+    doc.title = "Updated Title"
+    doc.metadata["note"] = "updated"
+
+    await documents_handler.upsert_documents_overview([doc])
+
+    # Verify update
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    fetched_doc = res["results"][0]
+    assert fetched_doc.title == "Updated Title"
+    assert fetched_doc.metadata["note"] == "updated"
+
+
+@pytest.mark.asyncio
+async def test_delete_document(documents_handler):
+    doc_id = uuid.uuid4()
+    doc = DocumentResponse(
+        id=doc_id,
+        collection_ids=[],
+        owner_id=uuid.uuid4(),
+        document_type=DocumentType.TXT,
+        metadata={},
+        title="ToDelete",
+        version="v1",
+        size_in_bytes=100,
+        ingestion_status=IngestionStatus.PENDING,
+        extraction_status=KGExtractionStatus.PENDING,
+        created_at=None,
+        updated_at=None,
+        summary=None,
+        summary_embedding=None,
+    )
+
+    await documents_handler.upsert_documents_overview([doc])
+    await documents_handler.delete(doc_id)
+    res = await documents_handler.get_documents_overview(
+        offset=0, limit=10, filter_document_ids=[doc_id]
+    )
+    assert res["total_entries"] == 0

+ 449 - 0
tests/unit/test_graphs.py

@@ -0,0 +1,449 @@
+import pytest
+import uuid
+from uuid import UUID
+
+from enum import Enum
+from core.base.abstractions import Entity, Relationship, Community
+from core.base.api.models import GraphResponse
+
+
+class StoreType(str, Enum):
+    GRAPHS = "graphs"
+    DOCUMENTS = "documents"
+
+
+@pytest.mark.asyncio
+async def test_create_graph(graphs_handler):
+    coll_id = uuid.uuid4()
+    resp = await graphs_handler.create(
+        collection_id=coll_id, name="My Graph", description="Test Graph"
+    )
+    assert isinstance(resp, GraphResponse)
+    assert resp.name == "My Graph"
+    assert resp.collection_id == coll_id
+
+
+@pytest.mark.asyncio
+async def test_add_entities_and_relationships(graphs_handler):
+    # Create a graph
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="TestGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add an entity
+    entity = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="TestEntity",
+        category="Person",
+        description="A test entity",
+    )
+    assert entity.name == "TestEntity"
+
+    # Add another entity
+    entity2 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="AnotherEntity",
+        category="Place",
+        description="A test place",
+    )
+
+    # Add a relationship between them
+    rel = await graphs_handler.relationships.create(
+        subject="TestEntity",
+        subject_id=entity.id,
+        predicate="lives_in",
+        object="AnotherEntity",
+        object_id=entity2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        description="Entity lives in AnotherEntity",
+    )
+    assert rel.predicate == "lives_in"
+
+    # Verify entities retrieval
+    ents, total_ents = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total_ents == 2
+    names = [e.name for e in ents]
+    assert "TestEntity" in names and "AnotherEntity" in names
+
+    # Verify relationships retrieval
+    rels, total_rels = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total_rels == 1
+    assert rels[0].predicate == "lives_in"
+
+
+@pytest.mark.asyncio
+async def test_delete_entities_and_relationships(graphs_handler):
+    # Create another graph
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeletableGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add entities
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="DeleteMe",
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+        name="DeleteMeToo",
+    )
+
+    # Add relationship
+    rel = await graphs_handler.relationships.create(
+        subject="DeleteMe",
+        subject_id=e1.id,
+        predicate="related_to",
+        object="DeleteMeToo",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Delete one entity
+    await graphs_handler.entities.delete(
+        parent_id=graph_id,
+        entity_ids=[e1.id],
+        store_type=StoreType.GRAPHS.value,
+    )
+    ents, count = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert count == 1
+    assert ents[0].id == e2.id
+
+    # Delete the relationship
+    await graphs_handler.relationships.delete(
+        parent_id=graph_id,
+        relationship_ids=[rel.id],
+        store_type=StoreType.GRAPHS.value,
+    )
+    rels, rel_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert rel_count == 0
+
+
+@pytest.mark.asyncio
+async def test_communities(graphs_handler):
+    # Insert a community for a collection_id (not strictly related to a graph_id)
+    coll_id = uuid.uuid4()
+    await graphs_handler.communities.create(
+        parent_id=coll_id,
+        store_type=StoreType.GRAPHS.value,
+        name="CommunityOne",
+        summary="Test community",
+        findings=["finding1", "finding2"],
+        rating=4.5,
+        rating_explanation="Excellent",
+        description_embedding=[0.1, 0.2, 0.3, 0.4],
+    )
+
+    comms, count = await graphs_handler.communities.get(
+        parent_id=coll_id,
+        store_type=StoreType.GRAPHS.value,
+        offset=0,
+        limit=10,
+    )
+    assert count == 1
+    assert comms[0].name == "CommunityOne"
+
+
+# TODO - Fix code such that these tests pass
+# # @pytest.mark.asyncio
+# # async def test_delete_graph(graphs_handler):
+# #     # Create a graph and then delete it
+# #     coll_id = uuid.uuid4()
+# #     graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
+# #     graph_id = graph_resp.id
+
+# #     # reset or delete calls are complicated in the code. We'll just call `reset` and `delete`
+# #     await graphs_handler.reset(graph_id)
+# #     # This should remove all entities & relationships from the graph_id
+
+# #     # Now delete the graph itself
+# #     # The `delete` method seems to be tied to collection_id rather than graph_id
+# #     await graphs_handler.delete(collection_id=graph_id, cascade=False)
+# #     # If the code is structured so that delete requires a collection_id,
+# #     # ensure `graph_id == collection_id` or adapt the code accordingly.
+
+# #     # Try fetching the graph
+# #     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
+# #     assert overview["total_entries"] == 0, "Graph should be deleted"
+
+
+# @pytest.mark.asyncio
+# async def test_delete_graph(graphs_handler):
+#     # Create a graph and then delete it
+#     coll_id = uuid.uuid4()
+#     graph_resp = await graphs_handler.create(collection_id=coll_id, name="TempGraph")
+#     graph_id = graph_resp.id
+
+#     # Reset the graph (remove entities, relationships, communities)
+#     await graphs_handler.reset(graph_id)
+
+#     # Now delete the graph using collection_id (which equals graph_id in this code)
+#     await graphs_handler.delete(collection_id=coll_id)
+
+#     # Verify the graph is deleted
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[coll_id])
+#     assert overview["total_entries"] == 0, "Graph should be deleted"
+
+
+@pytest.mark.asyncio
+async def test_create_graph_defaults(graphs_handler):
+    # Create a graph without specifying name or description
+    coll_id = uuid.uuid4()
+    resp = await graphs_handler.create(collection_id=coll_id)
+    assert resp.collection_id == coll_id
+    # The code sets a default name, which should be "Graph {coll_id}"
+    assert resp.name == f"Graph {coll_id}"
+    # Default description should be empty string as per code
+    assert resp.description == ""
+
+
+# @pytest.mark.asyncio
+# async def test_list_multiple_graphs(graphs_handler):
+#     # Create multiple graphs
+#     coll_id1 = uuid.uuid4()
+#     coll_id2 = uuid.uuid4()
+#     graph_resp1 = await graphs_handler.create(collection_id=coll_id1, name="Graph1")
+#     graph_resp2 = await graphs_handler.create(collection_id=coll_id2, name="Graph2")
+#     graph_resp3 = await graphs_handler.create(collection_id=coll_id2, name="Graph3")
+
+#     # List all graphs without filters
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10)
+#     # Ensure at least these three are in there
+#     found_ids = [g.id for g in overview["results"]]
+#     assert graph_resp1.id in found_ids
+#     assert graph_resp2.id in found_ids
+#     assert graph_resp3.id in found_ids
+
+#     # Filter by collection_id = coll_id2 should return Graph2 and Graph3 (the most recent one first if same collection)
+#     overview_coll2 = await graphs_handler.list_graphs(offset=0, limit=10, filter_collection_id=coll_id2)
+#     returned_ids = [g.id for g in overview_coll2["results"]]
+#     # According to the code, we only see the "most recent" graph per collection. Verify this logic.
+#     # If your code is returning only the most recent graph per collection, we should see only one graph per collection_id here.
+#     # Adjust test according to actual logic you desire.
+#     # For this example, let's assume we should only get the latest graph per collection. Graph3 should be newer than Graph2.
+#     assert len(returned_ids) == 1
+#     assert graph_resp3.id in returned_ids
+
+
+@pytest.mark.asyncio
+async def test_update_graph(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="OldName", description="OldDescription"
+    )
+    graph_id = graph_resp.id
+
+    # Update name and description
+    updated_resp = await graphs_handler.update(
+        collection_id=graph_id, name="NewName", description="NewDescription"
+    )
+    assert updated_resp.name == "NewName"
+    assert updated_resp.description == "NewDescription"
+
+    # Retrieve and verify
+    overview = await graphs_handler.list_graphs(
+        offset=0, limit=10, filter_graph_ids=[graph_id]
+    )
+    assert overview["total_entries"] == 1
+    fetched_graph = overview["results"][0]
+    assert fetched_graph.name == "NewName"
+    assert fetched_graph.description == "NewDescription"
+
+
+@pytest.mark.asyncio
+async def test_bulk_entities(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="BulkEntities"
+    )
+    graph_id = graph_resp.id
+
+    # Add multiple entities
+    entities_to_add = [
+        {"name": "EntityA", "category": "CategoryA", "description": "DescA"},
+        {"name": "EntityB", "category": "CategoryB", "description": "DescB"},
+        {"name": "EntityC", "category": "CategoryC", "description": "DescC"},
+    ]
+    for ent in entities_to_add:
+        await graphs_handler.entities.create(
+            parent_id=graph_id,
+            store_type=StoreType.GRAPHS.value,
+            name=ent["name"],
+            category=ent["category"],
+            description=ent["description"],
+        )
+
+    ents, total = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert total == 3
+    fetched_names = [e.name for e in ents]
+    for ent in entities_to_add:
+        assert ent["name"] in fetched_names
+
+
+@pytest.mark.asyncio
+async def test_relationship_filtering(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="RelFilteringGraph"
+    )
+    graph_id = graph_resp.id
+
+    # Add entities
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node1"
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node2"
+    )
+    e3 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="Node3"
+    )
+
+    # Add different relationships
+    await graphs_handler.relationships.create(
+        subject="Node1",
+        subject_id=e1.id,
+        predicate="connected_to",
+        object="Node2",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    await graphs_handler.relationships.create(
+        subject="Node2",
+        subject_id=e2.id,
+        predicate="linked_with",
+        object="Node3",
+        object_id=e3.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Get all relationships
+    all_rels, all_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert all_count == 2
+
+    # Filter by relationship_type = ["connected_to"]
+    filtered_rels, filt_count = await graphs_handler.get_relationships(
+        parent_id=graph_id,
+        offset=0,
+        limit=10,
+        relationship_types=["connected_to"],
+    )
+    assert filt_count == 1
+    assert filtered_rels[0].predicate == "connected_to"
+
+
+@pytest.mark.asyncio
+async def test_delete_all_entities(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeleteAllEntities"
+    )
+    graph_id = graph_resp.id
+
+    # Add some entities
+    await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+    )
+    await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+    )
+
+    # Delete all entities without specifying IDs
+    await graphs_handler.entities.delete(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+    )
+    ents, count = await graphs_handler.get_entities(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert count == 0
+
+
+@pytest.mark.asyncio
+async def test_delete_all_relationships(graphs_handler):
+    coll_id = uuid.uuid4()
+    graph_resp = await graphs_handler.create(
+        collection_id=coll_id, name="DeleteAllRels"
+    )
+    graph_id = graph_resp.id
+
+    # Add two entities and a relationship
+    e1 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E1"
+    )
+    e2 = await graphs_handler.entities.create(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value, name="E2"
+    )
+    await graphs_handler.relationships.create(
+        subject="E1",
+        subject_id=e1.id,
+        predicate="connected",
+        object="E2",
+        object_id=e2.id,
+        parent_id=graph_id,
+        store_type=StoreType.GRAPHS.value,
+    )
+
+    # Delete all relationships
+    await graphs_handler.relationships.delete(
+        parent_id=graph_id, store_type=StoreType.GRAPHS.value
+    )
+    rels, rel_count = await graphs_handler.get_relationships(
+        parent_id=graph_id, offset=0, limit=10
+    )
+    assert rel_count == 0
+
+
+@pytest.mark.asyncio
+async def test_error_handling_invalid_graph_id(graphs_handler):
+    # Attempt to get a non-existent graph
+    non_existent_id = uuid.uuid4()
+    overview = await graphs_handler.list_graphs(
+        offset=0, limit=10, filter_graph_ids=[non_existent_id]
+    )
+    assert overview["total_entries"] == 0
+
+    # Attempt to delete a non-existent graph
+    with pytest.raises(Exception) as exc_info:
+        await graphs_handler.delete(collection_id=non_existent_id)
+    # Expect an R2RException or HTTPException (depending on your code)
+    # Check the message or type if needed
+
+
+# TODO - Fix code to pass this test.
+# @pytest.mark.asyncio
+# async def test_delete_graph_cascade(graphs_handler):
+#     coll_id = uuid.uuid4()
+#     graph_resp = await graphs_handler.create(collection_id=coll_id, name="CascadeGraph")
+#     graph_id = graph_resp.id
+
+#     # Add entities/relationships here if you have documents attached
+#     # This test would verify that cascade=True behavior is correct
+#     # For now, just call delete with cascade=True
+#     # Depending on your implementation, you might need documents associated with the collection to test fully.
+#     await graphs_handler.delete(collection_id=coll_id)
+#     overview = await graphs_handler.list_graphs(offset=0, limit=10, filter_graph_ids=[graph_id])
+#     assert overview["total_entries"] == 0

+ 249 - 0
tests/unit/test_limits.py

@@ -0,0 +1,249 @@
+import uuid
+from datetime import datetime, timedelta, timezone
+from uuid import UUID
+
+import pytest
+
+from core.base import LimitSettings
+from core.database.postgres import PostgresLimitsHandler
+from shared.abstractions import User
+
+
+@pytest.mark.asyncio
+async def test_log_request_and_count(limits_handler):
+    """
+    Test that when we log requests, the count increments, and rate-limits are enforced.
+    Route-specific test using the /v3/retrieval/search endpoint limits.
+    """
+    # Clear existing logs first
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/v3/retrieval/search"  # Using actual route from config
+    test_user = User(
+        id=user_id,
+        email="test@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    # Set route limit to match config: 5 requests per minute
+    old_route_limits = limits_handler.config.route_limits
+    new_route_limits = {
+        route: LimitSettings(route_per_min=5, monthly_limit=10)
+    }
+    limits_handler.config.route_limits = new_route_limits
+
+    print(f"\nTesting with route limits: {new_route_limits}")
+    print(f"Route settings: {limits_handler.config.route_limits[route]}")
+
+    try:
+        # Initial check should pass (no requests yet)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial check passed (no requests)")
+
+        # Log 5 requests (exactly at limit)
+        for i in range(5):
+            await limits_handler.log_request(user_id, route)
+            now = datetime.now(timezone.utc)
+            one_min_ago = now - timedelta(minutes=1)
+            route_count = await limits_handler._count_requests(
+                user_id, route, one_min_ago
+            )
+            print(f"Route count after request {i+1}: {route_count}")
+
+            # This should pass for all 5 requests
+            await limits_handler.check_limits(test_user, route)
+            print(f"Check limits passed after request {i+1}")
+
+        # Log the 6th request (over limit)
+        await limits_handler.log_request(user_id, route)
+        route_count = await limits_handler._count_requests(
+            user_id, route, one_min_ago
+        )
+        print(f"Route count after request 6: {route_count}")
+
+        # This check should fail as we've exceeded route_per_min=5
+        with pytest.raises(
+            ValueError, match="Per-route per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.route_limits = old_route_limits
+
+
+@pytest.mark.asyncio
+async def test_global_limit(limits_handler):
+    """
+    Test global limit using the configured limit of 10 requests per minute
+    """
+    # Clear existing logs
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/global-test"
+    test_user = User(
+        id=user_id,
+        email="globaltest@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    # Set global limit to match config: 10 requests per minute
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(
+        global_per_min=10, monthly_limit=20
+    )
+
+    try:
+        # Initial check should pass (no requests)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial global check passed (no requests)")
+
+        # Log 10 requests (hits the limit)
+        for i in range(11):
+            await limits_handler.log_request(user_id, route)
+
+        # Debug counts
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+        global_count = await limits_handler._count_requests(
+            user_id, None, one_min_ago
+        )
+        print(f"Global count after 10 requests: {global_count}")
+
+        # This should fail as we've hit global_per_min=10
+        with pytest.raises(
+            ValueError, match="Global per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.limits = old_limits
+
+
+@pytest.mark.asyncio
+async def test_monthly_limit(limits_handler):
+    """
+    Test monthly limit using the configured limit of 20 requests per month
+    """
+    # Clear existing logs
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    user_id = uuid.uuid4()
+    route = "/monthly-test"
+    test_user = User(
+        id=user_id,
+        email="monthly@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides=None,
+    )
+
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(monthly_limit=20)
+
+    try:
+        # Initial check should pass (no requests)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial monthly check passed (no requests)")
+
+        # Log 20 requests (hits the monthly limit)
+        for i in range(21):
+            await limits_handler.log_request(user_id, route)
+
+        # Get current month's count
+        now = datetime.now(timezone.utc)
+        first_of_month = now.replace(
+            day=1, hour=0, minute=0, second=0, microsecond=0
+        )
+        monthly_count = await limits_handler._count_requests(
+            user_id, None, first_of_month
+        )
+        print(f"Monthly count after 20 requests: {monthly_count}")
+
+        # This should fail as we've hit monthly_limit=20
+        with pytest.raises(ValueError, match="Monthly rate limit exceeded"):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        limits_handler.config.limits = old_limits
+
+
+@pytest.mark.asyncio
+async def test_user_level_override(limits_handler):
+    """
+    Test user-specific override limits with debug logging
+    """
+    user_id = UUID("47e53676-b478-5b3f-a409-234ca2164de5")
+    route = "/test-route"
+
+    # Clear existing logs first
+    clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
+    await limits_handler.connection_manager.execute_query(clear_query)
+
+    test_user = User(
+        id=user_id,
+        email="override@example.com",
+        is_active=True,
+        is_verified=True,
+        is_superuser=False,
+        limits_overrides={
+            "global_per_min": 2,
+            "route_per_min": 1,
+            "route_overrides": {"/test-route": {"route_per_min": 1}},
+        },
+    )
+
+    # Set default limits that should be overridden
+    old_limits = limits_handler.config.limits
+    limits_handler.config.limits = LimitSettings(
+        global_per_min=10, monthly_limit=20
+    )
+
+    # Debug: Print current limits
+    print(f"\nDefault limits: {limits_handler.config.limits}")
+    print(f"User overrides: {test_user.limits_overrides}")
+
+    try:
+        # First check limits (should pass as no requests yet)
+        await limits_handler.check_limits(test_user, route)
+        print("Initial check passed (no requests yet)")
+
+        # Log first request
+        await limits_handler.log_request(user_id, route)
+
+        # Debug: Get current counts
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+        global_count = await limits_handler._count_requests(
+            user_id, None, one_min_ago
+        )
+        route_count = await limits_handler._count_requests(
+            user_id, route, one_min_ago
+        )
+        print(f"\nAfter first request:")
+        print(f"Global count: {global_count}")
+        print(f"Route count: {route_count}")
+
+        # Log second request
+        await limits_handler.log_request(user_id, route)
+
+        # This check should fail as we've hit route_per_min=1
+        with pytest.raises(
+            ValueError, match="Per-route per-minute rate limit exceeded"
+        ):
+            await limits_handler.check_limits(test_user, route)
+
+    finally:
+        # Cleanup
+        limits_handler.config.limits = old_limits