xlsx_parser.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # type: ignore
  2. from io import BytesIO
  3. from typing import AsyncGenerator
  4. from core.base.parsers.base_parser import AsyncParser
  5. from core.base.providers import (
  6. CompletionProvider,
  7. DatabaseProvider,
  8. IngestionConfig,
  9. )
  10. class XLSXParser(AsyncParser[str | bytes]):
  11. """A parser for XLSX data."""
  12. def __init__(
  13. self,
  14. config: IngestionConfig,
  15. database_provider: DatabaseProvider,
  16. llm_provider: CompletionProvider,
  17. ):
  18. self.database_provider = database_provider
  19. self.llm_provider = llm_provider
  20. self.config = config
  21. try:
  22. from openpyxl import load_workbook
  23. self.load_workbook = load_workbook
  24. except ImportError:
  25. raise ValueError(
  26. "Error, `openpyxl` is required to run `XLSXParser`. Please install it using `pip install openpyxl`."
  27. )
  28. async def ingest(
  29. self, data: bytes, *args, **kwargs
  30. ) -> AsyncGenerator[str, None]:
  31. """Ingest XLSX data and yield text from each row."""
  32. if isinstance(data, str):
  33. raise ValueError("XLSX data must be in bytes format.")
  34. wb = self.load_workbook(filename=BytesIO(data))
  35. for sheet in wb.worksheets:
  36. for row in sheet.iter_rows(values_only=True):
  37. yield ", ".join(map(str, row))
  38. class XLSXParserAdvanced(AsyncParser[str | bytes]):
  39. """A parser for XLSX data."""
  40. # identifies connected components in the excel graph and extracts data from each component
  41. def __init__(
  42. self, config: IngestionConfig, llm_provider: CompletionProvider
  43. ):
  44. self.llm_provider = llm_provider
  45. self.config = config
  46. try:
  47. import networkx as nx
  48. import numpy as np
  49. from openpyxl import load_workbook
  50. self.nx = nx
  51. self.np = np
  52. self.load_workbook = load_workbook
  53. except ImportError:
  54. raise ValueError(
  55. "Error, `networkx` and `numpy` are required to run `XLSXParserAdvanced`. Please install them using `pip install networkx numpy`."
  56. )
  57. def connected_components(self, arr):
  58. g = self.nx.grid_2d_graph(len(arr), len(arr[0]))
  59. empty_cell_indices = list(zip(*self.np.where(arr is None)))
  60. g.remove_nodes_from(empty_cell_indices)
  61. components = self.nx.connected_components(g)
  62. for component in components:
  63. rows, cols = zip(*component)
  64. min_row, max_row = min(rows), max(rows)
  65. min_col, max_col = min(cols), max(cols)
  66. yield arr[min_row : max_row + 1, min_col : max_col + 1].astype(
  67. "str"
  68. )
  69. async def ingest(
  70. self, data: bytes, num_col_times_num_rows: int = 100, *args, **kwargs
  71. ) -> AsyncGenerator[str, None]:
  72. """Ingest XLSX data and yield text from each connected component."""
  73. if isinstance(data, str):
  74. raise ValueError("XLSX data must be in bytes format.")
  75. workbook = self.load_workbook(filename=BytesIO(data))
  76. for ws in workbook.worksheets:
  77. ws_data = self.np.array(
  78. [[cell.value for cell in row] for row in ws.iter_rows()]
  79. )
  80. for table in self.connected_components(ws_data):
  81. # parse like a csv parser, assumes that the first row has column names
  82. if len(table) <= 1:
  83. continue
  84. num_rows = len(table)
  85. num_rows_per_chunk = num_col_times_num_rows // num_rows
  86. headers = ", ".join(table[0])
  87. # add header to each one
  88. for i in range(1, num_rows, num_rows_per_chunk):
  89. chunk = table[i : i + num_rows_per_chunk]
  90. yield headers + "\n" + "\n".join(
  91. [", ".join(row) for row in chunk]
  92. )