utils.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import logging
  2. from litellm import get_model_info, token_counter
  3. logger = logging.getLogger(__name__)
  4. def truncate_texts_to_token_limit(texts: list[str], model: str) -> list[str]:
  5. """
  6. Truncate texts to fit within the model's token limit.
  7. """
  8. try:
  9. model_info = get_model_info(model=model)
  10. if not model_info.get("max_input_tokens"):
  11. return texts # No truncation needed if no limit specified
  12. truncated_texts = []
  13. for text in texts:
  14. text_tokens = token_counter(model=model, text=text)
  15. assert model_info["max_input_tokens"]
  16. if text_tokens > model_info["max_input_tokens"]:
  17. estimated_chars = (
  18. model_info["max_input_tokens"] * 3
  19. ) # Estimate 3 chars per token
  20. truncated_text = text[:estimated_chars]
  21. truncated_texts.append(truncated_text)
  22. logger.warning(
  23. f"Truncated text from {text_tokens} to ~{model_info['max_input_tokens']} tokens"
  24. )
  25. else:
  26. truncated_texts.append(text)
  27. return truncated_texts
  28. except Exception as e:
  29. logger.warning(f"Failed to truncate texts: {str(e)}")
  30. return texts # Return original texts if truncation fails