123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- import logging
- import os
- import zipfile
- from datetime import datetime
- from io import BytesIO
- from typing import BinaryIO, Optional
- from uuid import UUID
- import boto3
- from botocore.exceptions import ClientError
- from core.providers.file.postgres import PostgresFileProvider
- from core.base import FileConfig, FileProvider, R2RException
- logger = logging.getLogger()
- class S3FileProvider(FileProvider):
- """S3 implementation of the FileProvider."""
-
- coco_postgres_file_provider = None
- def __init__(self, config: FileConfig, postgres_file_provider: PostgresFileProvider):
- super().__init__(config)
- self.bucket_name = self.config.bucket_name or os.getenv("S3_BUCKET_NAME")
- aws_access_key_id = self.config.aws_access_key_id or os.getenv(
- "AWS_ACCESS_KEY_ID"
- )
- aws_secret_access_key = self.config.aws_secret_access_key or os.getenv(
- "AWS_SECRET_ACCESS_KEY"
- )
- region_name = self.config.region_name or os.getenv("AWS_REGION")
- endpoint_url = self.config.endpoint_url or os.getenv("S3_ENDPOINT_URL")
- # Initialize S3 client
- self.s3_client = boto3.client(
- "s3",
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=region_name,
- endpoint_url=endpoint_url,
- )
- self.coco_postgres_file_provider = postgres_file_provider
- def _get_s3_key(self, document_id: UUID) -> str:
- """Generate a unique S3 key for a document."""
- return f"documents/{document_id}"
- async def initialize(self) -> None:
- """Initialize S3 bucket."""
- try:
- self.s3_client.head_bucket(Bucket=self.bucket_name)
- logger.info(f"Using existing S3 bucket: {self.bucket_name}")
- except ClientError as e:
- error_code = e.response.get("Error", {}).get("Code")
- if error_code == "404":
- logger.info(f"Creating S3 bucket: {self.bucket_name}")
- self.s3_client.create_bucket(Bucket=self.bucket_name)
- else:
- logger.error(f"Error accessing S3 bucket: {e}")
- raise R2RException(
- status_code=500,
- message=f"Failed to initialize S3 bucket: {e}",
- ) from e
- async def store_file(
- self,
- document_id: UUID,
- file_name: str,
- file_content: BytesIO,
- file_type: Optional[str] = None,
- ) -> None:
- """Store a file in S3."""
- try:
- # Generate S3 key
- s3_key = self._get_s3_key(document_id)
- # filename.encode("ascii", "backslashreplace").decode()
- # Upload to S3
- file_content.seek(0) # Reset pointer to beginning
- print("=======file===============")
- print(file_name)
- self.s3_client.upload_fileobj(
- file_content,
- self.bucket_name,
- s3_key,
- ExtraArgs={
- "ACL": "public-read",
- "ContentType": file_type or "application/octet-stream",
- "Metadata": {
- "filename": file_name.encode(
- "ascii", "backslashreplace"
- ).decode(),
- "document_id": str(document_id),
- },
- },
- )
- except Exception as e:
- logger.error(f"Error storing file in S3: {e}")
- raise R2RException(
- status_code=500, message=f"Failed to store file in S3: {e}"
- ) from e
- async def retrieve_file(
- self, document_id: UUID
- ) -> Optional[tuple[str, BinaryIO, int]]:
- """Retrieve a file from S3."""
- s3_key = self._get_s3_key(document_id)
- try:
- # Get file metadata from S3
- response = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
- file_name = response.get("Metadata", {}).get(
- "filename", f"file-{document_id}"
- )
- file_size = response.get("ContentLength", 0)
- # Download file from S3
- file_content = BytesIO()
- self.s3_client.download_fileobj(self.bucket_name, s3_key, file_content)
- file_content.seek(0) # Reset pointer to beginning
- return (
- file_name.encode("ascii").decode("unicode-escape"),
- file_content,
- file_size,
- )
- except ClientError as e:
- try:
- return await self.coco_postgres_file_provider.retrieve_file(document_id)
- except Exception as e:
- raise R2RException(
- status_code=404,
- message=f"File for document {document_id} not found",
- ) from e
- '''
- error_code = e.response.get("Error", {}).get("Code")
- if error_code in ["NoSuchKey", "404"]:
- raise R2RException(
- status_code=404,
- message=f"File for document {document_id} not found",
- ) from e
- else:
- raise R2RException(
- status_code=500,
- message=f"Error retrieving file from S3: {e}",
- ) from e
- '''
- async def retrieve_files_as_zip(
- self,
- document_ids: Optional[list[UUID]] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- ) -> tuple[str, BinaryIO, int]:
- """Retrieve multiple files from S3 and return them as a zip file."""
- if not document_ids:
- raise R2RException(
- status_code=400,
- message="Document IDs must be provided for S3 file retrieval",
- )
- zip_buffer = BytesIO()
- with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
- for doc_id in document_ids:
- try:
- # Get file information - note that retrieve_file won't return None here
- # since any errors will raise exceptions
- result = await self.retrieve_file(doc_id)
- if result:
- file_name, file_content, _ = result
- # Read the content into a bytes object
- if hasattr(file_content, "getvalue"):
- content_bytes = file_content.getvalue()
- else:
- # For BinaryIO objects that don't have getvalue()
- file_content.seek(0)
- content_bytes = file_content.read()
- # Add file to zip
- zip_file.writestr(file_name, content_bytes)
- except R2RException as e:
- if e.status_code == 404:
- # Skip files that don't exist
- logger.warning(
- f"File for document {doc_id} not found, skipping"
- )
- continue
- else:
- raise
- zip_buffer.seek(0)
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- zip_filename = f"files_export_{timestamp}.zip"
- zip_size = zip_buffer.getbuffer().nbytes
- if zip_size == 0:
- raise R2RException(
- status_code=404,
- message="No files found for the specified document IDs",
- )
- return zip_filename, zip_buffer, zip_size
- async def delete_file(self, document_id: UUID) -> bool:
- """Delete a file from S3."""
- s3_key = self._get_s3_key(document_id)
- try:
- # Check if file exists first
- self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
- # Delete from S3
- self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
- return True
- except ClientError as e:
- try:
- return await self.coco_postgres_file_provider.delete_file(document_id)
- except Exception as e:
- raise R2RException(
- status_code=404,
- message=f"File for document {document_id} not found",
- ) from e
- '''
- error_code = e.response.get("Error", {}).get("Code")
- if error_code in ["NoSuchKey", "404"]:
- raise R2RException(
- status_code=404,
- message=f"File for document {document_id} not found",
- ) from e
- logger.error(f"Error deleting file from S3: {e}")
- raise R2RException(
- status_code=500, message=f"Failed to delete file from S3: {e}"
- ) from e
- '''
- async def get_files_overview(
- self,
- offset: int,
- limit: int,
- filter_document_ids: Optional[list[UUID]] = None,
- filter_file_names: Optional[list[str]] = None,
- ) -> list[dict]:
- """
- Get an overview of stored files.
- Note: Since S3 doesn't have native query capabilities like a database,
- this implementation works best when document IDs are provided.
- """
- results = []
- if filter_document_ids:
- # We can efficiently get specific files by document ID
- for doc_id in filter_document_ids:
- s3_key = self._get_s3_key(doc_id)
- try:
- # Get metadata for this file
- response = self.s3_client.head_object(
- Bucket=self.bucket_name, Key=s3_key
- )
- file_info = {
- "document_id": doc_id,
- "file_name": response.get("Metadata", {}).get(
- "filename", f"file-{doc_id}"
- ),
- "file_key": s3_key,
- "file_size": response.get("ContentLength", 0),
- "file_type": response.get("ContentType"),
- "created_at": response.get("LastModified"),
- "updated_at": response.get("LastModified"),
- }
- results.append(file_info)
- except ClientError:
- # Skip files that don't exist
- continue
- else:
- # This is a list operation on the bucket, which is less efficient
- # We list objects with the documents/ prefix
- try:
- response = self.s3_client.list_objects_v2(
- Bucket=self.bucket_name,
- Prefix="documents/",
- )
- if "Contents" in response:
- # Apply pagination manually
- page_items = response["Contents"][offset : offset + limit]
- for item in page_items:
- # Extract document ID from the key
- key = item["Key"]
- doc_id_str = key.split("/")[-1]
- try:
- doc_id = UUID(doc_id_str)
- # Get detailed metadata
- obj_response = self.s3_client.head_object(
- Bucket=self.bucket_name, Key=key
- )
- file_name = obj_response.get("Metadata", {}).get(
- "filename", f"file-{doc_id}"
- )
- # Apply filename filter if provided
- if filter_file_names and file_name not in filter_file_names:
- continue
- file_info = {
- "document_id": doc_id,
- "file_name": file_name,
- "file_key": key,
- "file_size": item.get("Size", 0),
- "file_type": obj_response.get("ContentType"),
- "created_at": item.get("LastModified"),
- "updated_at": item.get("LastModified"),
- }
- results.append(file_info)
- except ValueError:
- # Skip if the key doesn't contain a valid UUID
- continue
- except ClientError as e:
- try:
- return await self.coco_postgres_file_provider.get_files_overview(
- offset, limit, filter_document_ids, filter_file_names
- )
- except Exception as e:
- raise R2RException(
- status_code=500,
- message=f"Failed to list files from S3: {e}",
- ) from e
- '''
- logger.error(f"Error listing files in S3 bucket: {e}")
- raise R2RException(
- status_code=500,
- message=f"Failed to list files from S3: {e}",
- ) from e
- '''
- if not results:
- raise R2RException(
- status_code=404,
- message="No files found with the given filters",
- )
- return results
|