| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import asyncio
- import base64
- import os
- from io import BytesIO
- from typing import Any, List
- import aiohttp
- import torch
- from PIL import Image
- from tqdm.asyncio import tqdm_asyncio
- class IlluinAPIModelWrapper:
- def __init__(
- self,
- model_name: str,
- **kwargs: Any,
- ):
- """Wrapper for Illuin API embedding model"""
- self.model_name = model_name
- self.url = model_name
- self.HEADERS = {
- "Accept": "application/json",
- "Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
- "Content-Type": "application/json",
- }
- @staticmethod
- def convert_image_to_base64(image: Image.Image) -> str:
- buffer = BytesIO()
- image.save(buffer, format="JPEG")
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
- async def post_images(self, session: aiohttp.ClientSession, encoded_images: List[str]):
- payload = {"inputs": {"images": encoded_images}}
- async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
- return await response.json()
- async def post_queries(self, session: aiohttp.ClientSession, queries: List[str]):
- payload = {"inputs": {"queries": queries}}
- async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
- return await response.json()
- async def call_api_queries(self, queries: List[str]):
- embeddings = []
- semaphore = asyncio.Semaphore(16)
- async with aiohttp.ClientSession() as session:
- async def sem_post(batch):
- async with semaphore:
- return await self.post_queries(session, batch)
- tasks = [asyncio.create_task(sem_post([batch])) for batch in queries]
- # ORDER-PRESERVING
- results = await tqdm_asyncio.gather(*tasks, desc="Query batches")
- for result in results:
- embeddings.extend(result.get("embeddings", []))
- return embeddings
- async def call_api_images(self, images_b64: List[str]):
- embeddings = []
- semaphore = asyncio.Semaphore(16)
- async with aiohttp.ClientSession() as session:
- async def sem_post(batch):
- async with semaphore:
- return await self.post_images(session, batch)
- tasks = [asyncio.create_task(sem_post([batch])) for batch in images_b64]
- # ORDER-PRESERVING
- results = await tqdm_asyncio.gather(*tasks, desc="Doc batches")
- for result in results:
- embeddings.extend(result.get("embeddings", []))
- return embeddings
- def forward_queries(self, queries: List[str]) -> torch.Tensor:
- response = asyncio.run(self.call_api_queries(queries))
- return response
- def forward_passages(self, passages: List[Image.Image]) -> torch.Tensor:
- response = asyncio.run(self.call_api_images([self.convert_image_to_base64(doc) for doc in passages]))
- return response
- if __name__ == "__main__":
- # Example usage
- client = IlluinAPIModelWrapper(
- model_name="https://sxeg6spz1yy8unh7.us-east-1.aws.endpoints.huggingface.cloud",
- )
- embed_queries = client.forward_queries(["What is the capital of France?", "Explain quantum computing."])
- images = [
- Image.new("RGB", (32, 32), color="white"),
- Image.new("RGB", (128, 128), color="black"),
- ]
- embed_images = client.forward_passages(images)
- print("Query embeddings shape:", len(embed_queries))
- print("Image embeddings shape:", len(embed_images))
|