Explainability-Guided Model Routing

Using XAI to Predict and Minimize Compute Cost

Project Overview

This project explores how explainable AI (XAI) can make deep learning systems more compute-efficient. Instead of running every input through a large, energy-hungry model, the system uses explainability signals to predict how difficult an input is and routes it to the most efficient model that can handle it.

The core idea: the AI analyzes the image, explains what it sees, and decides whether it needs a big brain or a small one. This combines responsible AI, interpretability, and green AI efficiency.

System Workflow

STEP 1
Image Input
User uploads an image. The system prepares it for analysis by resizing to 224x224 pixels and normalizing pixel values.
STEP 2
XAI Feature Extraction
A lightweight probe model (MobileNetV3-Small) analyzes the image and extracts 7 explainability features that indicate complexity: attention patterns, gradient responses, spatial structure, and prediction confidence.
STEP 3
Complexity Prediction
A Gradient Boosting Classifier takes the 7 features and predicts which model tier (TINY, LIGHT, MEDIUM, or HEAVY) is the minimum required for accurate classification.
STEP 4
Dynamic Model Routing
The image is routed to the predicted tier. Simple images go to fast, lightweight models. Complex images go to larger, more accurate models.
STEP 5
Inference and Explanation
The selected model runs inference and outputs the prediction along with visual explanations (GradCAM heatmap) and compute statistics (latency, FLOPs).

Why GradCAM?

GradCAM (Gradient-weighted Class Activation Mapping) produces a heatmap showing which regions of the image the model focuses on when making a prediction. We use it for two purposes:

Purpose Explanation
Complexity Estimation If attention is scattered across many regions, the image is complex. If focused on one area, it is simple.
Decision Transparency Users can visually verify that the model is looking at relevant parts of the image, building trust in the system.

XAI Feature Values

We extract 7 features from each image. These features act as a "complexity fingerprint" that tells us how difficult the image is for the model to process.

Feature What It Measures
Attention Entropy How scattered or focused the model's attention is across the image.
Saliency Sparsity What fraction of the image contains important information.
Gradient Magnitude How strongly the image pixels influence the model's prediction.
Feature Variance How stable or variable the importance is across different regions.
Spatial Complexity How many edges, textures, and color variations exist in the image.
Confidence Margin The gap between the top prediction and second prediction probabilities.
Activation Sparsity How many neurons in the network are activated by this image.

Complexity Profile

The complexity profile is a radar chart that visualizes all 7 features at once. It allows quick comparison between the current image and typical profiles for each tier.

Example: Simple vs Complex Image

Simple Image (Single Object)

Attention Entropy
0.15
Spatial Complexity
0.79
Activation Sparsity
0.65

Complex Image (Crowded Scene)

Attention Entropy
0.85
Spatial Complexity
0.98
Activation Sparsity
0.10

Simple images show low attention entropy (focused), moderate spatial complexity, and high activation sparsity (few neurons needed). Complex images show the opposite pattern.

SHAP Analysis

SHAP (SHapley Additive exPlanations) quantifies how much each feature contributed to the routing decision. It answers the question: "Why did this image get routed to this tier?"

Contribution Meaning
Positive (Red) This feature pushed the decision toward a more complex tier (MEDIUM/HEAVY).
Negative (Blue) This feature pushed the decision toward a simpler tier (TINY/LIGHT).

For example, if Attention Entropy contributes +0.35, it means the scattered attention pattern strongly suggested this image needs a heavier model. SHAP makes every routing decision interpretable.

Models Used

TINY: MobileNetV3-Small
2.5M params 60 MFLOPs 8ms latency
Smallest and fastest model, designed for mobile devices. Best for simple images with single objects and clean backgrounds.
LIGHT: MobileNetV3-Large
5.4M params 220 MFLOPs 15ms latency
Larger mobile-optimized model with more capacity. Handles images with a few objects and moderate background complexity.
MEDIUM: EfficientNet-B0
5.3M params 400 MFLOPs 25ms latency
Uses compound scaling for optimal depth-width-resolution balance. Suitable for moderately complex scenes with multiple interacting objects.
HEAVY: ResNet-50
25.6M params 4100 MFLOPs 80ms latency
Deep residual network with 50 layers. Most accurate and robust, used for crowded scenes, fine-grained details, and ambiguous images.

How Tier Assignment Works

The complexity predictor uses feature boundaries learned from training data. The primary routing signal is Attention Entropy (60% importance).

Tier Attention Entropy Activation Sparsity Typical Content
TINY 0.10 - 0.35 0.45 - 0.70 Single object, plain background
LIGHT 0.30 - 0.55 0.25 - 0.55 Few objects, some context
MEDIUM 0.50 - 0.75 0.05 - 0.30 Multiple objects, busy scene
HEAVY 0.70 - 0.95 0.02 - 0.18 Crowded, complex, ambiguous

Understanding the Routing Decision

When the system processes an image, it outputs a routing decision with four key metrics:

Tier LIGHT
Confidence 0.6759
Latency 15ms
FLOPs 220M
Metric What It Means
Tier The selected model tier (TINY, LIGHT, MEDIUM, or HEAVY). This is the smallest model predicted to handle the image correctly.
Confidence How certain the predictor is about this routing decision (0.0 to 1.0). A confidence of 0.6759 means 67.59% certainty that LIGHT is the correct tier.
Latency Expected inference time in milliseconds. Lower is faster. LIGHT at 15ms is 5.3x faster than HEAVY at 80ms.
FLOPs Floating point operations required (in millions). Measures computational cost. LIGHT at 220M uses 18.6x less compute than HEAVY at 4100M.

Results Summary

Metric Baseline (Always HEAVY) XAI Routing Improvement
Routing Accuracy - 86.7% -
Average Latency 80ms 21.7ms 73% faster
Average FLOPs 4100M 517M 87% less compute