🚀 Model Predictive Task Sampling: Enhancing Robustness and Efficiency in Adaptive Learning

🧠 Introduction

Modern foundation models, such as large language models and generalist robots, have transformed adaptive learning across diverse tasks. However, ensuring robust generalization remains challenging, especially under task distribution shifts.

A key issue:

Uniform task sampling often underrepresents rare but critical scenarios, leading to catastrophic failures.

CVaR optimization improves robustness yet is computationally or annotation expensive in task evaluation.

While prior solutions prioritize challenging tasks via methods like CVaR sampling, they often suffer from prohibitively high computational costs.

✨ To address this, Model Predictive Task Sampling (MPTS) proposes a lightweight predictive strategy, learning to model task risks and actively guide task selection, leading to robust and efficient adaptation without massive resource overhead.


🎯 Motivation

Real-world demands for robust learning are rising:

Existing uniform sampling and adversarial task optimization approaches either: - Miss rare cases, or - Require full evaluations over all tasks per iteration → Expensive!

🔎 Key Insight of MPTS:

Learn a generative model to predict task risks directly from historical optimization results, thus amortizing task evaluations and guiding robust adaptation.


🛠️ Methodology: How MPTS Works

MPTS consists of three interacting modules:

1. 🔮 Risk Learner (Generative Risk Model)

Formally, the risk learner approximates: \(p(\ell \mid \tau, H_{1:t}; \theta_t)\), where \(H_{1:t}\) records past adaptation results.


2. 🚀 Amortized Evaluation (Efficient Risk Prediction)

Instead of evaluating each task exhaustively, MPTS: - Samples a pseudo batch of task candidates. - Predicts their risks using the trained risk learner. - Scores tasks without expensive rollouts.

Predicted task acquisition score:

\(a(\tau) = \gamma_0 \cdot \mathbb{E}[\ell] + \gamma_1 \cdot \sqrt{\text{Var}[\ell]}\)

where: - \(\gamma_0\): worst-case focus weight - \(\gamma_1\): exploration (uncertainty) weight


3. 🎯 Active Task Sampling (Guided Optimization)

Select top-B tasks maximizing \(a(\tau)\) for adaptation, balancing: - Worst-case risk minimization - Uncertainty-driven exploration

This ensures that: - Challenging tasks are prioritized. - Diverse scenarios are still explored.


🔍 Theoretical Guarantee

Under mild assumptions (e.g., Lipschitz continuity, boundedness),
MPTS enjoys difficulty ranking stability:

Predicted task rankings remain valid even after small parameter updates.

Thus, the risk learner reliably amortizes task evaluation across training iterations.


🔗 Related Work

Topic MPTS Relationship
Robust Meta-Learning (CVaR, DRM) Shares worst-case risk focus, improves efficiency
Bayesian Active Learning Leverages predictive uncertainty for task sampling
Variational Meta-Learning Learns latent task embeddings like Neural Processes
Model Predictive Control Plans active task selection based on risk prediction

MPTS bridges these domains into a unified, efficient framework for task-robust adaptation.


🧪 Experiments: Where MPTS Excels

Benchmarks: - Few-shot Sinusoid Regression - Few-shot Image Classification (CLIP + MaPLe) - Meta Reinforcement Learning (HalfCheetah, Walker2D, Reacher) - Robotic Domain Randomization (Ergo-Reacher, Lunar-Lander) - Prompt-tuning Foundation Models (ImageNet and OOD datasets)

Metric Gains with MPTS
Adaptation Robustness (CVaR) ✅ Significant improvement
Average Accuracy/Return ✅ Higher across tasks
Computational Efficiency ✅ 60–80% less runtime vs DRM
Memory Footprint ✅ Marginal overhead
Exploration Diversity ✅ Maintained by pseudo-batch sampling

Result: MPTS outperforms strong baselines such as ERM, DRM, and GDRM across almost all domains.


🏆 Conclusion: Why MPTS Matters

MPTS shows that predicting what to learn next, not just learning blindly, is key to building truly adaptive foundation models for the future.


📚 References

@misc{wang2025modelpredictivetasksampling, title={Model Predictive Task Sampling for Efficient and Robust Adaptation}, author={Qi Cheems Wang and Zehao Xiao and Yixiu Mao and Yun Qu and Jiayi Shen and Yiqin Lv and Xiangyang Ji}, year={2025}, eprint={2501.11039}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2501.11039}, }