| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import torch |
| from transformers import AutoConfig |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
|
|
|
| class CapriProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "SiglipImageProcessor" |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
| def __init__( |
| self, |
| image_processor: BaseImageProcessor, |
| tokenizer: PreTrainedTokenizerBase, |
| prompt_prefix: str = "<image> Caption:", |
| image_token: str = "<image>", |
| pooled_embedding_dim: int = 768, |
| ): |
| self.prompt_prefix = prompt_prefix |
| self.image_token = image_token |
| self.pooled_embedding_dim = pooled_embedding_dim |
| super().__init__(image_processor, tokenizer) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| kwargs.setdefault("use_fast", False) |
| processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
| config = AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| trust_remote_code=kwargs.get("trust_remote_code", False), |
| ) |
| processor.prompt_prefix = getattr(config, "prompt_prefix", processor.prompt_prefix) |
| processor.image_token = getattr(config, "image_token", processor.image_token) |
| processor.pooled_embedding_dim = getattr(config, "projector_in_dim", processor.pooled_embedding_dim) |
| return processor |
|
|
| def normalize_images(self, images) -> list[Any]: |
| if isinstance(images, torch.Tensor): |
| if images.ndim == 4: |
| return [images[i] for i in range(images.shape[0])] |
| return [images] |
| if isinstance(images, (list, tuple)): |
| return list(images) |
| return [images] |
|
|
| def normalize_pooled_embeddings(self, pooled_embeddings) -> torch.Tensor: |
| pooled = torch.as_tensor(pooled_embeddings) |
| if pooled.ndim == 1: |
| pooled = pooled.unsqueeze(0) |
| if pooled.ndim != 2: |
| raise ValueError("`pooled_embeddings` must be a 1D embedding or a 2D batch of embeddings.") |
| if pooled.shape[-1] != self.pooled_embedding_dim: |
| raise ValueError( |
| f"Expected pooled embedding dim {self.pooled_embedding_dim}, got {pooled.shape[-1]}." |
| ) |
| return pooled |
|
|
| def __call__( |
| self, |
| images=None, |
| pooled_embeddings=None, |
| text=None, |
| return_tensors: str | None = "pt", |
| padding: bool | str = True, |
| truncation: bool = False, |
| max_length: int | None = None, |
| **kwargs: Any, |
| ) -> BatchFeature: |
| if images is None and pooled_embeddings is None and text is None: |
| raise ValueError("Provide `images`, `pooled_embeddings`, or `text`.") |
|
|
| batch = {} |
| batch_size = None |
|
|
| if images is not None: |
| image_features = self.image_processor(images=images, return_tensors=return_tensors, **kwargs) |
| batch.update(dict(image_features)) |
| batch_size = batch["pixel_values"].shape[0] |
|
|
| if pooled_embeddings is not None: |
| pooled = self.normalize_pooled_embeddings(pooled_embeddings) |
| batch["pooled_embeddings"] = pooled |
| batch_size = pooled.shape[0] |
|
|
| if text is None and batch_size is not None: |
| text = [self.prompt_prefix] * batch_size |
|
|
| if text is not None: |
| if isinstance(text, str): |
| text = [text] |
| tokenized = self.tokenizer( |
| text, |
| add_special_tokens=False, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| ) |
| batch.update(dict(tokenized)) |
|
|
| return BatchFeature(data=batch, tensor_type=return_tensors) |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|