Decouple-Then-Merge:
Finetune Diffusion Models as Multi-Task Learning

1 Shanghai Jiao Tong University    2 Tsinghua University   
3 Shanghai AI Laboratory    4 miguo.ai   
CVPR 2025

Denotes Corresponding Authors
Teaser.

Figure 1: (a) Cosine similarity between gradients at different timesteps on CIFAR10 & distribution of gradients similarity in t ∈ [0, 1000] and t ∈ [0, 250]. Non-adjacent timesteps have low similarity, indicating conflicts during their training. In contrast, adjacent timesteps have similar gradients. (b) & (c): Comparison between the traditional and our training paradigm: The previous paradigm trains one diffusion model on all timesteps, leading to conflicts in different timesteps. Our method addresses this problem by decoupling the training of diffusion models in N different timestep ranges.

Abstract

Diffusion models are trained by learning a sequence of models that reverse each step of noise corruption. Typically, the model parameters are fully shared across multiple timesteps to enhance training efficiency. However, since the denoising tasks differ at each timestep, the gradients computed at different timesteps may conflict, potentially degrading the overall performance of image generation. To solve this issue, this work proposes a Decouple-then-Merge (DeMe) framework, which begins with a pretrained model and finetunes separate models tailored to specific timesteps. We introduce several improved techniques during the finetuning stage to promote effective knowledge sharing while minimizing training interference across timesteps. Finally, after finetuning, these separate models can be merged into a single model in the parameter space, ensuring efficient and practical inference. Experimental results show significant generation quality improvements upon 6 benchmarks including Stable Diffusion on COCO30K, ImageNet1K, PartiPrompts, and DDPM on LSUN Church, LSUN Bedroom, and CIFAR10. Code is available at GitHub.

Teaser.

Figure 2: Pipeline of our framework. The following training techniques are incorporated into the finetuning process: Consistency loss preserves the original knowledge of diffusion models learned at all timesteps by minimizing the difference between pre-finetuned and post-finetuned diffusion models. Probabilistic sampling strategy samples from both the corresponding and other timesteps with different probabilities, helping the diffusion model overcome forgetting knowledge from other timesteps. Channel-wise projection enables the diffusion model to directly capture the feature difference in channel dimension. Model merging scheme merges the parameters of all the finetuned models into one unified model to promote the knowledge sharing across different timestep ranges.

Qualitative Results on Text-to-Image Generation

Qualitative Results on Unconditional Image Generation

Quantitative Results

Table 1: Quantitative results on text-to-image generation

Table 1: Quantitative results on text-to-image generation

Table 2: Quantitative results on unconditional image generation

Table 2: Quantitative results on unconditional image generation

Why DeMe works?

Table 1: Quantitative results on text-to-image generation

Loss landscape of the pretrained diffusion model in different timestep ranges on CIFAR10. We use dimension reduction methods to handle high-dimensional neural networks. Contour line density reflects the frequency of loss variations (i.e., gradients), with blue representing low loss and red representing high loss. The pretrained model resides at the critical point (with zero gradients) with sparse contour lines for the overall timesteps t ∈ [0, 1000), but when the training process is decoupled, it tends to be located in regions with densely packed contour lines, suggesting that there still exists gradients that enable pretrained model to escape from the critical point.

Table 2: Quantitative results on unconditional image generation

Loss landscape for applying task vectors. The optimal model parameters are neither the pretrained one nor the finetuned one, but lie within the plane spanned by the task vectors computed in Sec. 3.3. We utilize the pretrained and two finetuned model parameters to obtain the two task vectors, respectively. We compute an orthonormal basis from the plane spanned by the task vectors. Axis denotes the movement direction in the parameter space.

BibTeX

 @InProceedings{ma2024decouple,
        title={Decouple-Then-Merge: Towards Better Training for Diffusion Models},
        author={Ma, Qianli and Ning, Xuefei and Liu, Dongrui and Niu, Li and Zhang, Linfeng},
        booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
        year={2025}
        }