csv_parser.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # type: ignore
  2. from typing import IO, AsyncGenerator, Optional
  3. from core.base.parsers.base_parser import AsyncParser
  4. from core.base.providers import (
  5. CompletionProvider,
  6. DatabaseProvider,
  7. IngestionConfig,
  8. )
  9. class CSVParser(AsyncParser[str | bytes]):
  10. """A parser for CSV data."""
  11. def __init__(
  12. self,
  13. config: IngestionConfig,
  14. database_provider: DatabaseProvider,
  15. llm_provider: CompletionProvider,
  16. ):
  17. self.database_provider = database_provider
  18. self.llm_provider = llm_provider
  19. self.config = config
  20. import csv
  21. from io import StringIO
  22. self.csv = csv
  23. self.StringIO = StringIO
  24. async def ingest(
  25. self, data: str | bytes, *args, **kwargs
  26. ) -> AsyncGenerator[str, None]:
  27. """Ingest CSV data and yield text from each row."""
  28. if isinstance(data, bytes):
  29. data = data.decode("utf-8")
  30. csv_reader = self.csv.reader(self.StringIO(data))
  31. for row in csv_reader:
  32. yield ", ".join(row)
  33. class CSVParserAdvanced(AsyncParser[str | bytes]):
  34. """A parser for CSV data."""
  35. def __init__(
  36. self, config: IngestionConfig, database_provider: DatabaseProvider, llm_provider: CompletionProvider
  37. ):
  38. self.database_provider = database_provider
  39. self.llm_provider = llm_provider
  40. self.config = config
  41. import csv
  42. from io import StringIO
  43. self.csv = csv
  44. self.StringIO = StringIO
  45. def get_delimiter(
  46. self, file_path: Optional[str] = None, file: Optional[IO[bytes]] = None
  47. ):
  48. sniffer = self.csv.Sniffer()
  49. num_bytes = 65536
  50. if file:
  51. lines = file.readlines(num_bytes)
  52. file.seek(0)
  53. #.decode("utf-8")
  54. data = "\n".join(ln for ln in lines)
  55. elif file_path is not None:
  56. with open(file_path) as f:
  57. data = "\n".join(f.readlines(num_bytes))
  58. return sniffer.sniff(data, delimiters=",;").delimiter
  59. async def ingest(
  60. self,
  61. data: str | bytes,
  62. num_col_times_num_rows: int = 100,
  63. *args,
  64. **kwargs,
  65. ) -> AsyncGenerator[str, None]:
  66. """Ingest CSV data and yield text from each row."""
  67. #print(data)
  68. if isinstance(data, bytes):
  69. try:
  70. data = data.decode("utf-8")
  71. except UnicodeDecodeError:
  72. # 尝试其他常见编码
  73. for encoding in ['latin-1', 'cp1252', 'iso-8859-1']:
  74. try:
  75. data = data.decode(encoding)
  76. break
  77. except UnicodeDecodeError:
  78. continue
  79. else:
  80. raise ValueError("Unable to decode the provided byte data with any supported encoding")
  81. # let the first row be the header
  82. print("1")
  83. delimiter = self.get_delimiter(file=self.StringIO(data))
  84. print("2")
  85. csv_reader = self.csv.reader(self.StringIO(data), delimiter=delimiter)
  86. print("3")
  87. header = next(csv_reader)
  88. print("4")
  89. #print(header)
  90. #num_cols = len(header.split(delimiter))
  91. num_cols = len(header)
  92. print("5")
  93. num_rows = num_col_times_num_rows // num_cols
  94. print("6")
  95. print(num_rows)
  96. print(row_num)
  97. chunk_rows = []
  98. for row_num, row in enumerate(csv_reader):
  99. print(row)
  100. chunk_rows.append(row)
  101. #if num_rows > 0 and row_num > 0 and row_num % num_rows == 0:
  102. if True:
  103. yield (
  104. ", ".join(header)
  105. + "\n"
  106. + "\n".join([", ".join(row) for row in chunk_rows])
  107. )
  108. chunk_rows = []
  109. if chunk_rows:
  110. yield (
  111. ", ".join(header)
  112. + "\n"
  113. + "\n".join([", ".join(row) for row in chunk_rows])
  114. )