xlsx_parser.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # type: ignore
  2. from io import BytesIO
  3. from typing import AsyncGenerator
  4. import networkx as nx
  5. import numpy as np
  6. from openpyxl import load_workbook
  7. from core.base.parsers.base_parser import AsyncParser
  8. from core.base.providers import (
  9. CompletionProvider,
  10. DatabaseProvider,
  11. IngestionConfig,
  12. )
  13. class XLSXParser(AsyncParser[str | bytes]):
  14. """A parser for XLSX data."""
  15. def __init__(
  16. self,
  17. config: IngestionConfig,
  18. database_provider: DatabaseProvider,
  19. llm_provider: CompletionProvider,
  20. ):
  21. self.database_provider = database_provider
  22. self.llm_provider = llm_provider
  23. self.config = config
  24. self.load_workbook = load_workbook
  25. async def ingest(
  26. self, data: bytes, *args, **kwargs
  27. ) -> AsyncGenerator[str, None]:
  28. """Ingest XLSX data and yield text from each row."""
  29. if isinstance(data, str):
  30. raise ValueError("XLSX data must be in bytes format.")
  31. wb = self.load_workbook(filename=BytesIO(data))
  32. for sheet in wb.worksheets:
  33. for row in sheet.iter_rows(values_only=True):
  34. yield ", ".join(map(str, row))
  35. class XLSXParserAdvanced(AsyncParser[str | bytes]):
  36. """A parser for XLSX data."""
  37. # identifies connected components in the excel graph and extracts data from each component
  38. def __init__(
  39. self, config: IngestionConfig, llm_provider: CompletionProvider
  40. ):
  41. self.llm_provider = llm_provider
  42. self.config = config
  43. self.nx = nx
  44. self.np = np
  45. self.load_workbook = load_workbook
  46. def connected_components(self, arr):
  47. g = self.nx.grid_2d_graph(len(arr), len(arr[0]))
  48. empty_cell_indices = list(
  49. zip(*self.np.where(arr is None), strict=False)
  50. )
  51. g.remove_nodes_from(empty_cell_indices)
  52. components = self.nx.connected_components(g)
  53. for component in components:
  54. rows, cols = zip(*component, strict=False)
  55. min_row, max_row = min(rows), max(rows)
  56. min_col, max_col = min(cols), max(cols)
  57. yield arr[min_row : max_row + 1, min_col : max_col + 1].astype(
  58. "str"
  59. )
  60. async def ingest(
  61. self, data: bytes, num_col_times_num_rows: int = 100, *args, **kwargs
  62. ) -> AsyncGenerator[str, None]:
  63. """Ingest XLSX data and yield text from each connected component."""
  64. if isinstance(data, str):
  65. raise ValueError("XLSX data must be in bytes format.")
  66. workbook = self.load_workbook(filename=BytesIO(data))
  67. for ws in workbook.worksheets:
  68. ws_data = self.np.array(
  69. [[cell.value for cell in row] for row in ws.iter_rows()]
  70. )
  71. for table in self.connected_components(ws_data):
  72. # parse like a csv parser, assumes that the first row has column names
  73. if len(table) <= 1:
  74. continue
  75. num_rows = len(table)
  76. num_rows_per_chunk = num_col_times_num_rows // num_rows
  77. headers = ", ".join(table[0])
  78. # add header to each one
  79. for i in range(1, num_rows, num_rows_per_chunk):
  80. chunk = table[i : i + num_rows_per_chunk]
  81. yield (
  82. headers
  83. + "\n"
  84. + "\n".join([", ".join(row) for row in chunk])
  85. )