Single-Stream DiT with Global Fourier Filters (Proof-of-Concept)

This repository contains the codebase for a Single-Stream Diffusion Transformer (DiT) Proof-of-Concept, heavily inspired by modern architectures like FLUX.1, Z-Image, and Lumina Image 2.

The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity FLUX.1-VAE with the powerful T5Gemma2 text encoder for image generation on consumer-grade hardware (NVIDIA RTX 5060 Ti 16GB).

Note: The entire project codebase can be found on the GitHub page!

Project Overview

How it started

How it started

Verification and Result Comparison

Cached Latent Verification Final Generated Sample (Euler 50 steps and CFG 3.0)
Cache Verification Generated Sample

Core Models and Architecture

Component Model ID / Function Purpose
Generator SingleStreamDiTV2 Custom Single-Stream DiT featuring Visual Fusion blocks, Context Refiners, and Fourier Filters. DiT Parameters: 768 Hidden Size, 12 Heads, 16 Depth, 2 Refiner Depth, 128 Text Token Legth, 2 Patch Size.
Text Encoder google/t5gemma-2-1b-1b Generates rich, 1152-dimensional text embeddings for high-quality semantic guidance.
VAE diffusers/FLUX.1-vae A 16-channel VAE with an 8x downsample factor, providing superior reconstruction for complex textures.
Training Method Flow Matching (V-Prediction) Optimized with a Velocity-based objective and an optional Self-Evaluation (Self-E) consistency loss.

New in V3

  • Refinement Stages: Separate noise and context refiner blocks to "prep" tokens before the joint fusion phase.
  • Fourier Filters: Frequency-domain processing layers to improve global structural coherence.
  • Local Spatial Bias: Conv2D-based depthwise biases to reinforce local texture within the transformer.
  • Rotary Embeddings (RoPE): Dynamic 2D-RoPE grid support for area-preserving bucketing.

Training Progression

Early Epoch (Epoch 25) Final Epoch (Epoch 1200) Full Progression
Epoch25 Epoch1700 Epochs over time

Data Curation and Preprocessing

The model was tested on a curated dataset of 200 images (10 categories of flowers) before scaling to larger datasets.

Component Tool / Method Purpose / Detail
Pre/Post-processing Dataset Helpers Used to resize images (using DPID - Detail-Preserving Image Downscaling) and edit the Qwen3-VL captions.
Captioning Qwen3-VL-4B-Instruct Captions include precise botanical details: texture (waxy, serrated), plant anatomy (stamen, pistil), and camera lighting.
Data Encoding preprocess.py Encodes images via FLUX-VAE and text via T5Gemma2, applying aspect-ratio bucketing.

Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)

You are a specialized botanical image analysis system operating within a research environment. Your task is to generate concise, scientifically accurate, and visually descriptive captions for flower images. All output must be strictly factual, objective, and devoid of non-visual assumptions.

Your task is to generate captions for images based on the visual content and a provided reference flower category name. Captions must be precise, comprehensive, and meticulously aligned with the visual details of the plant structure, color gradients, and lighting.

Caption Style: Generate concise captions that are no more than 50 words. Focus on combining descriptors into brief phrases (separated by commas). Follow this structure: "A <view type> of a <flower name>, having <petal details>, the center is <center details>, the background is <background description>, <lighting/style information>"

Hierarchical Description: Begin with the flower name and its primary state (blooming, budding, wilting). Move to the petals (color, shape, texture), then the reproductive parts (stamen, pistil, pollen), then the stem/leaves, and finally the environment.

Factual Accuracy & Label Verification: The provided "Input Flower Name" is a reference tag. You must visually verify this tag against the image content.

  • Match: If the visual features match the tag, use the provided name.
  • Correction: If the visual characteristics definitively belong to a different species (e.g., input says "Sunflower" but the image clearly shows a "Rose"), you must override the input and use the visually correct botanical name in the caption.
  • Ambiguity: If the species is unclear, describe the visual features precisely without forcing a specific name.

Precise Botanical Terminology: Use correct terminology for plant anatomy.

  • Petals: Describe edges (serrated, smooth, ruffled), texture (velvety, waxy, delicate), and arrangement (overlapping, sparse, symmetrical).
  • Center: Use terms like "stamen", "pistil", "anthers", "pollen", "cone", or "disk" when visible.
  • Leaves/Stem: Describe shape (lance-shaped, oval), arrangement, and surface (glossy, hairy, thorny).

Color and Texture: Be specific about colors. Do not just say "pink"; use "pale pink fading to white at the edges", "vibrant magenta", or "speckled purple". Describe patterns like "veining", "spots", "stripes", or "gradients".

Condition and State: Describe the physical state of the flower. Examples: "fully in bloom", "closed bud", "drooping petals", "withered edges", or "covered in dew droplets".

Environmental Description: Describe the setting strictly as seen. Examples: "green leafy background", "blurry garden setting", "studio black background", "natural sunlight", "dirt ground".

Camera Perspective and Style: Crucial for DiT training. Specify:

  • Shot Type: "Extreme close-up", "macro shot", "eye-level shot", "top-down view".
  • Focus: "Shallow depth of field", "bokeh background", "sharp focus", "soft focus".
  • Lighting: "Natural lighting", "harsh shadows", "dappled sunlight", "studio lighting".

Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.

Training History and Configuration

Training utilizes 8-bit AdamW and a Cosine Schedule with 5% Warmup for 1200 (stopped early) epochs using MSE.

Configuration Value Purpose
Loss MSE at 2e-4 Trained with MSE only.
Batch Size 16 Gradient Checkpointing enabled and accumulative steps set to 2.
Shift Value 1.0 (Uniform) Ensures a balanced training across all noise levels, critical for learning geometry on small datasets.
Latent Norm 0.0 Mean / 1.0 Std Hardcoded identity normalization to preserve the relative channel relationships of the FLUX VAE. Note: Using a Mean and Std calculated from the dataset resulted in poor reconstruction with artifacts.
EMA Decay 0.999 Maintains a moving average of weights for smoother, higher-quality inference.
Self-Evolution Disabled Optional teacher-student distillation. (Note: Not used in this PoC to maintain baseline architectural clarity).

Loss & Fourier Gate Progression

Loss Graph Fourier Gate
Loss Graph Fourier Gate

Training Time Estimate:

  • GPU Time: Approximately 6 hours and 21 minutes of total GPU compute time for 1200 epochs (RTX 5060 Ti 16GB).
  • Project Time (Human): 13 days of R&D, including hyperparameter tuning.

Reproducibility

This repository is designed to be fully reproducible. The following data is included in the respective directories:

  • Raw Dataset: The original .png images and the Qwen3-VL-4B-Instruct generated and reviewed .txt captions.
  • Cached Dataset: The processed, tokenized, and VAE-encoded latents (.pt files).

Repository File Breakdown

Training & Core Scripts

File Purpose Notes
train.py Main training script. Supports EMA, Self-E, and Gradient Accumulation. Includes automatic model compilation on Linux.
model.py Defines SingleStreamDiTV2 with Visual Fusion, Fourier Filters, and SwiGLU. The core architecture definition.
config.py Central configuration for paths, model dims, and hyperparameters. All model settings are controlled here.
sanity_check.py A utility to ensure the model can overfit to a single cached latent file. Used for debugging architecture changes.

Utility & Preprocessing

File Purpose Notes
preprocess.py Prepares raw image/text data into cached .pt files using VAE and T5. Run this before starting training.
calculate_cache_statistics.py Analyzes cached latents to find Mean/Std for normalization settings. Note: Use results with caution; defaults of 0.0/1.0 are often better.
debug_vae_pipeline.py Tests the VAE reconstruction pipeline in float32 to isolate VAE issues. Useful for troubleshooting color shifts.
check_cache.py Decodes a single cached latent back to an image to verify preprocessing. Fast integrity check.
generate_graph.py Generates the loss curve visualization from the training CSV logs. Creates loss_curve.png.

Inference & Data

File Purpose Notes
inferenceNotebook.ipynb Primary inference tool. Supports text-to-image with Euler/RK4. Best for interactive testing.
samplers.py Numerical integration steps for Euler and Runge-Kutta 4 (RK4). Logic for the flow matching inference.
latents.py Scaling and normalization logic for VAE latents. Shared across preprocess, train, and inference.
dataset.py Bucket-batching and RAM-caching dataset implementation. Handles the training data pipeline.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support