Using XAI to Predict and Minimize Compute Cost
This project explores how explainable AI (XAI) can make deep learning systems more compute-efficient. Instead of running every input through a large, energy-hungry model, the system uses explainability signals to predict how difficult an input is and routes it to the most efficient model that can handle it.
The core idea: the AI analyzes the image, explains what it sees, and decides whether it needs a big brain or a small one. This combines responsible AI, interpretability, and green AI efficiency.
GradCAM (Gradient-weighted Class Activation Mapping) produces a heatmap showing which regions of the image the model focuses on when making a prediction. We use it for two purposes:
| Purpose | Explanation |
|---|---|
| Complexity Estimation | If attention is scattered across many regions, the image is complex. If focused on one area, it is simple. |
| Decision Transparency | Users can visually verify that the model is looking at relevant parts of the image, building trust in the system. |
We extract 7 features from each image. These features act as a "complexity fingerprint" that tells us how difficult the image is for the model to process.
| Feature | What It Measures |
|---|---|
| Attention Entropy | How scattered or focused the model's attention is across the image. |
| Saliency Sparsity | What fraction of the image contains important information. |
| Gradient Magnitude | How strongly the image pixels influence the model's prediction. |
| Feature Variance | How stable or variable the importance is across different regions. |
| Spatial Complexity | How many edges, textures, and color variations exist in the image. |
| Confidence Margin | The gap between the top prediction and second prediction probabilities. |
| Activation Sparsity | How many neurons in the network are activated by this image. |
The complexity profile is a radar chart that visualizes all 7 features at once. It allows quick comparison between the current image and typical profiles for each tier.
Simple images show low attention entropy (focused), moderate spatial complexity, and high activation sparsity (few neurons needed). Complex images show the opposite pattern.
SHAP (SHapley Additive exPlanations) quantifies how much each feature contributed to the routing decision. It answers the question: "Why did this image get routed to this tier?"
| Contribution | Meaning |
|---|---|
| Positive (Red) | This feature pushed the decision toward a more complex tier (MEDIUM/HEAVY). |
| Negative (Blue) | This feature pushed the decision toward a simpler tier (TINY/LIGHT). |
For example, if Attention Entropy contributes +0.35, it means the scattered attention pattern strongly suggested this image needs a heavier model. SHAP makes every routing decision interpretable.
The complexity predictor uses feature boundaries learned from training data. The primary routing signal is Attention Entropy (60% importance).
| Tier | Attention Entropy | Activation Sparsity | Typical Content |
|---|---|---|---|
| TINY | 0.10 - 0.35 | 0.45 - 0.70 | Single object, plain background |
| LIGHT | 0.30 - 0.55 | 0.25 - 0.55 | Few objects, some context |
| MEDIUM | 0.50 - 0.75 | 0.05 - 0.30 | Multiple objects, busy scene |
| HEAVY | 0.70 - 0.95 | 0.02 - 0.18 | Crowded, complex, ambiguous |
When the system processes an image, it outputs a routing decision with four key metrics:
| Metric | What It Means |
|---|---|
| Tier | The selected model tier (TINY, LIGHT, MEDIUM, or HEAVY). This is the smallest model predicted to handle the image correctly. |
| Confidence | How certain the predictor is about this routing decision (0.0 to 1.0). A confidence of 0.6759 means 67.59% certainty that LIGHT is the correct tier. |
| Latency | Expected inference time in milliseconds. Lower is faster. LIGHT at 15ms is 5.3x faster than HEAVY at 80ms. |
| FLOPs | Floating point operations required (in millions). Measures computational cost. LIGHT at 220M uses 18.6x less compute than HEAVY at 4100M. |
| Metric | Baseline (Always HEAVY) | XAI Routing | Improvement |
|---|---|---|---|
| Routing Accuracy | - | 86.7% | - |
| Average Latency | 80ms | 21.7ms | 73% faster |
| Average FLOPs | 4100M | 517M | 87% less compute |
Explainability-Guided Model Routing | Generated using Claude