api_call.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import asyncio
  2. import base64
  3. import os
  4. from io import BytesIO
  5. from typing import Any, List
  6. import aiohttp
  7. import torch
  8. from PIL import Image
  9. from tqdm.asyncio import tqdm_asyncio
  10. class IlluinAPIModelWrapper:
  11. def __init__(
  12. self,
  13. model_name: str,
  14. **kwargs: Any,
  15. ):
  16. """Wrapper for Illuin API embedding model"""
  17. self.model_name = model_name
  18. self.url = model_name
  19. self.HEADERS = {
  20. "Accept": "application/json",
  21. "Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
  22. "Content-Type": "application/json",
  23. }
  24. @staticmethod
  25. def convert_image_to_base64(image: Image.Image) -> str:
  26. buffer = BytesIO()
  27. image.save(buffer, format="JPEG")
  28. return base64.b64encode(buffer.getvalue()).decode("utf-8")
  29. async def post_images(self, session: aiohttp.ClientSession, encoded_images: List[str]):
  30. payload = {"inputs": {"images": encoded_images}}
  31. async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
  32. return await response.json()
  33. async def post_queries(self, session: aiohttp.ClientSession, queries: List[str]):
  34. payload = {"inputs": {"queries": queries}}
  35. async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
  36. return await response.json()
  37. async def call_api_queries(self, queries: List[str]):
  38. embeddings = []
  39. semaphore = asyncio.Semaphore(16)
  40. async with aiohttp.ClientSession() as session:
  41. async def sem_post(batch):
  42. async with semaphore:
  43. return await self.post_queries(session, batch)
  44. tasks = [asyncio.create_task(sem_post([batch])) for batch in queries]
  45. # ORDER-PRESERVING
  46. results = await tqdm_asyncio.gather(*tasks, desc="Query batches")
  47. for result in results:
  48. embeddings.extend(result.get("embeddings", []))
  49. return embeddings
  50. async def call_api_images(self, images_b64: List[str]):
  51. embeddings = []
  52. semaphore = asyncio.Semaphore(16)
  53. async with aiohttp.ClientSession() as session:
  54. async def sem_post(batch):
  55. async with semaphore:
  56. return await self.post_images(session, batch)
  57. tasks = [asyncio.create_task(sem_post([batch])) for batch in images_b64]
  58. # ORDER-PRESERVING
  59. results = await tqdm_asyncio.gather(*tasks, desc="Doc batches")
  60. for result in results:
  61. embeddings.extend(result.get("embeddings", []))
  62. return embeddings
  63. def forward_queries(self, queries: List[str]) -> torch.Tensor:
  64. response = asyncio.run(self.call_api_queries(queries))
  65. return response
  66. def forward_passages(self, passages: List[Image.Image]) -> torch.Tensor:
  67. response = asyncio.run(self.call_api_images([self.convert_image_to_base64(doc) for doc in passages]))
  68. return response
  69. if __name__ == "__main__":
  70. # Example usage
  71. client = IlluinAPIModelWrapper(
  72. model_name="https://sxeg6spz1yy8unh7.us-east-1.aws.endpoints.huggingface.cloud",
  73. )
  74. embed_queries = client.forward_queries(["What is the capital of France?", "Explain quantum computing."])
  75. images = [
  76. Image.new("RGB", (32, 32), color="white"),
  77. Image.new("RGB", (128, 128), color="black"),
  78. ]
  79. embed_images = client.forward_passages(images)
  80. print("Query embeddings shape:", len(embed_queries))
  81. print("Image embeddings shape:", len(embed_images))