audio_parser.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # type: ignore
  2. import logging
  3. import os
  4. import tempfile
  5. from typing import AsyncGenerator
  6. from litellm import atranscription
  7. from core.base.parsers.base_parser import AsyncParser
  8. from core.base.providers import (
  9. CompletionProvider,
  10. DatabaseProvider,
  11. IngestionConfig,
  12. )
  13. logger = logging.getLogger()
  14. class AudioParser(AsyncParser[bytes]):
  15. """A parser for audio data using Whisper transcription."""
  16. def __init__(
  17. self,
  18. config: IngestionConfig,
  19. database_provider: DatabaseProvider,
  20. llm_provider: CompletionProvider,
  21. ):
  22. self.database_provider = database_provider
  23. self.llm_provider = llm_provider
  24. self.config = config
  25. self.atranscription = atranscription
  26. async def ingest( # type: ignore
  27. self, data: bytes, **kwargs
  28. ) -> AsyncGenerator[str, None]:
  29. """Ingest audio data and yield a transcription using Whisper via
  30. LiteLLM.
  31. Args:
  32. data: Raw audio bytes
  33. *args, **kwargs: Additional arguments passed to the transcription call
  34. Yields:
  35. Chunks of transcribed text
  36. """
  37. try:
  38. # Create a temporary file to store the audio data
  39. with tempfile.NamedTemporaryFile(
  40. suffix=".wav", delete=False
  41. ) as temp_file:
  42. temp_file.write(data)
  43. temp_file_path = temp_file.name
  44. # Call Whisper transcription
  45. response = await self.atranscription(
  46. model=self.config.audio_transcription_model
  47. or self.config.app.audio_lm,
  48. file=open(temp_file_path, "rb"),
  49. **kwargs,
  50. )
  51. # The response should contain the transcribed text directly
  52. yield response.text
  53. except Exception as e:
  54. logger.error(f"Error processing audio with Whisper: {str(e)}")
  55. raise
  56. finally:
  57. # Clean up the temporary file
  58. try:
  59. os.unlink(temp_file_path)
  60. except Exception as e:
  61. logger.warning(f"Failed to delete temporary file: {str(e)}")