capri / processing_capri.py
Ligul's picture
Upload folder using huggingface_hub
fd6509b verified
from __future__ import annotations
from typing import Any
import torch
from transformers import AutoConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
class CapriProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor: BaseImageProcessor,
tokenizer: PreTrainedTokenizerBase,
prompt_prefix: str = "<image> Caption:",
image_token: str = "<image>",
pooled_embedding_dim: int = 768,
):
self.prompt_prefix = prompt_prefix
self.image_token = image_token
self.pooled_embedding_dim = pooled_embedding_dim
super().__init__(image_processor, tokenizer)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
kwargs.setdefault("use_fast", False)
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=kwargs.get("trust_remote_code", False),
)
processor.prompt_prefix = getattr(config, "prompt_prefix", processor.prompt_prefix)
processor.image_token = getattr(config, "image_token", processor.image_token)
processor.pooled_embedding_dim = getattr(config, "projector_in_dim", processor.pooled_embedding_dim)
return processor
def normalize_images(self, images) -> list[Any]:
if isinstance(images, torch.Tensor):
if images.ndim == 4:
return [images[i] for i in range(images.shape[0])]
return [images]
if isinstance(images, (list, tuple)):
return list(images)
return [images]
def normalize_pooled_embeddings(self, pooled_embeddings) -> torch.Tensor:
pooled = torch.as_tensor(pooled_embeddings)
if pooled.ndim == 1:
pooled = pooled.unsqueeze(0)
if pooled.ndim != 2:
raise ValueError("`pooled_embeddings` must be a 1D embedding or a 2D batch of embeddings.")
if pooled.shape[-1] != self.pooled_embedding_dim:
raise ValueError(
f"Expected pooled embedding dim {self.pooled_embedding_dim}, got {pooled.shape[-1]}."
)
return pooled
def __call__(
self,
images=None,
pooled_embeddings=None,
text=None,
return_tensors: str | None = "pt",
padding: bool | str = True,
truncation: bool = False,
max_length: int | None = None,
**kwargs: Any,
) -> BatchFeature:
if images is None and pooled_embeddings is None and text is None:
raise ValueError("Provide `images`, `pooled_embeddings`, or `text`.")
batch = {}
batch_size = None
if images is not None:
image_features = self.image_processor(images=images, return_tensors=return_tensors, **kwargs)
batch.update(dict(image_features))
batch_size = batch["pixel_values"].shape[0]
if pooled_embeddings is not None:
pooled = self.normalize_pooled_embeddings(pooled_embeddings)
batch["pooled_embeddings"] = pooled
batch_size = pooled.shape[0]
if text is None and batch_size is not None:
text = [self.prompt_prefix] * batch_size
if text is not None:
if isinstance(text, str):
text = [text]
tokenized = self.tokenizer(
text,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
)
batch.update(dict(tokenized))
return BatchFeature(data=batch, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)