mSFT: Addressing Dataset Mixtures Overfiting Heterogeneously in Multi-task SFT
The paper tackles the inefficiency of homogeneous compute allocation in multi-task supervised fine-tuning (SFT), where fast-learning tasks overfit while slow ones remain under-trained. The authors propose mSFT, an iterative algorithm that dynamically excludes overfitting sub-datasets and reverts to optimal checkpoints. Their approach consistently outperforms baselines across 6 models and 10 benchmarks, sometimes reducing compute while improving accuracy.
The paper presents a compelling solution to a genuine inefficiency in multi-task SFT training. The core insight—that sub-datasets overfit heterogeneously—is well-validated empirically, and the proposed mSFT algorithm addresses this with a practical iterative roll-back mechanism. As stated in the abstract, "Extensive evaluations demonstrate that mSFT consistently outperforms 4 baselines across 10 benchmarks and 6 base models." The method demonstrates consistent improvements (+1.8% average accuracy) across diverse model families and scales, with particularly strong gains on mathematical reasoning (+3.0%). The simplicity of the approach (only one new hyperparameter, compute budget $C$) adds to its practical appeal.
The empirical validation is thorough, covering multiple model architectures (OLMo 2, Qwen2.5, Qwen3) and demonstrating robustness across dataset sizes (9K-27K samples) and task granularities ($N \in \{5,10,15,21\}$). The motivation section provides clear evidence (Figure 2) that sub-datasets indeed peak at different compute levels, confirming that "heterogeneous learning dynamics cause faster-learning tasks to overfit early while slower ones remain under-fitted." The ablation comparing mSFT against naive "single roll-out" approaches (SRO SFT) effectively proves that the iterative roll-back mechanism is necessary due to parameter drift when datasets are excluded, with Figure 3 showing "an absolute shift of 0.91 epochs" in optimal compute. The analysis showing mSFT achieves lower training loss (Figure 9) provides mechanistic insight into why the method works—reducing gradient conflict by excluding post-peak datasets.
The paper assumes that validation/test sets for individual sub-datasets are available during training, which may not hold in realistic scenarios where only aggregate validation signals exist; the authors do not address how mSFT would function with only end-task metrics or noisy task-level signals. While disk management overhead is acknowledged, with "average storage footprint by approximately $4.44\times$ SFT," this assumes disk is not a bottleneck—a questionable assumption for 70B+ models or resource-constrained environments. The method requires frequent validation on all sub-datasets during roll-out phases, which could be expensive if sub-datasets are large. Algorithm 1 also admits: "In the rare case that no sub-dataset $\mathcal{D}_i$ over-fitted within the compute budget $C$, the algorithm continues without rolling back," suggesting potential instability when convergence rates are similar. Additionally, the theoretical justification remains limited—there is no formal analysis of convergence properties or optimality guarantees for the iterative exclusion strategy beyond empirical observation.
The evidence strongly supports the central claim that heterogeneous early-stopping improves performance. Comparisons against 4 baselines (standard SFT, Continual SFT, DynamixSFT, IES) are fair and conducted across identical hyperparameters where possible. However, the comparison with Continual SFT (which suffers $-2.2\%$ from catastrophic forgetting) is somewhat expected given the known limitations of sequential training; comparisons against more sophisticated task-interleaving methods would strengthen the evaluation. The FLOPs analysis (Figure 6) is rigorous, showing that at $C=1$, mSFT can reduce compute by 120.3 PFLOPs while improving performance by +3.4%. The claim that performance gains are "not from disproportionate gains on a few outlier tasks" is supported by reduced standard deviation across benchmarks (Figure 4 left), where "mSFT achieves the lowest levels of standard deviation across benchmarks (STD), indicating performance gains are not due to large outliers."
The paper provides detailed algorithmic descriptions (Algorithms 1-4) and experimental settings in Appendices E and F, including hyperparameters (learning rate $1\times 10^{-5}$, batch size 64, constant schedule). However, no code repository is mentioned or linked, which would be essential for exact reproduction of the iterative checkpointing logic. The FLOP counting methodology (Appendix F) is clearly specified using the Kaplan et al. formula, enabling independent verification of compute estimates. The authors note they "use a single seed (20) as preliminary experiments with Qwen2.5 3B on seeds 20, 30, 40 lead to virtually identical performance gains," though Table 6 confirms statistical significance ($p=0.023$). The use of specific HuggingFace datasets aids reproducibility for OLMo experiments, though exact preprocessing for other datasets is less detailed. Hyperparameter sensitivity is well-analyzed for compute budget $C$ (Figure 6), showing insensitivity, though other potential sensitivities (e.g., validation granularity every $1/4$ epochs) are not explored.
Current language model training commonly applies multi-task Supervised Fine-Tuning (SFT) using a homogeneous compute budget across all sub-datasets. This approach is fundamentally sub-optimal: heterogeneous learning dynamics cause faster-learning tasks to overfit early while slower ones remain under-fitted. To address this, we introduce mSFT, an iterative, overfitting-aware search algorithm for multi-task data mixtures. mSFT trains the model on an active mixture, identifies and excludes the earliest overfitting sub-dataset, and reverts to that specific optimal checkpoint before continuing. Extensive evaluations demonstrate that mSFT consistently outperforms 4 baselines across 10 benchmarks and 6 base models. Further analysis confirms mSFT maintains robust gains across diverse dataset sizes, task granularities, and is insensitive to its single new hyperparameter (compute budget). Notably, at low compute budget, mSFT can improve performance while lowering training FLOPs. Ultimately, mSFT establishes a practical overfitting-aware algorithm for multi-task SFT that maximizes the potential of models across diverse data mixtures.
Pick a starting point or write your own. Challenges run in the background, so you can keep reading while the AI investigates.
No challenges yet. Disagree with the review? Ask the AI to revisit a specific claim.