🚀 Model Predictive Task Sampling: Enhancing Robustness and Efficiency in Adaptive Learning
🧠 Introduction
Modern foundation models, such as large language models and generalist robots, have transformed adaptive learning across diverse tasks. However, ensuring robust generalization remains challenging, especially under task distribution shifts.
A key issue:
Uniform task sampling often underrepresents rare but critical scenarios, leading to catastrophic failures.
CVaR optimization improves robustness yet is computationally or annotation expensive in task evaluation.
While prior solutions prioritize challenging tasks via methods like CVaR sampling, they often suffer from prohibitively high computational costs.
✨ To address this, Model Predictive Task Sampling (MPTS) proposes a lightweight predictive strategy, learning to model task risks and actively guide task selection, leading to robust and efficient adaptation without massive resource overhead.
🎯 Motivation
Real-world demands for robust learning are rising:
- Autonomous driving must handle rare traffic accidents, not just typical road conditions.
- Embodied robots must generalize to diverse, sometimes unstructured environments.
- Large foundation models require resilience against rare distributional shifts.
Existing uniform sampling and adversarial task optimization approaches either: - Miss rare cases, or - Require full evaluations over all tasks per iteration → Expensive!
🔎 Key Insight of MPTS:
Learn a generative model to predict task risks directly from historical optimization results, thus amortizing task evaluations and guiding robust adaptation.
🛠️ Methodology: How MPTS Works
MPTS consists of three interacting modules:
1. 🔮 Risk Learner (Generative Risk Model)
- Learns to predict task-specific adaptation risk.
- Modeled via a probabilistic latent variable model (e.g., a VAE).
- Trained with streaming variational inference using past task batches.
Formally, the risk learner approximates: \(p(\ell \mid \tau, H_{1:t}; \theta_t)\), where \(H_{1:t}\) records past adaptation results.
2. 🚀 Amortized Evaluation (Efficient Risk Prediction)
Instead of evaluating each task exhaustively, MPTS: - Samples a pseudo batch of task candidates. - Predicts their risks using the trained risk learner. - Scores tasks without expensive rollouts.
Predicted task acquisition score:
\(a(\tau) = \gamma_0 \cdot \mathbb{E}[\ell] + \gamma_1 \cdot \sqrt{\text{Var}[\ell]}\)
where: - \(\gamma_0\): worst-case focus weight - \(\gamma_1\): exploration (uncertainty) weight
3. 🎯 Active Task Sampling (Guided Optimization)
Select top-B tasks maximizing \(a(\tau)\) for adaptation, balancing: - Worst-case risk minimization - Uncertainty-driven exploration
This ensures that: - Challenging tasks are prioritized. - Diverse scenarios are still explored.
🔍 Theoretical Guarantee
Under mild assumptions (e.g., Lipschitz continuity, boundedness),
MPTS enjoys difficulty ranking stability:
Predicted task rankings remain valid even after small parameter updates.
Thus, the risk learner reliably amortizes task evaluation across training iterations.
🔗 Related Work
Topic | MPTS Relationship |
---|---|
Robust Meta-Learning (CVaR, DRM) | Shares worst-case risk focus, improves efficiency |
Bayesian Active Learning | Leverages predictive uncertainty for task sampling |
Variational Meta-Learning | Learns latent task embeddings like Neural Processes |
Model Predictive Control | Plans active task selection based on risk prediction |
MPTS bridges these domains into a unified, efficient framework for task-robust adaptation.
🧪 Experiments: Where MPTS Excels
Benchmarks: - Few-shot Sinusoid Regression - Few-shot Image Classification (CLIP + MaPLe) - Meta Reinforcement Learning (HalfCheetah, Walker2D, Reacher) - Robotic Domain Randomization (Ergo-Reacher, Lunar-Lander) - Prompt-tuning Foundation Models (ImageNet and OOD datasets)
Metric | Gains with MPTS |
---|---|
Adaptation Robustness (CVaR) | ✅ Significant improvement |
Average Accuracy/Return | ✅ Higher across tasks |
Computational Efficiency | ✅ 60–80% less runtime vs DRM |
Memory Footprint | ✅ Marginal overhead |
Exploration Diversity | ✅ Maintained by pseudo-batch sampling |
✅ Result: MPTS outperforms strong baselines such as ERM, DRM, and GDRM across almost all domains.
🏆 Conclusion: Why MPTS Matters
- Efficiency: Reduces the cost of robust learning by avoiding massive environment interactions or expensive forward passes.
- Scalability: Easily deployable to large backbone models and expensive simulators.
- Generalizability: Compatible with zero-shot, few-shot, and fine-tuning paradigms.
MPTS shows that predicting what to learn next, not just learning blindly, is key to building truly adaptive foundation models for the future.
📚 References
@misc{wang2025modelpredictivetasksampling, title={Model Predictive Task Sampling for Efficient and Robust Adaptation}, author={Qi Cheems Wang and Zehao Xiao and Yixiu Mao and Yun Qu and Jiayi Shen and Yiqin Lv and Xiangyang Ji}, year={2025}, eprint={2501.11039}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2501.11039}, }