audio_parser.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # type: ignore
  2. import logging
  3. import os
  4. import tempfile
  5. from typing import AsyncGenerator
  6. from litellm import atranscription
  7. from core.base.parsers.base_parser import AsyncParser
  8. from core.base.providers import (
  9. CompletionProvider,
  10. DatabaseProvider,
  11. IngestionConfig,
  12. )
  13. logger = logging.getLogger()
  14. class AudioParser(AsyncParser[bytes]):
  15. """A parser for audio data using Whisper transcription."""
  16. def __init__(
  17. self,
  18. config: IngestionConfig,
  19. database_provider: DatabaseProvider,
  20. llm_provider: CompletionProvider,
  21. ):
  22. self.database_provider = database_provider
  23. self.llm_provider = llm_provider
  24. self.config = config
  25. self.atranscription = atranscription
  26. async def ingest( # type: ignore
  27. self, data: bytes, **kwargs
  28. ) -> AsyncGenerator[str, None]:
  29. """Ingest audio data and yield a transcription using Whisper via
  30. LiteLLM.
  31. Args:
  32. data: Raw audio bytes
  33. *args, **kwargs: Additional arguments passed to the transcription call
  34. Yields:
  35. Chunks of transcribed text
  36. """
  37. try:
  38. # Create a temporary file to store the audio data
  39. with tempfile.NamedTemporaryFile(
  40. suffix=".wav", delete=False
  41. ) as temp_file:
  42. temp_file.write(data)
  43. temp_file_path = temp_file.name
  44. kwargs.pop("chunking_strategy", None)
  45. # Call Whisper transcription
  46. response = await self.atranscription(
  47. model=self.config.audio_transcription_model
  48. or self.config.app.audio_lm,
  49. file=open(temp_file_path, "rb"),
  50. **kwargs,
  51. )
  52. # The response should contain the transcribed text directly
  53. yield response.text
  54. except Exception as e:
  55. logger.error(f"Error processing audio with Whisper: {str(e)}")
  56. raise
  57. finally:
  58. # Clean up the temporary file
  59. try:
  60. os.unlink(temp_file_path)
  61. except Exception as e:
  62. logger.warning(f"Failed to delete temporary file: {str(e)}")