s3.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import logging
  2. import os
  3. import zipfile
  4. from datetime import datetime
  5. from io import BytesIO
  6. from typing import BinaryIO, Optional
  7. from uuid import UUID
  8. import boto3
  9. from botocore.exceptions import ClientError
  10. from core.providers.file.postgres import PostgresFileProvider
  11. from core.base import FileConfig, FileProvider, R2RException
  12. logger = logging.getLogger()
  13. class S3FileProvider(FileProvider):
  14. """S3 implementation of the FileProvider."""
  15. coco_postgres_file_provider = None
  16. def __init__(self, config: FileConfig, postgres_file_provider: PostgresFileProvider):
  17. super().__init__(config)
  18. self.bucket_name = self.config.bucket_name or os.getenv("S3_BUCKET_NAME")
  19. aws_access_key_id = self.config.aws_access_key_id or os.getenv(
  20. "AWS_ACCESS_KEY_ID"
  21. )
  22. aws_secret_access_key = self.config.aws_secret_access_key or os.getenv(
  23. "AWS_SECRET_ACCESS_KEY"
  24. )
  25. region_name = self.config.region_name or os.getenv("AWS_REGION")
  26. endpoint_url = self.config.endpoint_url or os.getenv("S3_ENDPOINT_URL")
  27. # Initialize S3 client
  28. self.s3_client = boto3.client(
  29. "s3",
  30. aws_access_key_id=aws_access_key_id,
  31. aws_secret_access_key=aws_secret_access_key,
  32. region_name=region_name,
  33. endpoint_url=endpoint_url,
  34. )
  35. self.coco_postgres_file_provider = postgres_file_provider
  36. def _get_s3_key(self, document_id: UUID) -> str:
  37. """Generate a unique S3 key for a document."""
  38. return f"documents/{document_id}"
  39. async def initialize(self) -> None:
  40. """Initialize S3 bucket."""
  41. try:
  42. self.s3_client.head_bucket(Bucket=self.bucket_name)
  43. logger.info(f"Using existing S3 bucket: {self.bucket_name}")
  44. except ClientError as e:
  45. error_code = e.response.get("Error", {}).get("Code")
  46. if error_code == "404":
  47. logger.info(f"Creating S3 bucket: {self.bucket_name}")
  48. self.s3_client.create_bucket(Bucket=self.bucket_name)
  49. else:
  50. logger.error(f"Error accessing S3 bucket: {e}")
  51. raise R2RException(
  52. status_code=500,
  53. message=f"Failed to initialize S3 bucket: {e}",
  54. ) from e
  55. async def store_file(
  56. self,
  57. document_id: UUID,
  58. file_name: str,
  59. file_content: BytesIO,
  60. file_type: Optional[str] = None,
  61. ) -> None:
  62. """Store a file in S3."""
  63. try:
  64. # Generate S3 key
  65. s3_key = self._get_s3_key(document_id)
  66. # filename.encode("ascii", "backslashreplace").decode()
  67. # Upload to S3
  68. file_content.seek(0) # Reset pointer to beginning
  69. print("=======file===============")
  70. print(file_name)
  71. self.s3_client.upload_fileobj(
  72. file_content,
  73. self.bucket_name,
  74. s3_key,
  75. ExtraArgs={
  76. "ACL": "public-read",
  77. "ContentType": file_type or "application/octet-stream",
  78. "Metadata": {
  79. "filename": file_name.encode(
  80. "ascii", "backslashreplace"
  81. ).decode(),
  82. "document_id": str(document_id),
  83. },
  84. },
  85. )
  86. except Exception as e:
  87. logger.error(f"Error storing file in S3: {e}")
  88. raise R2RException(
  89. status_code=500, message=f"Failed to store file in S3: {e}"
  90. ) from e
  91. async def retrieve_file(
  92. self, document_id: UUID
  93. ) -> Optional[tuple[str, BinaryIO, int]]:
  94. """Retrieve a file from S3."""
  95. s3_key = self._get_s3_key(document_id)
  96. try:
  97. # Get file metadata from S3
  98. response = self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
  99. file_name = response.get("Metadata", {}).get(
  100. "filename", f"file-{document_id}"
  101. )
  102. file_size = response.get("ContentLength", 0)
  103. # Download file from S3
  104. file_content = BytesIO()
  105. self.s3_client.download_fileobj(self.bucket_name, s3_key, file_content)
  106. file_content.seek(0) # Reset pointer to beginning
  107. return (
  108. file_name.encode("ascii").decode("unicode-escape"),
  109. file_content,
  110. file_size,
  111. )
  112. except ClientError as e:
  113. try:
  114. return await self.coco_postgres_file_provider.retrieve_file(document_id)
  115. except Exception as e:
  116. raise R2RException(
  117. status_code=404,
  118. message=f"File for document {document_id} not found",
  119. ) from e
  120. '''
  121. error_code = e.response.get("Error", {}).get("Code")
  122. if error_code in ["NoSuchKey", "404"]:
  123. raise R2RException(
  124. status_code=404,
  125. message=f"File for document {document_id} not found",
  126. ) from e
  127. else:
  128. raise R2RException(
  129. status_code=500,
  130. message=f"Error retrieving file from S3: {e}",
  131. ) from e
  132. '''
  133. async def retrieve_files_as_zip(
  134. self,
  135. document_ids: Optional[list[UUID]] = None,
  136. start_date: Optional[datetime] = None,
  137. end_date: Optional[datetime] = None,
  138. ) -> tuple[str, BinaryIO, int]:
  139. """Retrieve multiple files from S3 and return them as a zip file."""
  140. if not document_ids:
  141. raise R2RException(
  142. status_code=400,
  143. message="Document IDs must be provided for S3 file retrieval",
  144. )
  145. zip_buffer = BytesIO()
  146. with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
  147. for doc_id in document_ids:
  148. try:
  149. # Get file information - note that retrieve_file won't return None here
  150. # since any errors will raise exceptions
  151. result = await self.retrieve_file(doc_id)
  152. if result:
  153. file_name, file_content, _ = result
  154. # Read the content into a bytes object
  155. if hasattr(file_content, "getvalue"):
  156. content_bytes = file_content.getvalue()
  157. else:
  158. # For BinaryIO objects that don't have getvalue()
  159. file_content.seek(0)
  160. content_bytes = file_content.read()
  161. # Add file to zip
  162. zip_file.writestr(file_name, content_bytes)
  163. except R2RException as e:
  164. if e.status_code == 404:
  165. # Skip files that don't exist
  166. logger.warning(
  167. f"File for document {doc_id} not found, skipping"
  168. )
  169. continue
  170. else:
  171. raise
  172. zip_buffer.seek(0)
  173. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  174. zip_filename = f"files_export_{timestamp}.zip"
  175. zip_size = zip_buffer.getbuffer().nbytes
  176. if zip_size == 0:
  177. raise R2RException(
  178. status_code=404,
  179. message="No files found for the specified document IDs",
  180. )
  181. return zip_filename, zip_buffer, zip_size
  182. async def delete_file(self, document_id: UUID) -> bool:
  183. """Delete a file from S3."""
  184. s3_key = self._get_s3_key(document_id)
  185. try:
  186. # Check if file exists first
  187. self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
  188. # Delete from S3
  189. self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
  190. return True
  191. except ClientError as e:
  192. try:
  193. return await self.coco_postgres_file_provider.delete_file(document_id)
  194. except Exception as e:
  195. raise R2RException(
  196. status_code=404,
  197. message=f"File for document {document_id} not found",
  198. ) from e
  199. '''
  200. error_code = e.response.get("Error", {}).get("Code")
  201. if error_code in ["NoSuchKey", "404"]:
  202. raise R2RException(
  203. status_code=404,
  204. message=f"File for document {document_id} not found",
  205. ) from e
  206. logger.error(f"Error deleting file from S3: {e}")
  207. raise R2RException(
  208. status_code=500, message=f"Failed to delete file from S3: {e}"
  209. ) from e
  210. '''
  211. async def get_files_overview(
  212. self,
  213. offset: int,
  214. limit: int,
  215. filter_document_ids: Optional[list[UUID]] = None,
  216. filter_file_names: Optional[list[str]] = None,
  217. ) -> list[dict]:
  218. """
  219. Get an overview of stored files.
  220. Note: Since S3 doesn't have native query capabilities like a database,
  221. this implementation works best when document IDs are provided.
  222. """
  223. results = []
  224. if filter_document_ids:
  225. # We can efficiently get specific files by document ID
  226. for doc_id in filter_document_ids:
  227. s3_key = self._get_s3_key(doc_id)
  228. try:
  229. # Get metadata for this file
  230. response = self.s3_client.head_object(
  231. Bucket=self.bucket_name, Key=s3_key
  232. )
  233. file_info = {
  234. "document_id": doc_id,
  235. "file_name": response.get("Metadata", {}).get(
  236. "filename", f"file-{doc_id}"
  237. ),
  238. "file_key": s3_key,
  239. "file_size": response.get("ContentLength", 0),
  240. "file_type": response.get("ContentType"),
  241. "created_at": response.get("LastModified"),
  242. "updated_at": response.get("LastModified"),
  243. }
  244. results.append(file_info)
  245. except ClientError:
  246. # Skip files that don't exist
  247. continue
  248. else:
  249. # This is a list operation on the bucket, which is less efficient
  250. # We list objects with the documents/ prefix
  251. try:
  252. response = self.s3_client.list_objects_v2(
  253. Bucket=self.bucket_name,
  254. Prefix="documents/",
  255. )
  256. if "Contents" in response:
  257. # Apply pagination manually
  258. page_items = response["Contents"][offset : offset + limit]
  259. for item in page_items:
  260. # Extract document ID from the key
  261. key = item["Key"]
  262. doc_id_str = key.split("/")[-1]
  263. try:
  264. doc_id = UUID(doc_id_str)
  265. # Get detailed metadata
  266. obj_response = self.s3_client.head_object(
  267. Bucket=self.bucket_name, Key=key
  268. )
  269. file_name = obj_response.get("Metadata", {}).get(
  270. "filename", f"file-{doc_id}"
  271. )
  272. # Apply filename filter if provided
  273. if filter_file_names and file_name not in filter_file_names:
  274. continue
  275. file_info = {
  276. "document_id": doc_id,
  277. "file_name": file_name,
  278. "file_key": key,
  279. "file_size": item.get("Size", 0),
  280. "file_type": obj_response.get("ContentType"),
  281. "created_at": item.get("LastModified"),
  282. "updated_at": item.get("LastModified"),
  283. }
  284. results.append(file_info)
  285. except ValueError:
  286. # Skip if the key doesn't contain a valid UUID
  287. continue
  288. except ClientError as e:
  289. try:
  290. return await self.coco_postgres_file_provider.get_files_overview(
  291. offset, limit, filter_document_ids, filter_file_names
  292. )
  293. except Exception as e:
  294. raise R2RException(
  295. status_code=500,
  296. message=f"Failed to list files from S3: {e}",
  297. ) from e
  298. '''
  299. logger.error(f"Error listing files in S3 bucket: {e}")
  300. raise R2RException(
  301. status_code=500,
  302. message=f"Failed to list files from S3: {e}",
  303. ) from e
  304. '''
  305. if not results:
  306. raise R2RException(
  307. status_code=404,
  308. message="No files found with the given filters",
  309. )
  310. return results