jack 4 mēneši atpakaļ
vecāks
revīzija
6e403e8693
1 mainītis faili ar 57 papildinājumiem un 18 dzēšanām
  1. 57 18
      app/providers/auth_provider.py

+ 57 - 18
app/providers/auth_provider.py

@@ -19,17 +19,23 @@ class AuthPolicy(object):
         enable auth policy
         """
 
-    def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
+    def insert_token_rel(
+        self, session, token_id: str, relation_type: RelationType, relation_id: str
+    ):
         """
         insert a token relation to database when enable token auth policy
         """
 
-    async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
+    async def delete_token_rel(
+        self, session, relation_type: RelationType, relation_id: str
+    ):
         """
         delete token relation when enable token auth policy
         """
 
-    def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
+    def token_filter(
+        self, statement, field, relation_type: RelationType, token_id: str
+    ):
         """
         add token filter clause when enable token auth policy
         """
@@ -48,23 +54,32 @@ class SimpleTokenAuthPolicy(AuthPolicy):
         from app.api.v1 import assistant, assistant_file, thread, message, runs, action
 
         verify_assistant_depends = Depends(
-            verify_token_relation(relation_type=RelationType.Assistant, name="assistant_id")
+            verify_token_relation(
+                relation_type=RelationType.Assistant, name="assistant_id"
+            )
         )
         # assistant router
         for route in assistant.router.routes:
-            if route.name == assistant.create_assistant.__name__ or route.name == assistant.list_assistants.__name__:
+            if (
+                route.name == assistant.create_assistant.__name__
+                or route.name == assistant.list_assistants.__name__
+            ):
                 route.dependencies.append(Depends(verfiy_token))
             else:
                 route.dependencies.append(verify_assistant_depends)
 
         # thread router
-        verify_thread_depends = Depends(verify_token_relation(relation_type=RelationType.Thread, name="thread_id"))
+        verify_thread_depends = Depends(
+            verify_token_relation(relation_type=RelationType.Thread, name="thread_id")
+        )
         for route in thread.router.routes:
             if route.name == thread.create_thread.__name__:
                 route.dependencies.append(
                     Depends(
                         verify_token_relation(
-                            relation_type=RelationType.Thread, name="thread_id", ignore_none_relation_id=True
+                            relation_type=RelationType.Thread,
+                            name="thread_id",
+                            ignore_none_relation_id=True,
                         )
                     )
                 )
@@ -72,30 +87,52 @@ class SimpleTokenAuthPolicy(AuthPolicy):
                 route.dependencies.append(verify_thread_depends)
 
         # action router
-        verify_action_depends = Depends(verify_token_relation(relation_type=RelationType.Action, name="action_id"))
+        verify_action_depends = Depends(
+            verify_token_relation(relation_type=RelationType.Action, name="action_id")
+        )
         for route in action.router.routes:
-            if route.name == action.create_actions.__name__ or route.name == action.list_actions.__name__:
+            if (
+                route.name == action.create_actions.__name__
+                or route.name == action.list_actions.__name__
+            ):
                 route.dependencies.append(Depends(verfiy_token))
             else:
                 route.dependencies.append(verify_action_depends)
 
-        self.__append_deps_for_all_routes(assistant_file.router, verify_assistant_depends)
+        self.__append_deps_for_all_routes(
+            assistant_file.router, verify_assistant_depends
+        )
         self.__append_deps_for_all_routes(message.router, verify_thread_depends)
         self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
 
-    def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
+    def insert_token_rel(
+        self, session, token_id: str, relation_type: RelationType, relation_id: str
+    ):
         if token_id:
-            relation = TokenRelation(token_id=token_id, relation_type=relation_type, relation_id=str(relation_id))
+            relation = TokenRelation(
+                token_id=token_id,
+                relation_type=relation_type,
+                relation_id=str(relation_id),
+            )
             session.add(relation)
 
-    async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
-        to_delete = TokenRelationDelete(relation_type=relation_type, relation_id=relation_id)
-        relation = await TokenRelationService.get_relation_to_delete(session=session, delete=to_delete)
+    async def delete_token_rel(
+        self, session, relation_type: RelationType, relation_id: str
+    ):
+        to_delete = TokenRelationDelete(
+            relation_type=relation_type, relation_id=relation_id
+        )
+        relation = await TokenRelationService.get_relation_to_delete(
+            session=session, delete=to_delete
+        )
         await session.delete(relation)
 
-    def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
+    def token_filter(
+        self, statement, field, relation_type: RelationType, token_id: str
+    ):
         id_subquery = select(TokenRelation.relation_id).where(
-            TokenRelation.relation_type == relation_type, TokenRelation.token_id == token_id
+            TokenRelation.relation_type == relation_type,
+            TokenRelation.token_id == token_id,
         )
         return statement.where(field.in_(id_subquery))
 
@@ -104,7 +141,9 @@ class SimpleTokenAuthPolicy(AuthPolicy):
             route.dependencies.append(depends)
 
 
-auth_policy: AuthPolicy = SimpleTokenAuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
+auth_policy: AuthPolicy = (
+    AuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
+)  ##SimpleTokenAuthPolicy()
 
 
 def register(app):