1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from typing import Optional, Any
- import pytest
- from r2r import R2RClient
- from app.libs.util import verify_jwt_expiration
- from config.llm import tool_settings
- import nest_asyncio
- import asyncio
- # Apply nest_asyncio to allow nested event loops
- nest_asyncio.apply()
- class R2R:
- client: R2RClient
- def init(self):
- # self.client = R2RClient(tool_settings.R2R_BASE_URL)
- self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
- self.client = None
- loop = asyncio.get_event_loop()
- if loop.is_running():
- return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
- else:
- return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
- # loop.create_task(self._login())
- def ingest_file(self, file_path: str, metadata: Optional[dict]):
- self._check_login()
- ingest_response = self.client.documents.create(
- file_path=file_path, metadata=metadata if metadata else None, id=None
- )
- return ingest_response.get("results")
- def search(self, query: str, filters: dict[str, Any]):
- self._check_login()
- print(
- "aaaaaaacccccccccccccccccccccccccccccccccccccccccvvvvvvvvvvvvvvvvvvvvvvvvvvvvv"
- )
- print(filters)
- print(tool_settings.R2R_SEARCH_LIMIT)
- search_response = self.client.retrieval.search(
- query=query,
- search_settings={
- "filters": filters,
- "limit": tool_settings.R2R_SEARCH_LIMIT,
- # ,"do_hybrid_search": True,
- },
- )
- print(search_response)
- return search_response.get("results").get("chunk_search_results")
- # @pytest.fixture(scope="session")
- async def _login(self):
- if not self.auth_enabled:
- return
- print(
- "client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>"
- )
- print(self.client)
- if not self.client:
- self.client = R2RClient(tool_settings.R2R_BASE_URL)
- result = await self.client.users.login(
- tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
- ) # 同步调用异步函数
- # return self.client
- def _check_login(self):
- if not self.auth_enabled:
- return
- if verify_jwt_expiration(self.client.access_token):
- return
- else:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
- else:
- return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
- # loop.create_task(self._login())
- # self._login()
- r2r = R2R()
- r2r.init() # 运行异步函数
|