[{"data":1,"prerenderedAt":-1},["ShallowReactive",2],{"project-71029":3},{"id":4,"name":5,"fullName":6,"owner":7,"repo":5,"description":8,"homepage":9,"htmlUrl":10,"language":11,"languages":10,"totalLinesOfCode":10,"stars":12,"forks":13,"watchers":14,"openIssues":15,"contributorsCount":16,"subscribersCount":16,"size":16,"stars1d":17,"stars7d":18,"stars30d":19,"stars90d":16,"forks30d":16,"starsTrendScore":20,"compositeScore":21,"rankGlobal":10,"rankLanguage":10,"license":22,"archived":23,"fork":24,"defaultBranch":25,"hasWiki":24,"hasPages":24,"topics":26,"createdAt":10,"pushedAt":10,"updatedAt":27,"readmeContent":28,"aiSummary":29,"trendingCount":16,"starSnapshotCount":16,"syncStatus":14,"lastSyncTime":30,"discoverSource":31},71029,"DiT","facebookresearch\u002FDiT","facebookresearch","Official PyTorch Implementation of \"Scalable Diffusion Models with Transformers\"","",null,"Python",8615,790,2,67,0,5,17,52,15,39.69,"Other",true,false,"main",[],"2026-06-12 02:02:46","## Scalable Diffusion Models with Transformers (DiT)\u003Cbr>\u003Csub>Official PyTorch Implementation\u003C\u002Fsub>\n\n### [Paper](http:\u002F\u002Farxiv.org\u002Fabs\u002F2212.09748) | [Project Page](https:\u002F\u002Fwww.wpeebles.com\u002FDiT) | Run DiT-XL\u002F2 [![Hugging Face Spaces](https:\u002F\u002Fimg.shields.io\u002Fbadge\u002F%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https:\u002F\u002Fhuggingface.co\u002Fspaces\u002Fwpeebles\u002FDiT) [![Open In Colab](https:\u002F\u002Fcolab.research.google.com\u002Fassets\u002Fcolab-badge.svg)](http:\u002F\u002Fcolab.research.google.com\u002Fgithub\u002Ffacebookresearch\u002FDiT\u002Fblob\u002Fmain\u002Frun_DiT.ipynb) \u003Ca href=\"https:\u002F\u002Freplicate.com\u002Farielreplicate\u002Fscalable_diffusion_with_transformers\">\u003Cimg src=\"https:\u002F\u002Freplicate.com\u002Farielreplicate\u002Fscalable_diffusion_with_transformers\u002Fbadge\">\u003C\u002Fa>\n\n![DiT samples](visuals\u002Fsample_grid_0.png)\n\nThis repo contains PyTorch model definitions, pre-trained weights and training\u002Fsampling code for our paper exploring \ndiffusion models with transformers (DiTs). You can find more visualizations on our [project page](https:\u002F\u002Fwww.wpeebles.com\u002FDiT).\n\n> [**Scalable Diffusion Models with Transformers**](https:\u002F\u002Fwww.wpeebles.com\u002FDiT)\u003Cbr>\n> [William Peebles](https:\u002F\u002Fwww.wpeebles.com), [Saining Xie](https:\u002F\u002Fwww.sainingxie.com)\n> \u003Cbr>UC Berkeley, New York University\u003Cbr>\n\nWe train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on \nlatent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass \ncomplexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth\u002Fwidth or\nincreased number of input tokens---consistently have lower FID. In addition to good scalability properties, our \nDiT-XL\u002F2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, \nachieving a state-of-the-art FID of 2.27 on the latter.\n\nThis repository contains:\n\n* 🪐 A simple PyTorch [implementation](models.py) of DiT\n* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)\n* 💥 A self-contained [Hugging Face Space](https:\u002F\u002Fhuggingface.co\u002Fspaces\u002Fwpeebles\u002FDiT) and [Colab notebook](http:\u002F\u002Fcolab.research.google.com\u002Fgithub\u002Ffacebookresearch\u002FDiT\u002Fblob\u002Fmain\u002Frun_DiT.ipynb) for running pre-trained DiT-XL\u002F2 models\n* 🛸 A DiT [training script](train.py) using PyTorch DDP\n\nAn implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https:\u002F\u002Fgithub.com\u002Fhuggingface\u002Fdiffusers\u002Fblob\u002Fmain\u002Fdocs\u002Fsource\u002Fen\u002Fapi\u002Fpipelines\u002Fdit.mdx).\n\n\n## Setup\n\nFirst, download and set up the repo:\n\n```bash\ngit clone https:\u002F\u002Fgithub.com\u002Ffacebookresearch\u002FDiT.git\ncd DiT\n```\n\nWe provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want \nto run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.\n\n```bash\nconda env create -f environment.yml\nconda activate DiT\n```\n\n\n## Sampling [![Hugging Face Spaces](https:\u002F\u002Fimg.shields.io\u002Fbadge\u002F%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https:\u002F\u002Fhuggingface.co\u002Fspaces\u002Fwpeebles\u002FDiT) [![Open In Colab](https:\u002F\u002Fcolab.research.google.com\u002Fassets\u002Fcolab-badge.svg)](http:\u002F\u002Fcolab.research.google.com\u002Fgithub\u002Ffacebookresearch\u002FDiT\u002Fblob\u002Fmain\u002Frun_DiT.ipynb)\n![More DiT samples](visuals\u002Fsample_grid_1.png)\n\n**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be \nautomatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256\nand 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from\nour 512x512 DiT-XL\u002F2 model, you can use:\n\n```bash\npython sample.py --image-size 512 --seed 1\n```\n\nFor convenience, our pre-trained DiT models can be downloaded directly here as well:\n\n| DiT Model     | Image Resolution | FID-50K | Inception Score | Gflops | \n|---------------|------------------|---------|-----------------|--------|\n| [XL\u002F2](https:\u002F\u002Fdl.fbaipublicfiles.com\u002FDiT\u002Fmodels\u002FDiT-XL-2-256x256.pt) | 256x256          | 2.27    | 278.24          | 119    |\n| [XL\u002F2](https:\u002F\u002Fdl.fbaipublicfiles.com\u002FDiT\u002Fmodels\u002FDiT-XL-2-512x512.pt) | 512x512          | 3.04    | 240.82          | 525    |\n\n\n**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt`\nargument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom \n256x256 DiT-L\u002F4 model, run:\n\n```bash\npython sample.py --model DiT-L\u002F4 --image-size 256 --ckpt \u002Fpath\u002Fto\u002Fmodel.pt\n```\n\n\n## Training DiT\n\nWe provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional \nDiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL\u002F2 (256x256) training with `N` GPUs on \none node:\n\n```bash\ntorchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL\u002F2 --data-path \u002Fpath\u002Fto\u002Fimagenet\u002Ftrain\n```\n\n### PyTorch Training Results\n\nWe've trained DiT-XL\u002F2 and DiT-B\u002F4 models from scratch with the PyTorch training script\nto verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give \nsimilar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:\n\n| DiT Model  | Train Steps | FID-50K\u003Cbr> (JAX Training) | FID-50K\u003Cbr> (PyTorch Training) | PyTorch Global Training Seed |\n|------------|-------------|----------------------------|--------------------------------|------------------------------|\n| XL\u002F2       | 400K        | 19.5                       | **18.1**                       | 42                           |\n| B\u002F4        | 400K        | **68.4**                   | 68.9                           | 42                           |\n| B\u002F4        | 400K        | 68.4                       | **68.3**                       | 100                          |\n\nThese models were trained at 256x256 resolution; we used 8x A100s to train XL\u002F2 and 4x A100s to train B\u002F4. Note that FID \nhere is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). \n\n**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. \nWe've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on \nA100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to \nthe above results.\n\n### Enhancements\nTraining (and sampling) could likely be sped-up significantly by:\n- [ ] using [Flash Attention](https:\u002F\u002Fgithub.com\u002FHazyResearch\u002Fflash-attention) in the DiT model\n- [ ] using `torch.compile` in PyTorch 2.0\n\nBasic features that would be nice to add:\n- [ ] Monitor FID and other metrics\n- [ ] Generate and save samples from the EMA model periodically\n- [ ] Resume training from a checkpoint\n- [ ] AMP\u002Fbfloat16 support\n\n**🔥 Feature Update** Check out this repository at https:\u002F\u002Fgithub.com\u002Fchuanyangjin\u002Ffast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps\u002Fsec for DiT-XL\u002F2 using just a single A100 GPU.\n\n## Evaluation (FID, Inception Score, etc.)\n\nWe include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a DiT model in parallel. This script \ngenerates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow\nevaluation suite](https:\u002F\u002Fgithub.com\u002Fopenai\u002Fguided-diffusion\u002Ftree\u002Fmain\u002Fevaluations) to compute FID, Inception Score and\nother metrics. For example, to sample 50K images from our pre-trained DiT-XL\u002F2 model over `N` GPUs, run:\n\n```bash\ntorchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL\u002F2 --num-fid-samples 50000\n```\n\nThere are several additional options; see [`sample_ddp.py`](sample_ddp.py) for details. \n\n\n## Differences from JAX\n\nOur models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. \nThere may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated \nour ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID \nversus 2.27 in the paper).\n\n\n## BibTeX\n\n```bibtex\n@article{Peebles2022DiT,\n  title={Scalable Diffusion Models with Transformers},\n  author={William Peebles and Saining Xie},\n  year={2022},\n  journal={arXiv preprint arXiv:2212.09748},\n}\n```\n\n\n## Acknowledgments\nWe thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. \nWilliam Peebles is supported by the NSF Graduate Research Fellowship.\n\nThis codebase borrows from OpenAI's diffusion repos, most notably [ADM](https:\u002F\u002Fgithub.com\u002Fopenai\u002Fguided-diffusion).\n\n\n## License\nThe code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details.\n","DiT项目是Facebook Research开发的一个基于PyTorch的扩散模型与变压器结合的官方实现。该项目通过将常见的U-Net架构替换为在潜在补丁上操作的变压器，探索了扩散模型的新方法，并提供了预训练权重和训练\u002F采样代码。技术特点包括使用高Gflops的变压器以提高图像生成质量，在ImageNet 512×512和256×256基准测试中达到了先进的FID分数。此外，项目还支持Hugging Face Spaces和Google Colab环境下的直接运行。适用于需要高质量图像生成或对现有扩散模型性能有更高要求的研究和应用场景。","2026-06-11 03:35:34","high_star"]