tsv_parser.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # type: ignore
  2. from typing import IO, AsyncGenerator
  3. from core.base.parsers.base_parser import AsyncParser
  4. from core.base.providers import (
  5. CompletionProvider,
  6. DatabaseProvider,
  7. IngestionConfig,
  8. )
  9. class TSVParser(AsyncParser[str | bytes]):
  10. """A parser for TSV (Tab Separated Values) data."""
  11. def __init__(
  12. self,
  13. config: IngestionConfig,
  14. database_provider: DatabaseProvider,
  15. llm_provider: CompletionProvider,
  16. ):
  17. self.database_provider = database_provider
  18. self.llm_provider = llm_provider
  19. self.config = config
  20. import csv
  21. from io import StringIO
  22. self.csv = csv
  23. self.StringIO = StringIO
  24. async def ingest(
  25. self, data: str | bytes, *args, **kwargs
  26. ) -> AsyncGenerator[str, None]:
  27. """Ingest TSV data and yield text from each row."""
  28. if isinstance(data, bytes):
  29. data = data.decode("utf-8")
  30. tsv_reader = self.csv.reader(self.StringIO(data), delimiter="\t")
  31. for row in tsv_reader:
  32. yield ", ".join(row) # Still join with comma for readability
  33. class TSVParserAdvanced(AsyncParser[str | bytes]):
  34. """An advanced parser for TSV data with chunking support."""
  35. def __init__(
  36. self, config: IngestionConfig, llm_provider: CompletionProvider
  37. ):
  38. self.llm_provider = llm_provider
  39. self.config = config
  40. import csv
  41. from io import StringIO
  42. self.csv = csv
  43. self.StringIO = StringIO
  44. def validate_tsv(self, file: IO[bytes]) -> bool:
  45. """Validate if the file is actually tab-delimited."""
  46. num_bytes = 65536
  47. lines = file.readlines(num_bytes)
  48. file.seek(0)
  49. if not lines:
  50. return False
  51. # Check if tabs exist in first few lines
  52. sample = "\n".join(ln.decode("utf-8") for ln in lines[:5])
  53. return "\t" in sample
  54. async def ingest(
  55. self,
  56. data: str | bytes,
  57. num_col_times_num_rows: int = 100,
  58. *args,
  59. **kwargs,
  60. ) -> AsyncGenerator[str, None]:
  61. """Ingest TSV data and yield text in chunks."""
  62. if isinstance(data, bytes):
  63. data = data.decode("utf-8")
  64. # Validate TSV format
  65. if not self.validate_tsv(self.StringIO(data)):
  66. raise ValueError("File does not appear to be tab-delimited")
  67. tsv_reader = self.csv.reader(self.StringIO(data), delimiter="\t")
  68. # Get header
  69. header = next(tsv_reader)
  70. num_cols = len(header)
  71. num_rows = num_col_times_num_rows // num_cols
  72. chunk_rows = []
  73. for row_num, row in enumerate(tsv_reader):
  74. chunk_rows.append(row)
  75. if row_num % num_rows == 0:
  76. yield ", ".join(header) + "\n" + "\n".join(
  77. [", ".join(row) for row in chunk_rows]
  78. )
  79. chunk_rows = []
  80. # Yield remaining rows
  81. if chunk_rows:
  82. yield ", ".join(header) + "\n" + "\n".join(
  83. [", ".join(row) for row in chunk_rows]
  84. )