Scalable Diffusion Models with Transformers

Article Link:https://arxiv.org/pdf/2212.09748.pdf

Abstract

We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops—through increased transformer depth/width or increased number of input tokens—consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the classconditional ImageNet 512 512 and 256 256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.

我们探索了一种基于 Transformer 架构的新型扩散模型。我们训练图像的潜在扩散模型,用一个在潜在补丁上运行的 Transformer 替换了常用的 U-Net 主干网络。我们通过前向传递复杂度的视角(以 GFLOPs 衡量)分析了我们的扩散 Transformer (DiT) 的可扩展性。我们发现,具有更高 GFLOPs 的 DiT(通过增加 Transformer 的深度/宽度或增加输入 tokens 的数量)始终具有更低的 FID。除了具有良好的可扩展性之外,我们最大的 DiT-XL/2 模型在 classconditional ImageNet 512x512 和 256x256 基准测试中优于所有先前的扩散模型,在后者上实现了 2.27 的最先进 FID。

U-Net 架构

U-Net 是一种常用于图像分割任务的深度学习架构,其核心是一个 U 形的编码-解码结构,由以下三部分组成:

  • 编码器(Encoder):压缩器 - 提取特征(卷积层组成,池化层)
  • 解码器(Decoder):解压器 - 补回特征(多个卷积块)
  • 跳跃连接(Skip Connection):连接编码器和解码器

U-Net架构示意图

复杂度说明

GFLOPs(每秒十亿次浮点运算)用于衡量网络在处理数据时需要进行的运算量,包括加法、乘法等操作。

FID(Fréchet Inception Distance)是一种用于评估生成模型(如生成对抗网络GAN)生成图像质量的指标。

FID值越低,表示生成图像与真实图像越相似,质量越高