123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from typing import Optional, Any
- import pytest
- from r2r import R2RClient
- from app.libs.util import verify_jwt_expiration
- from config.llm import tool_settings
- import nest_asyncio
- import asyncio
- nest_asyncio.apply()
- class R2R:
- client: R2RClient
- async def init(self):
- # self.client = R2RClient(tool_settings.R2R_BASE_URL)
- self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
- await self._login()
- def ingest_file(self, file_path: str, metadata: Optional[dict]):
- self._check_login()
- ingest_response = self.client.documents.create(
- file_path=file_path, metadata=metadata if metadata else None, id=None
- )
- return ingest_response.get("results")
- def search(self, query: str, filters: dict[str, Any]):
- self._check_login()
- search_response = self.client.retrieval.search(
- query=query,
- search_settings={
- "filters": filters,
- "limit": 10 #tool_settings.R2R_SEARCH_LIMIT,
- # ,"do_hybrid_search": True,
- },
- )
- #print(search_response)
- return search_response.get("results").get("chunk_search_results")
- #@pytest.fixture(scope="session")
- async def _login(self):
- if not self.auth_enabled:
- return
- self.client = R2RClient(tool_settings.R2R_BASE_URL)
- self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
- return self.client
- def _check_login(self):
- if not self.auth_enabled:
- return
- if verify_jwt_expiration(self.client.access_token):
- return
- else:
- self._login()
- r2r = R2R()
- def run_async():
- asyncio.run(r2r.init()) # 运行异步函数
- print(r2r.search("文档内容", {'file_key': {'$in': ['5a098a2c-e1ef-40b0-84a0-d6606261fee0-test.txt']}}))
- run_async()
|