| import streamlit as st |
| import pandas as pd |
| import torch |
| import vec2text |
| from transformers import AutoModel, AutoTokenizer |
| from sklearn.decomposition import PCA |
| from utils import file_cache |
| from transformers import PreTrainedModel, PreTrainedTokenizer |
|
|
| |
| @st.cache_resource |
| def load_corrector(): |
| return vec2text.load_pretrained_corrector("gtr-base") |
|
|
| |
| @st.cache_data |
| def load_data(): |
| return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
|
| @st.cache_resource |
| def vector_compressor_from_config(): |
| |
| |
| return PCA(n_components=2) |
|
|
|
|
| @st.cache_data |
| @file_cache(".cache/reducer_embeddings.pickle") |
| def reduce_embeddings(embeddings): |
| reducer = vector_compressor_from_config() |
| return reducer.fit_transform(embeddings), reducer |
|
|
| |
| @st.cache_resource |
| def load_model_and_tokenizer(device="cpu"): |
| encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device) |
| tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
| return encoder, tokenizer |
|
|
|
|
| def get_gtr_embeddings(text_list: list[str], |
| encoder: PreTrainedModel, |
| tokenizer: PreTrainedTokenizer, |
| device: str, |
| ) -> torch.Tensor: |
|
|
| inputs = tokenizer(text_list, |
| return_tensors="pt", |
| max_length=128, |
| truncation=True, |
| padding="max_length",).to(device) |
|
|
| with torch.no_grad(): |
| model_output = encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
| hidden_state = model_output.last_hidden_state |
| embeddings = vec2text.models.model_utils.mean_pool(hidden_state, inputs['attention_mask']) |
|
|
| return embeddings |