markdown.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import re
  2. from typing import Any, Generator, Iterable, Optional, Tuple
  3. from flupy import flu
  4. from .base import AdapterContext, AdapterStep
  5. class MarkdownChunker(AdapterStep):
  6. """
  7. MarkdownChunker is an AdapterStep that splits a markdown string into chunks where a heading signifies the start of a chunk, and yields each chunk as a separate record.
  8. """
  9. def __init__(self, *, skip_during_query: bool):
  10. """
  11. Initializes the MarkdownChunker adapter.
  12. Args:
  13. skip_during_query (bool): Whether to skip chunking during querying.
  14. """
  15. self.skip_during_query = skip_during_query
  16. @staticmethod
  17. def split_by_heading(
  18. md: str, max_tokens: int
  19. ) -> Generator[str, None, None]:
  20. regex_split = r"^(#{1,6}\s+.+)$"
  21. headings = [
  22. match.span()[0]
  23. for match in re.finditer(regex_split, md, flags=re.MULTILINE)
  24. ]
  25. if headings == [] or headings[0] != 0:
  26. headings.insert(0, 0)
  27. sections = [md[i:j] for i, j in zip(headings, headings[1:] + [None])]
  28. for section in sections:
  29. chunks = flu(section.split(" ")).chunk(max_tokens)
  30. is_not_useless_chunk = lambda i: not i in [
  31. "",
  32. "\n",
  33. [],
  34. ] # noqa: E731, E713
  35. joined_chunks = filter(
  36. is_not_useless_chunk,
  37. [" ".join(chunk) for chunk in chunks], # noqa: E731, E713
  38. )
  39. for joined_chunk in joined_chunks:
  40. yield joined_chunk
  41. def __call__(
  42. self,
  43. records: Iterable[Tuple[str, Any, Optional[dict]]],
  44. adapter_context: AdapterContext,
  45. max_tokens: int = 99999999,
  46. ) -> Generator[Tuple[str, Any, dict], None, None]:
  47. """
  48. Splits each markdown string in the records into chunks where each heading starts a new chunk, and yields each chunk
  49. as a separate record. If the `skip_during_query` attribute is set to True,
  50. this step is skipped during querying.
  51. Args:
  52. records (Iterable[Tuple[str, Any, Optional[dict]]]): Iterable of tuples each containing an id, a markdown string and an optional dict.
  53. adapter_context (AdapterContext): Context of the adapter.
  54. max_tokens (int): The maximum number of tokens per chunk
  55. Yields:
  56. Tuple[str, Any, dict]: The id appended with chunk index, the chunk, and the metadata.
  57. """
  58. if max_tokens and max_tokens < 1:
  59. raise ValueError("max_tokens must be a nonzero positive integer")
  60. if (
  61. adapter_context == AdapterContext("query")
  62. and self.skip_during_query
  63. ):
  64. for id, markdown, metadata in records:
  65. yield (id, markdown, metadata or {})
  66. else:
  67. for id, markdown, metadata in records:
  68. headings = MarkdownChunker.split_by_heading(
  69. markdown, max_tokens
  70. )
  71. for heading_ix, heading in enumerate(headings):
  72. yield (
  73. f"{id}_head_{str(heading_ix).zfill(3)}",
  74. heading,
  75. metadata or {},
  76. )