file.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import logging
  2. import os
  3. from abc import ABC, abstractmethod
  4. from datetime import datetime
  5. from io import BytesIO
  6. from typing import BinaryIO, Optional
  7. from uuid import UUID
  8. from .base import Provider, ProviderConfig
  9. logger = logging.getLogger()
  10. class FileConfig(ProviderConfig):
  11. """
  12. Configuration for file storage providers.
  13. """
  14. provider: Optional[str] = None
  15. # S3-specific configuration
  16. bucket_name: Optional[str] = None
  17. aws_access_key_id: Optional[str] = None
  18. aws_secret_access_key: Optional[str] = None
  19. region_name: Optional[str] = None
  20. endpoint_url: Optional[str] = None
  21. @property
  22. def supported_providers(self) -> list[str]:
  23. """
  24. List of supported file storage providers.
  25. """
  26. return [
  27. "postgres",
  28. "s3",
  29. ]
  30. def validate_config(self) -> None:
  31. if self.provider not in self.supported_providers:
  32. raise ValueError(f"Unsupported file provider: {self.provider}")
  33. if self.provider == "s3" and (
  34. not self.bucket_name and not os.getenv("S3_BUCKET_NAME")
  35. ):
  36. raise ValueError(
  37. "S3 bucket name is required when using S3 provider"
  38. )
  39. class FileProvider(Provider, ABC):
  40. """
  41. Base abstract class for file storage providers.
  42. """
  43. def __init__(self, config: FileConfig):
  44. if not isinstance(config, FileConfig):
  45. raise ValueError(
  46. "FileProvider must be initialized with a `FileConfig`."
  47. )
  48. super().__init__(config)
  49. self.config: FileConfig = config
  50. @abstractmethod
  51. async def initialize(self) -> None:
  52. """Initialize the file provider."""
  53. pass
  54. @abstractmethod
  55. async def store_file(
  56. self,
  57. document_id: UUID,
  58. file_name: str,
  59. file_content: BytesIO,
  60. file_type: Optional[str] = None,
  61. ) -> None:
  62. """Store a file."""
  63. pass
  64. @abstractmethod
  65. async def retrieve_file(
  66. self, document_id: UUID
  67. ) -> Optional[tuple[str, BinaryIO, int]]:
  68. """Retrieve a file."""
  69. pass
  70. @abstractmethod
  71. async def retrieve_files_as_zip(
  72. self,
  73. document_ids: Optional[list[UUID]] = None,
  74. start_date: Optional[datetime] = None,
  75. end_date: Optional[datetime] = None,
  76. ) -> tuple[str, BinaryIO, int]:
  77. """Retrieve multiple files as a zip."""
  78. pass
  79. @abstractmethod
  80. async def delete_file(self, document_id: UUID) -> bool:
  81. """Delete a file."""
  82. pass
  83. @abstractmethod
  84. async def get_files_overview(
  85. self,
  86. offset: int,
  87. limit: int,
  88. filter_document_ids: Optional[list[UUID]] = None,
  89. filter_file_names: Optional[list[str]] = None,
  90. ) -> list[dict]:
  91. """Get an overview of stored files."""
  92. pass