Thursday, February 12, 2026
HomeArtificial IntelligenceThe right way to Construct a Matryoshka-Optimized Sentence Embedding Mannequin for Extremely-Quick...

The right way to Construct a Matryoshka-Optimized Sentence Embedding Mannequin for Extremely-Quick Retrieval with 64-Dimension Truncation

On this tutorial, we fine-tune a Sentence-Transformers embedding mannequin utilizing Matryoshka Illustration Studying in order that the earliest dimensions of the vector carry essentially the most helpful semantic sign. We prepare with MatryoshkaLoss on triplet knowledge after which validate the important thing promise of MRL by benchmarking retrieval high quality after truncating embeddings to 64, 128, and 256 dimensions. On the finish, we save the tuned mannequin and exhibit the best way to load it with a small truncate_dim setting for quick and memory-efficient vector search. Try the FULL CODES right here.

!pip -q set up -U sentence-transformers datasets speed up


import math
import random
import numpy as np
import torch


from datasets import load_dataset
from torch.utils.knowledge import DataLoader


from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers import losses
from sentence_transformers.util import cos_sim




def set_seed(seed=42):
   random.seed(seed)
   np.random.seed(seed)
   torch.manual_seed(seed)
   torch.cuda.manual_seed_all(seed)


set_seed(42)

We set up the required libraries and import all the mandatory modules for coaching and analysis. We set a deterministic seed, so our sampling and coaching habits keep constant throughout runs. We additionally guarantee PyTorch and CUDA RNGs are aligned when a GPU is out there. Try the FULL CODES right here.

@torch.no_grad()
def retrieval_metrics_mrr_recall_at_k(
   mannequin,
   queries,
   corpus,
   qrels,
   dims_list=(64, 128, 256, None),
   ok=10,
   batch_size=64,
):
   system = "cuda" if torch.cuda.is_available() else "cpu"
   mannequin.to(system)


   qids = record(queries.keys())
   docids = record(corpus.keys())


   q_texts = [queries[qid] for qid in qids]
   d_texts = [corpus[did] for did in docids]


   q_emb = mannequin.encode(q_texts, batch_size=batch_size, convert_to_tensor=True, normalize_embeddings=True)
   d_emb = mannequin.encode(d_texts, batch_size=batch_size, convert_to_tensor=True, normalize_embeddings=True)


   outcomes = {}


   for dim in dims_list:
       if dim is None:
           qe = q_emb
           de = d_emb
           dim_name = "full"
       else:
           qe = q_emb[:, :dim]
           de = d_emb[:, :dim]
           dim_name = str(dim)
           qe = torch.nn.practical.normalize(qe, p=2, dim=1)
           de = torch.nn.practical.normalize(de, p=2, dim=1)


       sims = cos_sim(qe, de)


       mrr_total = 0.0
       recall_total = 0.0


       for i, qid in enumerate(qids):
           rel = qrels.get(qid, set())
           if not rel:
               proceed


           topk = torch.topk(sims[i], ok=min(ok, sims.form[1]), largest=True).indices.tolist()
           topk_docids = [docids[j] for j in topk]


           recall_total += 1.0 if any(d in rel for d in topk_docids) else 0.0


           rr = 0.0
           for rank, d in enumerate(topk_docids, begin=1):
               if d in rel:
                   rr = 1.0 / rank
                   break
           mrr_total += rr


       denom = max(1, len(qids))
       outcomes[dim_name] = {f"MRR@{ok}": mrr_total / denom, f"Recall@{ok}": recall_total / denom}


   return outcomes




def pretty_print(outcomes, title):
   print("n" + "=" * 80)
   print(title)
   print("=" * 80)
   for dim, metrics in outcomes.objects():
       print(f"dim={dim:>4} | " + " | ".be part of([f"{k}={v:.4f}" for k, v in metrics.items()]))

We implement a light-weight retrieval evaluator that encodes queries and paperwork, computes cosine similarity, and studies MRR@10 and Recall@10. We re-normalize embeddings after truncation so smaller prefixes stay comparable in cosine house. We additionally added a compact printer to make earlier than/after comparisons straightforward to learn. Try the FULL CODES right here.

DATASET_ID = "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1"
SUBSET = "triplet-hard"
SPLIT = "prepare"


TRAIN_SAMPLES = 4000
EVAL_QUERIES = 300


stream = load_dataset(DATASET_ID, SUBSET, cut up=SPLIT, streaming=True)


train_examples = []
eval_queries = {}
eval_corpus = {}
eval_qrels = {}


doc_id_counter = 0
qid_counter = 0


for row in stream:
   q = (row.get("question") or "").strip()
   pos = (row.get("optimistic") or "").strip()
   neg = (row.get("destructive") or "").strip()


   if not q or not pos or not neg:
       proceed


   train_examples.append(InputExample(texts=[q, pos, neg]))


   if len(eval_queries) < EVAL_QUERIES:
       qid = f"q{qid_counter}"
       qid_counter += 1


       pos_id = f"d{doc_id_counter}"; doc_id_counter += 1
       neg_id = f"d{doc_id_counter}"; doc_id_counter += 1


       eval_queries[qid] = q
       eval_corpus[pos_id] = pos
       eval_corpus[neg_id] = neg
       eval_qrels[qid] = {pos_id}


   if len(train_examples) >= TRAIN_SAMPLES and len(eval_queries) >= EVAL_QUERIES:
       break


print(len(train_examples), len(eval_queries), len(eval_corpus))

We stream a mined MS MARCO triplet dataset and construct each a coaching set (queries, positives, negatives) and a tiny IR benchmark set. We map every question to a related optimistic doc and embrace a destructive doc to make retrieval significant. We cease early to maintain the run Colab-friendly whereas nonetheless massive sufficient to point out truncation results.

MODEL_ID = "BAAI/bge-base-en-v1.5"


system = "cuda" if torch.cuda.is_available() else "cpu"
mannequin = SentenceTransformer(MODEL_ID, system=system)
full_dim = mannequin.get_sentence_embedding_dimension()


baseline = retrieval_metrics_mrr_recall_at_k(
   mannequin,
   queries=eval_queries,
   corpus=eval_corpus,
   qrels=eval_qrels,
   dims_list=(64, 128, 256, None),
   ok=10,
)
pretty_print(baseline, "BEFORE")

We load a powerful base embedding mannequin and file its full embedding dimension. We run the baseline analysis throughout 64/128/256/full dimensions to see how truncation behaves earlier than any coaching. We print the outcomes so we will later examine whether or not MRL improves the early-dimension high quality.

batch_size = 16
epochs = 1
warmup_steps = 100


train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, drop_last=True)


base_loss = losses.MultipleNegativesRankingLoss(mannequin=mannequin)


mrl_dims = [full_dim, 512, 256, 128, 64] if full_dim >= 768 else [full_dim, 256, 128, 64]
mrl_loss = losses.MatryoshkaLoss(
   mannequin=mannequin,
   loss=base_loss,
   matryoshka_dims=mrl_dims
)


mannequin.match(
   train_objectives=[(train_loader, mrl_loss)],
   epochs=epochs,
   warmup_steps=warmup_steps,
   show_progress_bar=True,
)


after = retrieval_metrics_mrr_recall_at_k(
   mannequin,
   queries=eval_queries,
   corpus=eval_corpus,
   qrels=eval_qrels,
   dims_list=(64, 128, 256, None),
   ok=10,
)
pretty_print(after, "AFTER")


out_dir = "mrl-msmarco-demo"
mannequin.save(out_dir)


m64 = SentenceTransformer(out_dir, truncate_dim=64)
emb = m64.encode(
   ["what is the liberal arts?", "liberal arts covers humanities and sciences"],
   normalize_embeddings=True
)
print(emb.form)

We create a MultipleNegativesRankingLoss and wrap it with MatryoshkaLoss utilizing a descending record of goal prefix dimensions. We fine-tune the mannequin on the triplets, then re-run the identical truncation benchmark to measure the development in retention. Additionally, we save the mannequin and reload it with truncate_dim=64 to verify sensible utilization for compact retrieval.

In conclusion, we efficiently skilled a Matryoshka-optimized embedding mannequin that maintains sturdy retrieval efficiency even once we truncate vectors to small prefix dimensions, similar to 64. We verified the impact by evaluating baseline versus post-training retrieval metrics throughout a number of truncation sizes and the complete embedding. With the saved mannequin and the truncate_dim loading sample, we now have a clear workflow for constructing smaller, sooner vector indexes whereas retaining the choice to rerank with full-dimensional embeddings.


Try the FULL CODES right here. Additionally, be happy to observe us on Twitter and don’t neglect to affix our 100k+ ML SubReddit and Subscribe to our Publication. Wait! are you on telegram? now you possibly can be part of us on telegram as effectively.


RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Most Popular

Recent Comments