upload_hf_textbooks_ex.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import asyncio
  2. import os
  3. import uuid
  4. from datasets import load_dataset
  5. from r2r import R2RAsyncClient
  6. batch_size = 64
  7. total_batches = 8
  8. rest_time_in_s = 1
  9. def generate_id(label: str) -> uuid.UUID:
  10. return uuid.uuid5(uuid.NAMESPACE_DNS, label)
  11. def remove_file(file_path):
  12. try:
  13. os.remove(file_path)
  14. except Exception as e:
  15. print(f"Error removing file {file_path}: {e}")
  16. async def process_batch(client, batch):
  17. results = await client.ingest_files(batch)
  18. print(f"Submitted {len(results['results'])} files for processing")
  19. # Remove the processed files
  20. for file_path in batch:
  21. remove_file(file_path)
  22. async def process_dataset(client, dataset, batch_size):
  23. current_batch = []
  24. count = 0
  25. tasks = []
  26. for example in dataset:
  27. count += 1
  28. fname = f"example_{generate_id(example['completion'])}.txt"
  29. print(f"Streaming {fname} w/ completion {count} ...")
  30. with open(fname, "w") as f:
  31. f.write(example["completion"])
  32. current_batch.append(fname)
  33. if len(current_batch) == batch_size:
  34. task = asyncio.create_task(process_batch(client, current_batch))
  35. tasks.append(task)
  36. current_batch = []
  37. if len(tasks) == total_batches:
  38. await asyncio.gather(*tasks)
  39. tasks = [] # Reset the tasks list
  40. # await asyncio.sleep(rest_time_in_s)
  41. # Process any remaining files in the last batch
  42. if current_batch:
  43. await process_batch(client, current_batch)
  44. async def main():
  45. r2r_url = os.getenv("R2R_API_URL", "http://localhost:7272")
  46. print(f"Using R2R API at: {r2r_url}")
  47. client = R2RAsyncClient(r2r_url)
  48. dataset = load_dataset(
  49. "SciPhi/textbooks-are-all-you-need-lite", streaming=True
  50. )["train"]
  51. print("Submitting batches for processing ...")
  52. await process_dataset(client, dataset, batch_size)
  53. print("All batches submitted for processing")
  54. if __name__ == "__main__":
  55. asyncio.run(main())