123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- # type: ignore
- from typing import IO, AsyncGenerator, Optional
- from core.base.parsers.base_parser import AsyncParser
- from core.base.providers import (
- CompletionProvider,
- DatabaseProvider,
- IngestionConfig,
- )
- class CSVParser(AsyncParser[str | bytes]):
- """A parser for CSV data."""
- def __init__(
- self,
- config: IngestionConfig,
- database_provider: DatabaseProvider,
- llm_provider: CompletionProvider,
- ):
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.config = config
- import csv
- from io import StringIO
- self.csv = csv
- self.StringIO = StringIO
- async def ingest(
- self, data: str | bytes, *args, **kwargs
- ) -> AsyncGenerator[str, None]:
- """Ingest CSV data and yield text from each row."""
- if isinstance(data, bytes):
- data = data.decode("utf-8")
- csv_reader = self.csv.reader(self.StringIO(data))
- for row in csv_reader:
- yield ", ".join(row)
- class CSVParserAdvanced(AsyncParser[str | bytes]):
- """A parser for CSV data."""
- def __init__(
- self, config: IngestionConfig, llm_provider: CompletionProvider
- ):
- self.llm_provider = llm_provider
- self.config = config
- import csv
- from io import StringIO
- self.csv = csv
- self.StringIO = StringIO
- def get_delimiter(
- self, file_path: Optional[str] = None, file: Optional[IO[bytes]] = None
- ):
- sniffer = self.csv.Sniffer()
- num_bytes = 65536
- if file:
- lines = file.readlines(num_bytes)
- file.seek(0)
- data = "\n".join(ln.decode("utf-8") for ln in lines)
- elif file_path is not None:
- with open(file_path) as f:
- data = "\n".join(f.readlines(num_bytes))
- return sniffer.sniff(data, delimiters=",;").delimiter
- async def ingest(
- self,
- data: str | bytes,
- num_col_times_num_rows: int = 100,
- *args,
- **kwargs,
- ) -> AsyncGenerator[str, None]:
- """Ingest CSV data and yield text from each row."""
- if isinstance(data, bytes):
- data = data.decode("utf-8")
- # let the first row be the header
- delimiter = self.get_delimiter(file=self.StringIO(data))
- csv_reader = self.csv.reader(self.StringIO(data), delimiter=delimiter)
- header = next(csv_reader)
- num_cols = len(header.split(delimiter))
- num_rows = num_col_times_num_rows // num_cols
- chunk_rows = []
- for row_num, row in enumerate(csv_reader):
- chunk_rows.append(row)
- if row_num % num_rows == 0:
- yield ", ".join(header) + "\n" + "\n".join(
- [", ".join(row) for row in chunk_rows]
- )
- chunk_rows = []
- if chunk_rows:
- yield ", ".join(header) + "\n" + "\n".join(
- [", ".join(row) for row in chunk_rows]
- )
|