ColBERT-Zero โ€” INT8 ONNX

A dynamic-INT8 ONNX export of lightonai/ColBERT-Zero for CPU inference. ~150 MB on disk (vs. ~600 MB for the fp32 safetensors). The projection (768 โ†’ 128) and L2 normalization are fused into the graph, so the ONNX model returns per-token embeddings that are already unit-norm.

What's in this repo

File Purpose
model_int8.onnx Dynamic-INT8 quantized graph
tokenizer.json pylate-extended tokenizer (adds [Q] / [D] special tokens)
onnx_config.json prompts, prefixes, lengths, skiplist token IDs

Attribution

Only the serving artifacts are redistributed here; the upstream license applies.

Standalone usage (onnxruntime + tokenizers + numpy only)

No pylate, sentence-transformers, or torch needed at inference time.

import json
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer

local = snapshot_download("thomasht86/ColBERT-Zero-onnx-int8")
session = ort.InferenceSession(f"{local}/model_int8.onnx")
tokenizer = Tokenizer.from_file(f"{local}/tokenizer.json")
cfg = json.loads(open(f"{local}/onnx_config.json").read())

PAD = cfg["pad_token_id"]
SKIP = {PAD} | {
    tid for w in cfg["skiplist_words"]
    if (tid := tokenizer.token_to_id(w)) is not None
}


def encode(texts, prompt, prefix, max_len, pad_to_max):
    tokenizer.enable_truncation(max_length=max_len)
    if pad_to_max:
        tokenizer.enable_padding(length=max_len, pad_id=PAD, pad_token="[PAD]")
    else:
        tokenizer.no_padding()
    encs = tokenizer.encode_batch([prompt + prefix + t for t in texts])
    L = max(len(e.ids) for e in encs)
    ids = np.full((len(encs), L), PAD, dtype=np.int64)
    mask = np.zeros_like(ids)
    for i, e in enumerate(encs):
        ids[i, : len(e.ids)] = e.ids
        mask[i, : len(e.ids)] = e.attention_mask
    out = session.run(["output"], {"input_ids": ids, "attention_mask": mask})[0]
    return ids, out  # out shape [B, L, 128], L2-normalized


def maxsim(query, documents):
    q_ids, q_out = encode(
        [query], cfg["query_prompt"], cfg["query_prefix"], cfg["query_length"], pad_to_max=True
    )
    d_ids, d_out = encode(
        documents, cfg["document_prompt"], cfg["document_prefix"], cfg["document_length"], pad_to_max=False
    )
    q_emb = q_out[0][q_ids[0] != PAD]  # drop query padding
    scores = []
    for i in range(len(documents)):
        keep = np.array([t not in SKIP for t in d_ids[i]])
        d_emb = d_out[i][keep]  # drop padding + punctuation
        scores.append(float((q_emb @ d_emb.T).max(axis=1).sum()))
    return scores


docs = [
    "Paris is the capital of France.",
    "Berlin is the capital of Germany.",
    "The Eiffel Tower is a famous Parisian landmark.",
]
for doc, s in sorted(zip(docs, maxsim("what is the capital of France?", docs)),
                     key=lambda x: -x[1]):
    print(f"{s:6.3f}  {doc}")

Output:

 9.xxx  Paris is the capital of France.
 8.xxx  The Eiffel Tower is a famous Parisian landmark.
 7.xxx  Berlin is the capital of Germany.

ONNX I/O

Tensor Shape Dtype
input_ids [batch, seq] int64
attention_mask [batch, seq] int64
output [batch, seq, 128] float32 (unit-norm along last dim)

query_length = 39, document_length = 519; both axes are dynamic.

Reproducing the export

uv run --python 3.12 --with pylate --with torch --with onnx \
       --with onnxruntime --with onnxscript --with accelerate \
       scripts/export_colbert_onnx.py

See scripts/export_colbert_onnx.py for the underlying pattern.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for thomasht86/ColBERT-Zero-onnx-int8

Quantized
(9)
this model