memory.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. 限于 llm 对上下文长度的限制和成本控制,需要对上下文进行筛选整合,本模块用于相关策略描述
  3. """
  4. from enum import Enum
  5. from typing import List
  6. from abc import ABC, abstractmethod
  7. class MemoryType(str, Enum):
  8. """
  9. support 3 kind of context memory
  10. """
  11. WINDOW = "window"
  12. ZERO = "zero"
  13. NAIVE = "naive"
  14. class Memory(ABC):
  15. """
  16. interface for context memory
  17. """
  18. @abstractmethod
  19. def integrate_context(self, messages: List[dict]) -> List[dict]:
  20. """
  21. integrate context according to the memory
  22. """
  23. class WindowMemory(Memory):
  24. """
  25. limit the context length to a fixed window size
  26. """
  27. def __init__(self, window_size: int = 20, max_token_size: int = 4000):
  28. if window_size < 1 or max_token_size < 1:
  29. raise ValueError("window size and max token size should be greater than 0")
  30. self.window_size = window_size
  31. self.max_token_size = max_token_size
  32. def integrate_context(self, messages: List[dict]) -> List[dict]:
  33. if not messages:
  34. return []
  35. histories = messages[-self.window_size :]
  36. total_length = 0
  37. for i, message in enumerate(reversed(histories)):
  38. total_length += len(message["content"])
  39. if total_length >= self.max_token_size:
  40. return histories[len(histories) - i - 1 :]
  41. return histories
  42. class NaiveMemory(Memory):
  43. """
  44. navie memory contains all the context
  45. """
  46. def integrate_context(self, messages: List[dict]) -> List[dict]:
  47. return messages
  48. class ZeroMemory(Memory):
  49. """
  50. zero memory contains last user message
  51. """
  52. def integrate_context(self, messages: List[dict]) -> List[dict]:
  53. if not messages:
  54. return []
  55. for i, message in enumerate(reversed(messages)):
  56. if message["role"] == "user":
  57. return messages[len(messages) - i - 1 :]
  58. Memories = {MemoryType.WINDOW: WindowMemory, MemoryType.ZERO: ZeroMemory, MemoryType.NAIVE: NaiveMemory}
  59. def find_memory(assistants_metadata: dict) -> Memory:
  60. memory_type = assistants_metadata.get("type", MemoryType.NAIVE)
  61. kwargs = assistants_metadata.copy()
  62. kwargs.pop("type", None)
  63. if kwargs:
  64. return Memories[memory_type](**kwargs)
  65. else:
  66. return Memories[memory_type]()