123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- import json
- import logging
- from typing import Any, AsyncGenerator
- from uuid import UUID
- from core.base import (
- AsyncState,
- CompletionProvider,
- DatabaseProvider,
- EmbeddingProvider,
- )
- from core.base.abstractions import (
- GraphSearchResult,
- GraphSearchSettings,
- KGCommunityResult,
- KGEntityResult,
- KGRelationshipResult,
- KGSearchResultType,
- SearchSettings,
- )
- from ..abstractions.generator_pipe import GeneratorPipe
- logger = logging.getLogger()
- class GraphSearchSearchPipe(GeneratorPipe):
- """
- Embeds and stores documents using a specified embedding model and database.
- """
- def __init__(
- self,
- llm_provider: CompletionProvider,
- database_provider: DatabaseProvider,
- embedding_provider: EmbeddingProvider,
- config: GeneratorPipe.PipeConfig,
- *args,
- **kwargs,
- ):
- """
- Initializes the embedding pipe with necessary components and configurations.
- """
- super().__init__(
- llm_provider,
- database_provider,
- config,
- *args,
- **kwargs,
- )
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.embedding_provider = embedding_provider
- self.pipe_run_info = None
- def filter_responses(self, map_responses):
- filtered_responses = []
- for response in map_responses:
- try:
- parsed_response = json.loads(response)
- for item in parsed_response["points"]:
- try:
- if item["score"] > 0:
- filtered_responses.append(item)
- except KeyError:
- # Skip this item if it doesn't have a 'score' key
- logger.warning(f"Item in response missing 'score' key")
- continue
- except json.JSONDecodeError:
- logger.warning(
- f"Response is not valid JSON: {response[:100]}..."
- )
- continue
- except KeyError:
- logger.warning(
- f"Response is missing 'points' key: {response[:100]}..."
- )
- continue
- filtered_responses = sorted(
- filtered_responses, key=lambda x: x["score"], reverse=True
- )
- responses = "\n".join(
- [
- response.get("description", "")
- for response in filtered_responses
- ]
- )
- return responses
- async def search(
- self,
- input: GeneratorPipe.Input,
- state: AsyncState,
- run_id: UUID,
- search_settings: SearchSettings,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[GraphSearchResult, None]:
- if search_settings.graph_settings.enabled == False:
- return
- async for message in input.message:
- query_embedding = (
- await self.embedding_provider.async_get_embedding(message)
- )
- # entity search
- search_type = "entities"
- base_limit = search_settings.limit
- if search_type not in search_settings.graph_settings.limits:
- logger.warning(
- f"No limit set for graph search type {search_type}, defaulting to global settings limit of {base_limit}"
- )
- async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
- message,
- search_type=search_type,
- limit=search_settings.graph_settings.limits.get(
- search_type, base_limit
- ),
- query_embedding=query_embedding,
- property_names=[
- "name",
- "description",
- "chunk_ids",
- ],
- filters=search_settings.filters,
- ):
- yield GraphSearchResult(
- content=KGEntityResult(
- name=search_result["name"],
- description=search_result["description"],
- ),
- result_type=KGSearchResultType.ENTITY,
- score=(
- search_result["similarity_score"]
- if search_settings.include_scores
- else None
- ),
- # chunk_ids=search_result["chunk_ids"],
- metadata=(
- {
- "associated_query": message,
- **(search_result["metadata"] or {}),
- }
- if search_settings.include_metadatas
- else None
- ),
- )
- # # relationship search
- # # disabled for now. We will check evaluations and see if we need it
- search_type = "relationships"
- if search_type not in search_settings.graph_settings.limits:
- logger.warning(
- f"No limit set for graph search type {search_type}, defaulting to global settings limit of {base_limit}"
- )
- async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
- input,
- search_type=search_type,
- limit=search_settings.graph_settings.limits.get(
- search_type, base_limit
- ),
- query_embedding=query_embedding,
- property_names=[
- # "name",
- "subject",
- "predicate",
- "object",
- # "name",
- "description",
- # "chunk_ids",
- # "document_ids",
- ],
- ):
- try:
- # TODO - remove this nasty hack
- search_result["metadata"] = json.loads(
- search_result["metadata"]
- )
- except:
- pass
- yield GraphSearchResult(
- content=KGRelationshipResult(
- # name=search_result["name"],
- subject=search_result["subject"],
- predicate=search_result["predicate"],
- object=search_result["object"],
- description=search_result["description"],
- ),
- result_type=KGSearchResultType.RELATIONSHIP,
- score=(
- search_result["similarity_score"]
- if search_settings.include_scores
- else None
- ),
- # chunk_ids=search_result["chunk_ids"],
- # document_ids=search_result["document_ids"],
- metadata=(
- {
- "associated_query": message,
- **(search_result["metadata"] or {}),
- }
- if search_settings.include_metadatas
- else None
- ),
- )
- # community search
- search_type = "communities"
- async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
- message,
- search_type=search_type,
- limit=search_settings.graph_settings.limits.get(
- search_type, base_limit
- ),
- # embedding_type="embedding",
- query_embedding=query_embedding,
- property_names=[
- "community_id",
- "name",
- "findings",
- "rating",
- "rating_explanation",
- "summary",
- ],
- filters=search_settings.filters,
- ):
- yield GraphSearchResult(
- content=KGCommunityResult(
- name=search_result["name"],
- summary=search_result["summary"],
- rating=search_result["rating"],
- rating_explanation=search_result["rating_explanation"],
- findings=search_result["findings"],
- ),
- result_type=KGSearchResultType.COMMUNITY,
- metadata=(
- {
- "associated_query": message,
- **(search_result["metadata"] or {}),
- }
- if search_settings.include_metadatas
- else None
- ),
- score=(
- search_result["similarity_score"]
- if search_settings.include_scores
- else None
- ),
- )
- async def _run_logic( # type: ignore
- self,
- input: GeneratorPipe.Input,
- state: AsyncState,
- run_id: UUID,
- search_settings: GraphSearchSettings,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[GraphSearchResult, None]:
- async for result in self.search(input, state, run_id, search_settings):
- yield result
|