benetraco commited on
Commit
9c10e1e
verified
1 Parent(s): fc36213

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +95 -0
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - pytorch
5
+ - diffusers
6
+ - stable-diffusion
7
+ - latent-diffusion
8
+ - medical-imaging
9
+ - brain-mri
10
+ - multiple-sclerosis
11
+ - dataset-conditioning
12
+ ---
13
+
14
+ #: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts)
15
+ Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis.
16
+ It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.
17
+
18
+ This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets.
19
+ It uses latent diffusion to generate realistic 256脳256 scans from latent representations of resolution 32脳32 and includes special prompt tokens that allow control over the visual style.
20
+
21
+ ## 馃攳 Prompt Conditioning
22
+
23
+ Each training image was paired with a specific dataset prompt:
24
+
25
+ - "SHIFTS FLAIR MRI"
26
+ - "VH FLAIR MRI"
27
+ - "WMH2017 FLAIR MRI"
28
+
29
+ These prompts were added as new tokens in the tokenizer and trained jointly with the model,
30
+ enabling conditional generation aligned with dataset distribution.
31
+
32
+ ## 馃 Training Details
33
+
34
+ - Base model: [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
35
+ - Architecture: Latent Diffusion (U-Net + ResNet + Attention)
36
+ - Latent resolution: 32x32 (decoded to 256x256)
37
+ - Channels: 4
38
+ - Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
39
+ - Epochs: 50
40
+ - Batch size: 8
41
+ - Gradient accumulation: 4
42
+ - Optimizer: AdamW
43
+ - LR: 1.0e-4
44
+ - Betas: (0.95, 0.999)
45
+ - Weight decay: 1.0e-6
46
+ - Epsilon: 1.0e-8
47
+ - LR Scheduler: Cosine decay with 500 warm-up steps
48
+ - Noise Scheduler: DDPM
49
+ - Timesteps: 1000
50
+ - Beta schedule: linear (尾_start=0.0001, 尾_end=0.02)
51
+ - Gradient Clipping: Max norm 1.0
52
+ - Mixed Precision: Disabled
53
+ - Hardware: Single NVIDIA A30 GPU (4 dataloader workers)
54
+
55
+ ## 鉁嶏笍 Fine-Tuning Strategy
56
+
57
+ The text encoder, U-Net, and special prompt embeddings were trained jointly.
58
+ Images were encoded into 32脳32 latent space using a VAE and trained using latent diffusion.
59
+
60
+ ## 馃И Inference (Guided Sampling)
61
+
62
+ ```python
63
+ from diffusers import StableDiffusionPipeline
64
+ import torch
65
+ from torchvision.utils import save_image
66
+
67
+ pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
68
+ pipe.scheduler.set_timesteps(999)
69
+
70
+ def get_embeddings(prompt):
71
+ tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
72
+ return pipe.text_encoder(**tokens).last_hidden_state
73
+
74
+ def sample(prompt, guidance_scale=2.0, seed=42):
75
+ torch.manual_seed(seed)
76
+ latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
77
+ text_emb = get_embeddings(prompt)
78
+ uncond_emb = get_embeddings("")
79
+
80
+ for t in pipe.scheduler.timesteps:
81
+ latent_in = pipe.scheduler.scale_model_input(latent, t)
82
+ with torch.no_grad():
83
+ noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
84
+ noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
85
+ noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
86
+ latent = pipe.scheduler.step(noise, t, latent).prev_sample
87
+
88
+ latent /= pipe.vae.config.scaling_factor
89
+ with torch.no_grad():
90
+ decoded = pipe.vae.decode(latent).sample
91
+ image = (decoded + 1.0) / 2.0
92
+ image = image.clamp(0, 1)
93
+ save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")
94
+
95
+ sample("SHIFTS FLAIR MRI", guidance_scale=5.0)