# type: ignore from typing import IO, AsyncGenerator, Optional from core.base.parsers.base_parser import AsyncParser from core.base.providers import ( CompletionProvider, DatabaseProvider, IngestionConfig, ) class CSVParser(AsyncParser[str | bytes]): """A parser for CSV data.""" def __init__( self, config: IngestionConfig, database_provider: DatabaseProvider, llm_provider: CompletionProvider, ): self.database_provider = database_provider self.llm_provider = llm_provider self.config = config import csv from io import StringIO self.csv = csv self.StringIO = StringIO async def ingest( self, data: str | bytes, *args, **kwargs ) -> AsyncGenerator[str, None]: """Ingest CSV data and yield text from each row.""" if isinstance(data, bytes): data = data.decode("utf-8") csv_reader = self.csv.reader(self.StringIO(data)) for row in csv_reader: yield ", ".join(row) class CSVParserAdvanced(AsyncParser[str | bytes]): """A parser for CSV data.""" def __init__( self, config: IngestionConfig, database_provider: DatabaseProvider, llm_provider: CompletionProvider ): self.database_provider = database_provider self.llm_provider = llm_provider self.config = config import csv from io import StringIO self.csv = csv self.StringIO = StringIO def get_delimiter( self, file_path: Optional[str] = None, file: Optional[IO[bytes]] = None ): sniffer = self.csv.Sniffer() num_bytes = 65536 if file: lines = file.readlines(num_bytes) file.seek(0) #.decode("utf-8") data = "\n".join(ln for ln in lines) elif file_path is not None: with open(file_path) as f: data = "\n".join(f.readlines(num_bytes)) return sniffer.sniff(data, delimiters=",;").delimiter async def ingest( self, data: str | bytes, num_col_times_num_rows: int = 100, *args, **kwargs, ) -> AsyncGenerator[str, None]: """Ingest CSV data and yield text from each row.""" #print(data) if isinstance(data, bytes): try: data = data.decode("utf-8") except UnicodeDecodeError: # 尝试其他常见编码 for encoding in ['latin-1', 'cp1252', 'iso-8859-1']: try: data = data.decode(encoding) break except UnicodeDecodeError: continue else: raise ValueError("Unable to decode the provided byte data with any supported encoding") # let the first row be the header print("1") delimiter = self.get_delimiter(file=self.StringIO(data)) print("2") csv_reader = self.csv.reader(self.StringIO(data), delimiter=delimiter) print("3") header = next(csv_reader) print("4") #print(header) #num_cols = len(header.split(delimiter)) num_cols = len(header) print("5") num_rows = num_col_times_num_rows // num_cols print("6") print(num_rows) print(row_num) chunk_rows = [] for row_num, row in enumerate(csv_reader): print(row) chunk_rows.append(row) #if num_rows > 0 and row_num > 0 and row_num % num_rows == 0: if True: yield ( ", ".join(header) + "\n" + "\n".join([", ".join(row) for row in chunk_rows]) ) chunk_rows = [] if chunk_rows: yield ( ", ".join(header) + "\n" + "\n".join([", ".join(row) for row in chunk_rows]) )