Mixpeek Logo
    Demo
    7 min read

    Learning to Rank with Bandit: Multimodal Search Guide

    Learn how to upgrade static rerankers into adaptive, personalized retrieval systems using Thompson Sampling. This guide breaks down how online bandit algorithms improve multimodal search by learning from every click.

    Learning to Rank with Bandit: Multimodal Search Guide
    Multimodal Search

    TL;DR: Static rerankers give you good results. Learned rerankers give you results that get better every time someone clicks. We'll show you how to build one using Thompson Sampling—the same technique powering TikTok's "For You" page.


    The Problem with Static Reranking

    You've built a solid multimodal retrieval pipeline. Your system combines:

    • CLIP embeddings for visual similarity
    • OCR for text matching
    • Audio transcriptions
    • Object detection (SAM, YOLO)
    • Metadata signals

    And you're fusing them with something like:

    final_score = 0.4 * clip_score + 0.3 * ocr_score + 0.2 * audio_score + 0.1 * metadata_score
    

    The question: How did you pick those weights?

    More importantly: Why are they the same for everyone?

    Your e-commerce customers care about visual similarity. Your edtech customers care about OCR and transcripts. Your ad-tech customers care about motion and audio. But your static weights treat them all the same.

    This is the problem Thompson Sampling solves.


    What is Online Bandit Learning?

    Think of it like A/B testing, but continuous and automatic.

    Traditional A/B test:

    1. Split traffic 50/50
    2. Wait 2 weeks
    3. Pick the winner
    4. Deploy

    Online bandit:

    1. Try different feature weights
    2. Learn from every click
    3. Shift weights toward what works
    4. Never stop learning
    Static Pipeline (Traditional) Query Retrieve Rerank Results (no feedback loop) Learned Pipeline (Adaptive) Query Retrieve Rerank Learn Thompson Results Click

    The "bandit" learns which features predict clicks for different contexts, users, or domains.


    Why Feature-Level Learning?

    Wrong approach: Learn a quality score for each document

    • Creates millions of parameters (one per document)
    • New documents have no history (cold start)
    • Doesn't generalize

    Right approach: Learn importance weights for each feature

    • 5-10 parameters total (one per feature)
    • New documents benefit immediately
    • Learns modality preferences, not document preferences
    ❌ Document-Level Learning (Wrong) doc_1: α=1, β=0 (1 interaction) doc_2: α=0, β=1 (1 interaction) doc_3: α=0, β=0 (no data) doc_...: α=0, β=0 (no data) doc_9999: α=0, β=0 (no data) Problem: Millions of params, no generalization ✅ Feature-Level Learning (Right) clip: α=12, β=3 → weight: 0.80 ocr: α=2, β=10 → weight: 0.17 audio: α=8, β=4 → weight: 0.67 metadata: α=6, β=6 → weight: 0.50 Benefit: 4 parameters, works for ALL documents

    Thompson Sampling in 3 Steps

    Step 1: Maintain a Beta Distribution per Feature

    For each feature (clip, ocr, audio, etc.), track:

    • α (alpha): Times this feature was strong in clicked results
    • β (beta): Times this feature was weak in clicked results
    # Initial state: no knowledge
    feature_params = {
        "clip": {"alpha": 1, "beta": 1},
        "ocr": {"alpha": 1, "beta": 1},
        "audio": {"alpha": 1, "beta": 1},
        "metadata": {"alpha": 1, "beta": 1}
    }
    

    Step 2: Sample Weights for Ranking

    When a query arrives, sample a weight from each distribution:

    import numpy as np
    
    def sample_weights(feature_params):
        """Sample feature weights using Thompson Sampling."""
        weights = {}
        for feature, params in feature_params.items():
            # Sample from Beta(α, β)
            weights[feature] = np.random.beta(
                params["alpha"], 
                params["beta"]
            )
        return weights
    
    # Example output:
    # {"clip": 0.78, "ocr": 0.15, "audio": 0.62, "metadata": 0.45}
    

    Then rank documents:

    def compute_score(doc_features, weights):
        """Combine features with learned weights."""
        return sum(
            weights[f] * doc_features[f] 
            for f in weights
        )
    

    Step 3: Update from Clicks

    When a user clicks a result, check which features were strong:

    def update_from_click(clicked_doc_features, feature_params, threshold=0.5):
        """Update α/β based on clicked document features."""
        for feature, value in clicked_doc_features.items():
            if value > threshold:
                # Feature was strong → credit it
                feature_params[feature]["alpha"] += 1
            else:
                # Feature was weak → penalize it
                feature_params[feature]["beta"] += 1
    

    That's it. Three simple steps, zero ML infrastructure.


    Handling Cold Start: Hierarchical Contexts

    What about new users with no click history?

    Solution: Learn at multiple levels simultaneously.

    NEW USER QUERY Looking for best context... Has 5+ clicks? (check interaction count) YES ✓ PERSONAL user_123 (best match) NO Has demographics? (segment, device, etc.) YES DEMOGRAPHIC segment_tech _enthusiast (similar users) NO GLOBAL (ultimate fallback)

    Key insight: Every click updates:

    1. Personal context (if user has history)
    2. Demographic context (benefits similar users)
    3. Global context (benefits everyone)

    New users get instant personalization from their demographic group.


    Let's say you're building video search. Initial state:

    clip:  Beta(1, 1)  → no idea if visual matters
    ocr:   Beta(1, 1)  → no idea if text matters
    audio: Beta(1, 1)  → no idea if audio matters
    

    After 20 clicks in e-commerce:

    clip:  Beta(18, 4)  → 0.82 avg weight (visual matters!)
    ocr:   Beta(5, 17)  → 0.23 avg weight (text doesn't)
    audio: Beta(3, 19)  → 0.14 avg weight (audio irrelevant)
    

    After 20 clicks in edtech:

    clip:  Beta(6, 16)  → 0.27 avg weight (visual less important)
    ocr:   Beta(19, 3)  → 0.86 avg weight (transcripts crucial)
    audio: Beta(17, 5)  → 0.77 avg weight (lecture audio key)
    

    Same features. Different domains. Automatic adaptation.


    Architecture Overview

    Learned Reranking Architecture Retrieval Pipeline Vector Search Top 100 candidates Static Rerank Cross-encoder scoring Learned Rerank Thompson Sampling Results to User Ranked documents User Interaction Click, view, etc. Learning Loop Extract Features Update α/β in Redis Storage Redis Feature weights (α, β params) ClickHouse Interaction log (analytics) Red highlight = where learning happens | Green dashed = feedback loop

    Key components:

    1. Redis: Stores current α/β for each feature (hot, fast)
    2. ClickHouse: Stores interaction history (cold, analytics)
    3. Learned Rerank Stage: Samples weights, computes scores
    4. Update Webhook: Processes clicks, updates Redis

    DIY Implementation Guide

    Prerequisites

    pip install numpy redis
    

    1. Storage Layer

    import redis
    import json
    
    class BanditStorage:
        def __init__(self, redis_url="redis://localhost:6379"):
            self.redis = redis.from_url(redis_url)
        
        def get_params(self, context_id, features):
            """Get α/β for all features in a context."""
            key = f"bandit:{context_id}"
            params = {}
            
            for feature in features:
                alpha = self.redis.hget(key, f"{feature}.alpha")
                beta = self.redis.hget(key, f"{feature}.beta")
                
                # Default to uniform prior
                params[feature] = {
                    "alpha": float(alpha) if alpha else 1.0,
                    "beta": float(beta) if beta else 1.0
                }
            
            return params
        
        def update_params(self, context_id, feature_updates):
            """Update α or β for features."""
            key = f"bandit:{context_id}"
            
            for feature, updates in feature_updates.items():
                if "alpha" in updates:
                    self.redis.hincrbyfloat(key, f"{feature}.alpha", updates["alpha"])
                if "beta" in updates:
                    self.redis.hincrbyfloat(key, f"{feature}.beta", updates["beta"])
    

    2. Thompson Sampling Algorithm

    import numpy as np
    
    class ThompsonSampling:
        @staticmethod
        def sample_weights(feature_params, exploration_bonus=1.0):
            """Sample feature weights from Beta distributions."""
            weights = {}
            
            for feature, params in feature_params.items():
                alpha = params["alpha"] * exploration_bonus
                beta = params["beta"] * exploration_bonus
                
                # Sample from Beta(α, β)
                weights[feature] = np.random.beta(alpha, beta)
            
            return weights
        
        @staticmethod
        def compute_score(doc_features, weights):
            """Weighted sum of features."""
            return sum(
                weights.get(f, 0.5) * v 
                for f, v in doc_features.items()
            )
    

    3. Learned Reranker

    class LearnedReranker:
        def __init__(self, storage):
            self.storage = storage
            self.sampler = ThompsonSampling()
        
        def rerank(self, documents, context_id, feature_threshold=0.5):
            """Rerank documents using learned feature weights."""
            
            # Extract feature names from first doc
            features = list(documents[0]["features"].keys())
            
            # Get current parameters
            params = self.storage.get_params(context_id, features)
            
            # Sample weights
            weights = self.sampler.sample_weights(params)
            
            # Score all documents
            for doc in documents:
                doc["learned_score"] = self.sampler.compute_score(
                    doc["features"], 
                    weights
                )
            
            # Sort by learned score
            documents.sort(key=lambda d: d["learned_score"], reverse=True)
            
            return documents
        
        def update_from_click(self, clicked_doc, context_id, threshold=0.5):
            """Update feature parameters from user click."""
            updates = {}
            
            for feature, value in clicked_doc["features"].items():
                if value > threshold:
                    # Strong feature gets credit
                    updates[feature] = {"alpha": 1}
                else:
                    # Weak feature gets penalty
                    updates[feature] = {"beta": 1}
            
            self.storage.update_params(context_id, updates)
    

    4. Usage Example

    # Initialize
    storage = BanditStorage()
    reranker = LearnedReranker(storage)
    
    # Sample documents with multimodal features
    documents = [
        {
            "id": "doc_1",
            "features": {
                "clip": 0.85,
                "ocr": 0.23,
                "audio": 0.67,
                "metadata": 0.91
            }
        },
        {
            "id": "doc_2",
            "features": {
                "clip": 0.45,
                "ocr": 0.89,
                "audio": 0.12,
                "metadata": 0.56
            }
        }
    ]
    
    # Rerank for a specific context (e.g., user_123 or segment_techies)
    context_id = "user_123"
    ranked_docs = reranker.rerank(documents, context_id)
    
    # User clicks the top result
    clicked_doc = ranked_docs[0]
    reranker.update_from_click(clicked_doc, context_id)
    
    # Next query will use updated weights!
    

    Performance Characteristics

    Metric Static Rerank Learned Rerank
    Latency overhead 0ms ~1-2ms (sampling)
    Storage per context 0 bytes ~200 bytes (α/β per feature)
    Training required None None
    Adaptation speed Never ~5-10 interactions
    Cold start handling N/A Hierarchical fallback
    Personalization No Yes (per user/segment)

    When to Use Bandit Learning

    ✅ Good fit:

    • Multimodal retrieval (images, video, audio, text)
    • User behavior varies by segment/domain
    • Need to optimize click-through or engagement
    • Fast iteration required (no retraining)
    • Cold start is important

    ❌ Not a fit:

    • Static, compliance-driven ranking
    • No user feedback loop
    • Need deep collaborative filtering
    • Regulatory constraints on personalization

    Common Pitfalls

    1. Learning per document instead of per feature

    # ❌ DON'T: Creates millions of parameters
    bandit_params["doc_12345"] = {"alpha": 3, "beta": 1}
    
    # ✅ DO: Learn which features matter
    bandit_params["clip"] = {"alpha": 12, "beta": 3}
    

    2. Forgetting cold start

    # ❌ DON'T: Fail for new users
    if user_has_history(user_id):
        context = f"user_{user_id}"
    else:
        raise Exception("No history!")
    
    # ✅ DO: Hierarchical fallback
    if user_has_history(user_id, min_interactions=5):
        context = f"user_{user_id}"
    elif user_has_demographics(user_id):
        context = f"segment_{user_segment}"
    else:
        context = "global"
    

    3. Not updating multiple contexts

    # ❌ DON'T: Only update personal
    update_params(f"user_{user_id}", feature_updates)
    
    # ✅ DO: Update personal + demographic + global
    for context in [personal_ctx, demographic_ctx, "global"]:
        update_params(context, feature_updates)
    

    Advanced: Context Strategies

    Different use cases need different context strategies:

    # E-commerce: User-level personalization
    context = f"user_{user_id}"
    
    # B2B SaaS: Team-level learning
    context = f"team_{team_id}"
    
    # Content platform: Demographic clustering
    context = f"segment_{age_group}_{device_type}"
    
    # Search engine: Query-type clustering
    context = f"query_type_{intent_category}"
    

    Monitoring Your Bandit

    Key metrics to track:

    def get_feature_stats(storage, context_id, features):
        """Get current state of feature learning."""
        params = storage.get_params(context_id, features)
        
        stats = {}
        for feature, p in params.items():
            alpha, beta = p["alpha"], p["beta"]
            
            stats[feature] = {
                "mean_weight": alpha / (alpha + beta),
                "confidence": alpha + beta,  # Higher = more certain
                "preference": "high" if alpha > beta else "low"
            }
        
        return stats
    
    # Example output:
    # {
    #   "clip": {"mean_weight": 0.82, "confidence": 23, "preference": "high"},
    #   "ocr": {"mean_weight": 0.19, "confidence": 21, "preference": "low"}
    # }
    

    Why This Beats Alternatives

    Approach Training Latency Cold Start Explainability
    Thompson Sampling None ~1ms Hierarchical High
    XGBoost reranker Batch (weekly) ~50ms Poor Medium
    Twin towers Offline (daily) ~10ms Poor Low
    Neural reranker Continuous ~20ms Medium Low
    LinUCB bandit None ~2ms Medium High

    Thompson Sampling gives you the best trade-off for most production systems.


    Real-World Impact

    At Mixpeek, we've seen customers using learned reranking achieve:

    • 23% improvement in click-through rate (e-commerce visual search)
    • 31% increase in watch time (video content discovery)
    • 40% faster convergence vs. batch retraining (ad creative ranking)
    • Cold start working in <5 interactions (new user onboarding)

    The key insight: The system learns what matters for YOUR domain automatically.


    Try It With Mixpeek

    Building this from scratch is fun for learning, but production systems need:

    • Distributed state management
    • Interaction tracking
    • Context isolation per tenant
    • Monitoring dashboards
    • Fallback strategies

    Mixpeek's retrieval API includes learned reranking as a single stage:

    {
      "pipeline": [
        {
          "stage_name": "retriever",
          "parameters": { ... }
        },
        {
          "stage_name": "rerank",
          "parameters": {
            "inference_name": "bge-reranker"
          }
        },
        {
          "stage_name": "learned_rerank",
          "parameters": {
            "algorithm": "thompson_sampling",
            "feature_fields": ["features.*"],
            "context_features": ["INPUT.user_id"]
          }
        }
      ]
    }
    

    Learn more about Mixpeek's learned reranking →


    Further Reading


    Questions? Drop them in the comments or join our Discord.