paginate.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import logging
  2. from typing import TypeVar, Any, Optional, Generic, List, Sequence
  3. from fastapi import Query
  4. from fastapi_pagination.bases import AbstractPage, AbstractParams, CursorRawParams
  5. from fastapi_pagination.cursor import encode_cursor
  6. from fastapi_pagination.ext.sqlmodel import paginate
  7. from fastapi_pagination.types import Cursor
  8. from fastapi_pagination.utils import verify_params, create_pydantic_model
  9. from sqlmodel import asc, desc
  10. from sqlalchemy.ext.asyncio import AsyncSession
  11. from app.models.base_model import BaseModel
  12. ModelType = TypeVar("ModelType", bound=BaseModel)
  13. class CursorParams(BaseModel, AbstractParams):
  14. limit: int = Query(20, ge=1, le=100, description="Page offset")
  15. order: str = Query(default="desc", description="Sort order")
  16. after: Optional[str] = Query(None, description="Page after")
  17. before: Optional[str] = Query(None, description="Page before")
  18. def to_raw_params(self) -> CursorRawParams:
  19. return CursorRawParams(cursor=None, size=self.limit, include_total=True)
  20. class CommonPage(AbstractPage[ModelType], Generic[ModelType]):
  21. __params_type__ = CursorParams
  22. object: str = "list"
  23. data: List[ModelType] = []
  24. first_id: Optional[str] = ""
  25. last_id: Optional[str] = ""
  26. has_more: bool = False
  27. @classmethod
  28. def create(
  29. cls,
  30. items: Sequence[ModelType],
  31. params: CursorParams,
  32. *,
  33. current: Optional[Cursor] = None,
  34. current_backwards: Optional[Cursor] = None,
  35. next_: Optional[Cursor] = None,
  36. previous: Optional[Cursor] = None,
  37. **kwargs: Any,
  38. ) -> AbstractPage[ModelType]:
  39. next_page = encode_cursor(next_)
  40. return create_pydantic_model(
  41. CommonPage,
  42. next_page=next_page,
  43. first_id=items[0].id if items else None,
  44. last_id=items[len(items) - 1].id if items else None,
  45. has_more=False if next_page is None else True,
  46. data=list(items),
  47. )
  48. async def cursor_page(query: Any, db: AsyncSession) -> CommonPage[ModelType]:
  49. params, _ = verify_params(None, "cursor")
  50. model = query._propagate_attrs["plugin_subject"].class_
  51. logging.debug(
  52. f"Page model={model}, sort={params.order}, filter_parameters=before:{params.before}, after:{params.after}",
  53. )
  54. if params.before is not None:
  55. if params.order.upper() == "DESC":
  56. query = query.where(model.id > params.before)
  57. else:
  58. query = query.where(model.id < params.before)
  59. if params.after is not None:
  60. if params.order.upper() == "DESC":
  61. query = query.where(model.id < params.after)
  62. else:
  63. query = query.where(model.id > params.after)
  64. if params.order.upper() == "DESC":
  65. query = query.order_by(desc(model.__dict__["created_at"]))
  66. else:
  67. query = query.order_by(asc(model.__dict__["created_at"]))
  68. return await paginate(db, query)