from __future__ import annotations from typing import Any import torch from transformers import AutoConfig from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import BaseImageProcessor from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase class CapriProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "SiglipImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__( self, image_processor: BaseImageProcessor, tokenizer: PreTrainedTokenizerBase, prompt_prefix: str = " Caption:", image_token: str = "", pooled_embedding_dim: int = 768, ): self.prompt_prefix = prompt_prefix self.image_token = image_token self.pooled_embedding_dim = pooled_embedding_dim super().__init__(image_processor, tokenizer) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): kwargs.setdefault("use_fast", False) processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False), ) processor.prompt_prefix = getattr(config, "prompt_prefix", processor.prompt_prefix) processor.image_token = getattr(config, "image_token", processor.image_token) processor.pooled_embedding_dim = getattr(config, "projector_in_dim", processor.pooled_embedding_dim) return processor def normalize_images(self, images) -> list[Any]: if isinstance(images, torch.Tensor): if images.ndim == 4: return [images[i] for i in range(images.shape[0])] return [images] if isinstance(images, (list, tuple)): return list(images) return [images] def normalize_pooled_embeddings(self, pooled_embeddings) -> torch.Tensor: pooled = torch.as_tensor(pooled_embeddings) if pooled.ndim == 1: pooled = pooled.unsqueeze(0) if pooled.ndim != 2: raise ValueError("`pooled_embeddings` must be a 1D embedding or a 2D batch of embeddings.") if pooled.shape[-1] != self.pooled_embedding_dim: raise ValueError( f"Expected pooled embedding dim {self.pooled_embedding_dim}, got {pooled.shape[-1]}." ) return pooled def __call__( self, images=None, pooled_embeddings=None, text=None, return_tensors: str | None = "pt", padding: bool | str = True, truncation: bool = False, max_length: int | None = None, **kwargs: Any, ) -> BatchFeature: if images is None and pooled_embeddings is None and text is None: raise ValueError("Provide `images`, `pooled_embeddings`, or `text`.") batch = {} batch_size = None if images is not None: image_features = self.image_processor(images=images, return_tensors=return_tensors, **kwargs) batch.update(dict(image_features)) batch_size = batch["pixel_values"].shape[0] if pooled_embeddings is not None: pooled = self.normalize_pooled_embeddings(pooled_embeddings) batch["pooled_embeddings"] = pooled batch_size = pooled.shape[0] if text is None and batch_size is not None: text = [self.prompt_prefix] * batch_size if text is not None: if isinstance(text, str): text = [text] tokenized = self.tokenizer( text, add_special_tokens=False, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, ) batch.update(dict(tokenized)) return BatchFeature(data=batch, tensor_type=return_tensors) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs)