authentication.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import logging
  2. from enum import Enum
  3. from typing import Optional, Dict
  4. from pydantic import BaseModel, Field, model_validator
  5. from app.utils import aes_encrypt, aes_decrypt
  6. logger = logging.getLogger(__name__)
  7. __all__ = ["Authentication", "AuthenticationType"]
  8. # This class utilizes code from the Open Source Project TaskingAI.
  9. # The original code can be found at: https://github.com/TaskingAI/TaskingAI
  10. class AuthenticationType(str, Enum):
  11. bearer = "bearer"
  12. basic = "basic"
  13. custom = "custom"
  14. none = "none"
  15. # This function code from the Open Source Project TaskingAI.
  16. # The original code can be found at: https://github.com/TaskingAI/TaskingAI
  17. def validate_authentication_data(data: Dict):
  18. if not isinstance(data, dict):
  19. raise ValueError("Authentication should be a dict.")
  20. if "type" not in data or not data.get("type"):
  21. raise ValueError("Type is required for authentication.")
  22. if data["type"] == AuthenticationType.custom:
  23. if "content" not in data or data["content"] is None:
  24. raise ValueError("Content is required for custom authentication.")
  25. elif data["type"] == AuthenticationType.bearer:
  26. if "secret" not in data or data["secret"] is None:
  27. raise ValueError(f'Secret is required for {data["type"]} authentication.')
  28. elif data["type"] == AuthenticationType.basic:
  29. if "secret" not in data or data["secret"] is None:
  30. raise ValueError(f'Secret is required for {data["type"]} authentication.')
  31. # assume the secret is a base64 encoded string
  32. elif data["type"] == AuthenticationType.none:
  33. data["secret"] = None
  34. data["content"] = None
  35. return data
  36. # This class utilizes code from the Open Source Project TaskingAI.
  37. # The original code can be found at: https://github.com/TaskingAI/TaskingAI
  38. class Authentication(BaseModel):
  39. encrypted: bool = Field(False)
  40. type: AuthenticationType = Field(...)
  41. secret: Optional[str] = Field(None, min_length=1, max_length=1024)
  42. content: Optional[Dict] = Field(None)
  43. @model_validator(mode="before")
  44. def validate_all_fields_at_the_same_time(cls, data: Dict):
  45. data = validate_authentication_data(data)
  46. return data
  47. def is_encrypted(self):
  48. return self.encrypted or self.type == AuthenticationType.none
  49. def encrypt(self):
  50. if self.encrypted or self.type == AuthenticationType.none:
  51. return
  52. if self.secret is not None:
  53. self.secret = aes_encrypt(self.secret)
  54. if self.content is not None:
  55. for key in self.content:
  56. self.content[key] = aes_encrypt(self.content[key])
  57. self.encrypted = True
  58. def decrypt(self):
  59. if not self.encrypted or self.type == AuthenticationType.none:
  60. return
  61. if self.secret is not None:
  62. self.secret = aes_decrypt(self.secret)
  63. if self.content is not None:
  64. for key in self.content:
  65. self.content[key] = aes_decrypt(self.content[key])
  66. self.encrypted = False