action.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import copy
  2. from typing import Dict, List
  3. from sqlmodel import select
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.exceptions.exception import ResourceNotFoundError, ValidateFailedError
  6. from app.models.action import Action
  7. from app.models.token_relation import RelationType
  8. from app.providers.auth_provider import auth_policy
  9. from app.schemas.common import DeleteResponse
  10. from app.schemas.tool.action import (
  11. ActionBulkCreateRequest,
  12. ActionUpdateRequest,
  13. ActionMethod,
  14. ActionBodyType,
  15. )
  16. from app.schemas.tool.authentication import Authentication
  17. from app.services.tool.openapi_call import call_action_api
  18. from app.services.tool.openapi_utils import (
  19. split_openapi_schema,
  20. replace_openapi_refs,
  21. function_name,
  22. extract_params,
  23. build_function_def,
  24. action_param_schema_to_dict,
  25. action_param_dict_to_schema,
  26. )
  27. class ActionService:
  28. @staticmethod
  29. async def create_actions(
  30. *, session: AsyncSession, body: ActionBulkCreateRequest, token_id: str = None
  31. ) -> List[Action]:
  32. openapi_schema = replace_openapi_refs(body.openapi_schema)
  33. schemas = split_openapi_schema(openapi_schema)
  34. if not schemas:
  35. raise ValidateFailedError("Failed to parse OpenAPI schema")
  36. if not body.authentication.is_encrypted():
  37. raise Exception("Authentication must be encrypted")
  38. actions = []
  39. for schema in schemas:
  40. action = ActionService.build_action_struct(schema)
  41. action.authentication = body.authentication.dict()
  42. action.use_for_everyone = body.use_for_everyone
  43. actions.append(action)
  44. auth_policy.insert_token_rel(
  45. session=session, token_id=token_id, relation_type=RelationType.Action, relation_id=str(action.id)
  46. )
  47. session.add_all(actions)
  48. await session.commit()
  49. for action in actions:
  50. await session.refresh(action)
  51. return actions
  52. @staticmethod
  53. def create_actions_sync(
  54. *, session: AsyncSession, body: ActionBulkCreateRequest, token_id: str = None
  55. ) -> List[Action]:
  56. openapi_schema = replace_openapi_refs(body.openapi_schema)
  57. schemas = split_openapi_schema(openapi_schema)
  58. if not schemas:
  59. raise ValidateFailedError("Failed to parse OpenAPI schema")
  60. if not body.authentication.is_encrypted():
  61. raise Exception("Authentication must be encrypted")
  62. actions = []
  63. for schema in schemas:
  64. action = ActionService.build_action_struct(schema)
  65. action.authentication = body.authentication.dict()
  66. action.use_for_everyone = body.use_for_everyone
  67. actions.append(action)
  68. auth_policy.insert_token_rel(
  69. session=session, token_id=token_id, relation_type=RelationType.Action, relation_id=str(action.id)
  70. )
  71. session.add_all(actions)
  72. session.commit()
  73. for action in actions:
  74. session.refresh(action)
  75. return actions
  76. @staticmethod
  77. async def modify_action(*, session: AsyncSession, action_id: str, body: ActionUpdateRequest) -> Action:
  78. db_action = await ActionService.get_action(session=session, action_id=action_id)
  79. update_dict = {}
  80. if body.openapi_schema is not None:
  81. openapi_schema = replace_openapi_refs(body.openapi_schema)
  82. action: Action = ActionService.build_action_struct(openapi_schema)
  83. update_dict["openapi_schema"] = action.openapi_schema
  84. update_dict["name"] = action.name
  85. update_dict["description"] = action.description
  86. update_dict["operation_id"] = action.operation_id
  87. update_dict["url"] = action.url
  88. update_dict["method"] = action.method
  89. update_dict["path_param_schema"] = action.path_param_schema
  90. update_dict["query_param_schema"] = action.query_param_schema
  91. update_dict["body_type"] = action.body_type
  92. update_dict["body_param_schema"] = action.body_param_schema
  93. update_dict["function_def"] = action.function_def
  94. if body.authentication is not None:
  95. update_dict["authentication"] = body.authentication.dict()
  96. if body.use_for_everyone is not None:
  97. update_dict["use_for_everyone"] = body.use_for_everyone
  98. for key, value in update_dict.items():
  99. setattr(db_action, key, value)
  100. session.add(db_action)
  101. await session.commit()
  102. await session.refresh(db_action)
  103. return db_action
  104. @staticmethod
  105. async def delete_action(*, session: AsyncSession, action_id: str) -> DeleteResponse:
  106. action = await ActionService.get_action(session=session, action_id=action_id)
  107. await session.delete(action)
  108. await auth_policy.delete_token_rel(session=session, relation_type=RelationType.Action, relation_id=action_id)
  109. await session.commit()
  110. return DeleteResponse(id=action_id, object="action.deleted", deleted=True)
  111. @staticmethod
  112. async def get_action(*, session: AsyncSession, action_id: str) -> Action:
  113. statement = select(Action).where(Action.id == action_id)
  114. result = await session.execute(statement)
  115. action = result.scalars().one_or_none()
  116. if action is None:
  117. raise ResourceNotFoundError(message="action not found")
  118. return action
  119. @staticmethod
  120. def build_action_struct(
  121. openapi_schema: Dict,
  122. ) -> Action:
  123. """
  124. Extract action components from OpenAPI schema.
  125. :param openapi_schema: a dict of OpenAPI schema
  126. :return: an Action including all the components of an action
  127. """
  128. # copy openapi_schema to avoid modifying the original
  129. openapi_dict = copy.deepcopy(openapi_schema)
  130. # extract the first path and method
  131. path, path_info = next(iter(openapi_dict["paths"].items()))
  132. method, method_info = next(iter(path_info.items()))
  133. # check operationId
  134. operation_id = method_info.get("operationId", None)
  135. # get function name
  136. name = function_name(method, path, operation_id)
  137. method = ActionMethod(method.upper())
  138. # extract description
  139. description = method_info.get("description", "")
  140. if not description:
  141. # use other fields to generate description
  142. summary = method_info.get("summary", "")
  143. description = f"{method.upper()} {path}: {summary}"
  144. # build function parameters schema
  145. url, path_param_schema, query_param_schema, body_type, body_param_schema = extract_params(
  146. openapi_dict, method, path
  147. )
  148. # build function definition
  149. function_def = build_function_def(
  150. name=name,
  151. description=description,
  152. path_param_schema=path_param_schema,
  153. query_param_schema=query_param_schema,
  154. body_param_schema=body_param_schema,
  155. )
  156. return Action.model_validate(
  157. {
  158. "name": name,
  159. "description": description,
  160. "operation_id": operation_id,
  161. "url": url,
  162. "method": method,
  163. "path_param_schema": action_param_schema_to_dict(path_param_schema),
  164. "query_param_schema": action_param_schema_to_dict(query_param_schema),
  165. "body_type": body_type,
  166. "body_param_schema": action_param_schema_to_dict(body_param_schema),
  167. "function_def": function_def.dict(exclude_none=True),
  168. "openapi_schema": openapi_dict,
  169. }
  170. )
  171. @staticmethod
  172. async def run_action(
  173. *,
  174. session: AsyncSession,
  175. action_id: str,
  176. parameters: Dict,
  177. headers: Dict,
  178. ) -> Dict:
  179. """
  180. Run an action
  181. :param action_id: the action ID
  182. :param parameters: the parameters for the API call
  183. :param headers: the headers for the API call
  184. :return: the response of the API call
  185. """
  186. action: Action = await ActionService.get_action(session=session, action_id=action_id)
  187. response = call_action_api(
  188. url=action.url,
  189. method=ActionMethod(action.method),
  190. path_param_schema=action_param_dict_to_schema(action.path_param_schema),
  191. query_param_schema=action_param_dict_to_schema(action.query_param_schema),
  192. body_param_schema=action_param_dict_to_schema(action.body_param_schema),
  193. body_type=ActionBodyType(action.body_type),
  194. parameters=parameters,
  195. headers=headers,
  196. authentication=Authentication(**action.authentication),
  197. )
  198. return response