xls_parser.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # type: ignore
  2. from typing import AsyncGenerator
  3. import networkx as nx
  4. import numpy as np
  5. import xlrd
  6. from core.base.parsers.base_parser import AsyncParser
  7. from core.base.providers import (
  8. CompletionProvider,
  9. DatabaseProvider,
  10. IngestionConfig,
  11. )
  12. class XLSParser(AsyncParser[str | bytes]):
  13. """A parser for XLS (Excel 97-2003) data."""
  14. def __init__(
  15. self,
  16. config: IngestionConfig,
  17. database_provider: DatabaseProvider,
  18. llm_provider: CompletionProvider,
  19. ):
  20. self.database_provider = database_provider
  21. self.llm_provider = llm_provider
  22. self.config = config
  23. self.xlrd = xlrd
  24. async def ingest(
  25. self, data: bytes, *args, **kwargs
  26. ) -> AsyncGenerator[str, None]:
  27. """Ingest XLS data and yield text from each row."""
  28. if isinstance(data, str):
  29. raise ValueError("XLS data must be in bytes format.")
  30. wb = self.xlrd.open_workbook(file_contents=data)
  31. for sheet in wb.sheets():
  32. for row_idx in range(sheet.nrows):
  33. # Get all values in the row
  34. row_values = []
  35. for col_idx in range(sheet.ncols):
  36. cell = sheet.cell(row_idx, col_idx)
  37. # Handle different cell types
  38. if cell.ctype == self.xlrd.XL_CELL_DATE:
  39. try:
  40. value = self.xlrd.xldate_as_datetime(
  41. cell.value, wb.datemode
  42. ).strftime("%Y-%m-%d")
  43. except Exception:
  44. value = str(cell.value)
  45. elif cell.ctype == self.xlrd.XL_CELL_BOOLEAN:
  46. value = str(bool(cell.value)).lower()
  47. elif cell.ctype == self.xlrd.XL_CELL_ERROR:
  48. value = "#ERROR#"
  49. else:
  50. value = str(cell.value).strip()
  51. row_values.append(value)
  52. # Yield non-empty rows
  53. if any(val.strip() for val in row_values):
  54. yield ", ".join(row_values)
  55. class XLSParserAdvanced(AsyncParser[str | bytes]):
  56. """An advanced parser for XLS data with chunking support."""
  57. def __init__(
  58. self, config: IngestionConfig, llm_provider: CompletionProvider
  59. ):
  60. self.llm_provider = llm_provider
  61. self.config = config
  62. self.nx = nx
  63. self.np = np
  64. self.xlrd = xlrd
  65. def connected_components(self, arr):
  66. g = self.nx.grid_2d_graph(len(arr), len(arr[0]))
  67. empty_cell_indices = list(zip(*self.np.where(arr == ""), strict=False))
  68. g.remove_nodes_from(empty_cell_indices)
  69. components = self.nx.connected_components(g)
  70. for component in components:
  71. rows, cols = zip(*component, strict=False)
  72. min_row, max_row = min(rows), max(rows)
  73. min_col, max_col = min(cols), max(cols)
  74. yield arr[min_row : max_row + 1, min_col : max_col + 1]
  75. def get_cell_value(self, cell, workbook):
  76. """Extract cell value handling different data types."""
  77. if cell.ctype == self.xlrd.XL_CELL_DATE:
  78. try:
  79. return self.xlrd.xldate_as_datetime(
  80. cell.value, workbook.datemode
  81. ).strftime("%Y-%m-%d")
  82. except Exception:
  83. return str(cell.value)
  84. elif cell.ctype == self.xlrd.XL_CELL_BOOLEAN:
  85. return str(bool(cell.value)).lower()
  86. elif cell.ctype == self.xlrd.XL_CELL_ERROR:
  87. return "#ERROR#"
  88. else:
  89. return str(cell.value).strip()
  90. async def ingest(
  91. self, data: bytes, num_col_times_num_rows: int = 100, *args, **kwargs
  92. ) -> AsyncGenerator[str, None]:
  93. """Ingest XLS data and yield text from each connected component."""
  94. if isinstance(data, str):
  95. raise ValueError("XLS data must be in bytes format.")
  96. workbook = self.xlrd.open_workbook(file_contents=data)
  97. for sheet in workbook.sheets():
  98. # Convert sheet to numpy array with proper value handling
  99. ws_data = self.np.array(
  100. [
  101. [
  102. self.get_cell_value(sheet.cell(row, col), workbook)
  103. for col in range(sheet.ncols)
  104. ]
  105. for row in range(sheet.nrows)
  106. ]
  107. )
  108. for table in self.connected_components(ws_data):
  109. if len(table) <= 1:
  110. continue
  111. num_rows = len(table)
  112. num_rows_per_chunk = num_col_times_num_rows // num_rows
  113. headers = ", ".join(table[0])
  114. for i in range(1, num_rows, num_rows_per_chunk):
  115. chunk = table[i : i + num_rows_per_chunk]
  116. yield (
  117. headers
  118. + "\n"
  119. + "\n".join([", ".join(row) for row in chunk])
  120. )