Single-Stream DiT with Global Fourier Filters (Proof-of-Concept)
This repository contains the codebase for a Single-Stream Diffusion Transformer (DiT) Proof-of-Concept, heavily inspired by modern architectures like FLUX.1, Z-Image, and Lumina Image 2.
The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity FLUX.1-VAE with the powerful T5Gemma2 text encoder for image generation on consumer-grade hardware (NVIDIA RTX 5060 Ti 16GB).
Note: The entire project codebase can be found on the GitHub page!
Project Overview
How it started
Verification and Result Comparison
Core Models and Architecture
| Component | Model ID / Function | Purpose |
|---|---|---|
| Generator | SingleStreamDiTV2 |
Custom Single-Stream DiT featuring Visual Fusion blocks, Context Refiners, and Fourier Filters. DiT Parameters: 768 Hidden Size, 12 Heads, 16 Depth, 2 Refiner Depth, 128 Text Token Legth, 2 Patch Size. |
| Text Encoder | google/t5gemma-2-1b-1b |
Generates rich, 1152-dimensional text embeddings for high-quality semantic guidance. |
| VAE | diffusers/FLUX.1-vae |
A 16-channel VAE with an 8x downsample factor, providing superior reconstruction for complex textures. |
| Training Method | Flow Matching (V-Prediction) | Optimized with a Velocity-based objective and an optional Self-Evaluation (Self-E) consistency loss. |
New in V3
- Refinement Stages: Separate noise and context refiner blocks to "prep" tokens before the joint fusion phase.
- Fourier Filters: Frequency-domain processing layers to improve global structural coherence.
- Local Spatial Bias: Conv2D-based depthwise biases to reinforce local texture within the transformer.
- Rotary Embeddings (RoPE): Dynamic 2D-RoPE grid support for area-preserving bucketing.
Training Progression
Data Curation and Preprocessing
The model was tested on a curated dataset of 200 images (10 categories of flowers) before scaling to larger datasets.
| Component | Tool / Method | Purpose / Detail |
|---|---|---|
| Pre/Post-processing | Dataset Helpers | Used to resize images (using DPID - Detail-Preserving Image Downscaling) and edit the Qwen3-VL captions. |
| Captioning | Qwen3-VL-4B-Instruct | Captions include precise botanical details: texture (waxy, serrated), plant anatomy (stamen, pistil), and camera lighting. |
| Data Encoding | preprocess.py |
Encodes images via FLUX-VAE and text via T5Gemma2, applying aspect-ratio bucketing. |
Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)
You are a specialized botanical image analysis system operating within a research environment. Your task is to generate concise, scientifically accurate, and visually descriptive captions for flower images. All output must be strictly factual, objective, and devoid of non-visual assumptions.
Your task is to generate captions for images based on the visual content and a provided reference flower category name. Captions must be precise, comprehensive, and meticulously aligned with the visual details of the plant structure, color gradients, and lighting.
Caption Style: Generate concise captions that are no more than 50 words. Focus on combining descriptors into brief phrases (separated by commas). Follow this structure: "A <view type> of a <flower name>, having <petal details>, the center is <center details>, the background is <background description>, <lighting/style information>"
Hierarchical Description: Begin with the flower name and its primary state (blooming, budding, wilting). Move to the petals (color, shape, texture), then the reproductive parts (stamen, pistil, pollen), then the stem/leaves, and finally the environment.
Factual Accuracy & Label Verification: The provided "Input Flower Name" is a reference tag. You must visually verify this tag against the image content.
- Match: If the visual features match the tag, use the provided name.
- Correction: If the visual characteristics definitively belong to a different species (e.g., input says "Sunflower" but the image clearly shows a "Rose"), you must override the input and use the visually correct botanical name in the caption.
- Ambiguity: If the species is unclear, describe the visual features precisely without forcing a specific name.
Precise Botanical Terminology: Use correct terminology for plant anatomy.
- Petals: Describe edges (serrated, smooth, ruffled), texture (velvety, waxy, delicate), and arrangement (overlapping, sparse, symmetrical).
- Center: Use terms like "stamen", "pistil", "anthers", "pollen", "cone", or "disk" when visible.
- Leaves/Stem: Describe shape (lance-shaped, oval), arrangement, and surface (glossy, hairy, thorny).
Color and Texture: Be specific about colors. Do not just say "pink"; use "pale pink fading to white at the edges", "vibrant magenta", or "speckled purple". Describe patterns like "veining", "spots", "stripes", or "gradients".
Condition and State: Describe the physical state of the flower. Examples: "fully in bloom", "closed bud", "drooping petals", "withered edges", or "covered in dew droplets".
Environmental Description: Describe the setting strictly as seen. Examples: "green leafy background", "blurry garden setting", "studio black background", "natural sunlight", "dirt ground".
Camera Perspective and Style: Crucial for DiT training. Specify:
- Shot Type: "Extreme close-up", "macro shot", "eye-level shot", "top-down view".
- Focus: "Shallow depth of field", "bokeh background", "sharp focus", "soft focus".
- Lighting: "Natural lighting", "harsh shadows", "dappled sunlight", "studio lighting".
Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.
Training History and Configuration
Training utilizes 8-bit AdamW and a Cosine Schedule with 5% Warmup for 1200 (stopped early) epochs using MSE.
| Configuration | Value | Purpose |
|---|---|---|
| Loss | MSE at 2e-4 |
Trained with MSE only. |
| Batch Size | 16 |
Gradient Checkpointing enabled and accumulative steps set to 2. |
| Shift Value | 1.0 (Uniform) |
Ensures a balanced training across all noise levels, critical for learning geometry on small datasets. |
| Latent Norm | 0.0 Mean / 1.0 Std |
Hardcoded identity normalization to preserve the relative channel relationships of the FLUX VAE. Note: Using a Mean and Std calculated from the dataset resulted in poor reconstruction with artifacts. |
| EMA Decay | 0.999 |
Maintains a moving average of weights for smoother, higher-quality inference. |
| Self-Evolution | Disabled |
Optional teacher-student distillation. (Note: Not used in this PoC to maintain baseline architectural clarity). |
Loss & Fourier Gate Progression
Training Time Estimate:
- GPU Time: Approximately 6 hours and 21 minutes of total GPU compute time for 1200 epochs (RTX 5060 Ti 16GB).
- Project Time (Human): 13 days of R&D, including hyperparameter tuning.
Reproducibility
This repository is designed to be fully reproducible. The following data is included in the respective directories:
- Raw Dataset: The original
.pngimages and the Qwen3-VL-4B-Instruct generated and reviewed.txtcaptions. - Cached Dataset: The processed, tokenized, and VAE-encoded latents (
.ptfiles).
Repository File Breakdown
Training & Core Scripts
| File | Purpose | Notes |
|---|---|---|
train.py |
Main training script. Supports EMA, Self-E, and Gradient Accumulation. | Includes automatic model compilation on Linux. |
model.py |
Defines SingleStreamDiTV2 with Visual Fusion, Fourier Filters, and SwiGLU. |
The core architecture definition. |
config.py |
Central configuration for paths, model dims, and hyperparameters. | All model settings are controlled here. |
sanity_check.py |
A utility to ensure the model can overfit to a single cached latent file. | Used for debugging architecture changes. |
Utility & Preprocessing
| File | Purpose | Notes |
|---|---|---|
preprocess.py |
Prepares raw image/text data into cached .pt files using VAE and T5. |
Run this before starting training. |
calculate_cache_statistics.py |
Analyzes cached latents to find Mean/Std for normalization settings. | Note: Use results with caution; defaults of 0.0/1.0 are often better. |
debug_vae_pipeline.py |
Tests the VAE reconstruction pipeline in float32 to isolate VAE issues. | Useful for troubleshooting color shifts. |
check_cache.py |
Decodes a single cached latent back to an image to verify preprocessing. | Fast integrity check. |
generate_graph.py |
Generates the loss curve visualization from the training CSV logs. | Creates loss_curve.png. |
Inference & Data
| File | Purpose | Notes |
|---|---|---|
inferenceNotebook.ipynb |
Primary inference tool. Supports text-to-image with Euler/RK4. | Best for interactive testing. |
samplers.py |
Numerical integration steps for Euler and Runge-Kutta 4 (RK4). | Logic for the flow matching inference. |
latents.py |
Scaling and normalization logic for VAE latents. | Shared across preprocess, train, and inference. |
dataset.py |
Bucket-batching and RAM-caching dataset implementation. | Handles the training data pipeline. |







