gkaucic's picture
More bbox detection fixes
7ce1f69
import base64
from io import BytesIO
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from typing import Any, Dict
MODEL_ID = "IDEA-Research/grounding-dino-base"
class EndpointHandler():
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoProcessor.from_pretrained(MODEL_ID)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
MODEL_ID).to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get('inputs')
image_b64 = inputs.get('image_b64')
prompt = inputs.get('prompt')
if image_b64 is None or prompt is None:
return {
'error': 'No image_b64 or prompt provided'
}
image_bytes = BytesIO(base64.b64decode(image_b64))
image = Image.open(image_bytes)
inputs = self.processor(images=image, text=prompt,
return_tensors='pt').to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
if len(results) == 0 or len(results[0]['boxes']) == 0:
return {
'error': 'No bounding boxes found'
}
bbox = results[0]['boxes'][0]
return {
'x1': bbox[0].item(),
'x2': bbox[1].item(),
'y1': bbox[2].item(),
'y2': bbox[3].item()
}