MaxMilan1 commited on
Commit
a1f69bb
·
1 Parent(s): 883d514
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import os
3
 
4
  from util.instantmesh import generate_mvs, make3d, preprocess, check_input_image
 
5
 
6
  _CITE_ = r"""
7
  ```bibtex
@@ -16,6 +17,33 @@ _CITE_ = r"""
16
 
17
 
18
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with gr.Row(variant="panel"):
20
  with gr.Column():
21
  with gr.Row():
@@ -62,7 +90,6 @@ with gr.Blocks() as demo:
62
  inputs=[input_image],
63
  label="Examples",
64
  cache_examples=False,
65
- examples_per_page=12
66
  )
67
 
68
  with gr.Column():
 
2
  import os
3
 
4
  from util.instantmesh import generate_mvs, make3d, preprocess, check_input_image
5
+ from util.text_img import generate_image, check_prompt
6
 
7
  _CITE_ = r"""
8
  ```bibtex
 
17
 
18
 
19
  with gr.Blocks() as demo:
20
+ with gr.Tab("Text to Image Generator"):
21
+ with gr.Row():
22
+ with gr.Column():
23
+ prompt = gr.Textbox(label="Enter a discription of a shoe")
24
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="low quality, bad quality, sketches, legs")
25
+ scale = gr.Slider(label="Control Image Scale", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
26
+ with gr.Column():
27
+ control_image = gr.Image(label="Enter an image of a shoe, that you want to use as a reference", type='numpy')
28
+ # neg_prompt = gr.Textbox(label="Enter a negative prompt", value="low quality, watermark, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, closed eyes, text, logo")
29
+ with gr.Row():
30
+ with gr.Column():
31
+ gr.Examples(
32
+ examples=[
33
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
34
+ ],
35
+ inputs=[control_image],
36
+ label="Examples",
37
+ cache_examples=False,
38
+ )
39
+ with gr.Column():
40
+ button_gen = gr.Button("Generate Image")
41
+ with gr.Row():
42
+ with gr.Column():
43
+ image_nobg = gr.Image(label="Generated Image", show_download_button=True, show_label=False)
44
+
45
+ button_gen.click(check_prompt, inputs=[prompt]).succes(generate_image, inputs=[prompt, negative_prompt, control_image, scale], outputs=[image_nobg])
46
+
47
  with gr.Row(variant="panel"):
48
  with gr.Column():
49
  with gr.Row():
 
90
  inputs=[input_image],
91
  label="Examples",
92
  cache_examples=False,
 
93
  )
94
 
95
  with gr.Column():
diff-gaussian-rasterization/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build/
2
+ diff_gaussian_rasterization.egg-info/
3
+ dist/
diff-gaussian-rasterization/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/glm"]
2
+ path = third_party/glm
3
+ url = https://github.com/g-truc/glm.git
diff-gaussian-rasterization/CMakeLists.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ cmake_minimum_required(VERSION 3.20)
13
+
14
+ project(DiffRast LANGUAGES CUDA CXX)
15
+
16
+ set(CMAKE_CXX_STANDARD 17)
17
+ set(CMAKE_CXX_EXTENSIONS OFF)
18
+ set(CMAKE_CUDA_STANDARD 17)
19
+
20
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
21
+
22
+ add_library(CudaRasterizer
23
+ cuda_rasterizer/backward.h
24
+ cuda_rasterizer/backward.cu
25
+ cuda_rasterizer/forward.h
26
+ cuda_rasterizer/forward.cu
27
+ cuda_rasterizer/auxiliary.h
28
+ cuda_rasterizer/rasterizer_impl.cu
29
+ cuda_rasterizer/rasterizer_impl.h
30
+ cuda_rasterizer/rasterizer.h
31
+ )
32
+
33
+ set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
34
+
35
+ target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
36
+ target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
diff-gaussian-rasterization/LICENSE.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gaussian-Splatting License
2
+ ===========================
3
+
4
+ **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5
+ The *Software* is in the process of being registered with the Agence pour la Protection des
6
+ Programmes (APP).
7
+
8
+ The *Software* is still being developed by the *Licensor*.
9
+
10
+ *Licensor*'s goal is to allow the research community to use, test and evaluate
11
+ the *Software*.
12
+
13
+ ## 1. Definitions
14
+
15
+ *Licensee* means any person or entity that uses the *Software* and distributes
16
+ its *Work*.
17
+
18
+ *Licensor* means the owners of the *Software*, i.e Inria and MPII
19
+
20
+ *Software* means the original work of authorship made available under this
21
+ License ie gaussian-splatting.
22
+
23
+ *Work* means the *Software* and any additions to or derivative works of the
24
+ *Software* that are made available under this License.
25
+
26
+
27
+ ## 2. Purpose
28
+ This license is intended to define the rights granted to the *Licensee* by
29
+ Licensors under the *Software*.
30
+
31
+ ## 3. Rights granted
32
+
33
+ For the above reasons Licensors have decided to distribute the *Software*.
34
+ Licensors grant non-exclusive rights to use the *Software* for research purposes
35
+ to research users (both academic and industrial), free of charge, without right
36
+ to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37
+ and/or evaluation purposes only.
38
+
39
+ Subject to the terms and conditions of this License, you are granted a
40
+ non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41
+ publicly display, publicly perform and distribute its *Work* and any resulting
42
+ derivative works in any form.
43
+
44
+ ## 4. Limitations
45
+
46
+ **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47
+ so under this License, (b) you include a complete copy of this License with
48
+ your distribution, and (c) you retain without modification any copyright,
49
+ patent, trademark, or attribution notices that are present in the *Work*.
50
+
51
+ **4.2 Derivative Works.** You may specify that additional or different terms apply
52
+ to the use, reproduction, and distribution of your derivative works of the *Work*
53
+ ("Your Terms") only if (a) Your Terms provide that the use limitation in
54
+ Section 2 applies to your derivative works, and (b) you identify the specific
55
+ derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56
+ this License (including the redistribution requirements in Section 3.1) will
57
+ continue to apply to the *Work* itself.
58
+
59
+ **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60
+ users explicitly acknowledge having received from Licensors all information
61
+ allowing to appreciate the adequacy between of the *Software* and their needs and
62
+ to undertake all necessary precautions for its execution and use.
63
+
64
+ **4.4** The *Software* is provided both as a compiled library file and as source
65
+ code. In case of using the *Software* for a publication or other results obtained
66
+ through the use of the *Software*, users are strongly encouraged to cite the
67
+ corresponding publications as explained in the documentation of the *Software*.
68
+
69
+ ## 5. Disclaimer
70
+
71
+ THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72
+ WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73
+ UNAUTHORIZED USE: [email protected] . ANY SUCH ACTION WILL
74
+ CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75
+ OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76
+ USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77
+ ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78
+ AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80
+ GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81
+ HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83
+ IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
diff-gaussian-rasterization/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Differential Gaussian Rasterization
2
+
3
+ Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
4
+
5
+ <section class="section" id="BibTeX">
6
+ <div class="container is-max-desktop content">
7
+ <h2 class="title">BibTeX</h2>
8
+ <pre><code>@Article{kerbl3Dgaussians,
9
+ author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
10
+ title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
11
+ journal = {ACM Transactions on Graphics},
12
+ number = {4},
13
+ volume = {42},
14
+ month = {July},
15
+ year = {2023},
16
+ url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
17
+ }</code></pre>
18
+ </div>
19
+ </section>
diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
13
+ #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
14
+
15
+ #include "config.h"
16
+ #include "stdio.h"
17
+
18
+ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
19
+ #define NUM_WARPS (BLOCK_SIZE/32)
20
+
21
+ // Spherical harmonics coefficients
22
+ __device__ const float SH_C0 = 0.28209479177387814f;
23
+ __device__ const float SH_C1 = 0.4886025119029199f;
24
+ __device__ const float SH_C2[] = {
25
+ 1.0925484305920792f,
26
+ -1.0925484305920792f,
27
+ 0.31539156525252005f,
28
+ -1.0925484305920792f,
29
+ 0.5462742152960396f
30
+ };
31
+ __device__ const float SH_C3[] = {
32
+ -0.5900435899266435f,
33
+ 2.890611442640554f,
34
+ -0.4570457994644658f,
35
+ 0.3731763325901154f,
36
+ -0.4570457994644658f,
37
+ 1.445305721320277f,
38
+ -0.5900435899266435f
39
+ };
40
+
41
+ __forceinline__ __device__ float ndc2Pix(float v, int S)
42
+ {
43
+ return ((v + 1.0) * S - 1.0) * 0.5;
44
+ }
45
+
46
+ __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
47
+ {
48
+ rect_min = {
49
+ min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
50
+ min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
51
+ };
52
+ rect_max = {
53
+ min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
54
+ min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
55
+ };
56
+ }
57
+
58
+ __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
59
+ {
60
+ float3 transformed = {
61
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
62
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
63
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
64
+ };
65
+ return transformed;
66
+ }
67
+
68
+ __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
69
+ {
70
+ float4 transformed = {
71
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
72
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
73
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
74
+ matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
75
+ };
76
+ return transformed;
77
+ }
78
+
79
+ __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
80
+ {
81
+ float3 transformed = {
82
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
83
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
84
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
85
+ };
86
+ return transformed;
87
+ }
88
+
89
+ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
90
+ {
91
+ float3 transformed = {
92
+ matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
93
+ matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
94
+ matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
95
+ };
96
+ return transformed;
97
+ }
98
+
99
+ __forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
100
+ {
101
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
102
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
103
+ float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
104
+ return dnormvdz;
105
+ }
106
+
107
+ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
108
+ {
109
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
110
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
111
+
112
+ float3 dnormvdv;
113
+ dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
114
+ dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
115
+ dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
116
+ return dnormvdv;
117
+ }
118
+
119
+ __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
120
+ {
121
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
122
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
123
+
124
+ float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
125
+ float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
126
+ float4 dnormvdv;
127
+ dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
128
+ dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
129
+ dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
130
+ dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
131
+ return dnormvdv;
132
+ }
133
+
134
+ __forceinline__ __device__ float sigmoid(float x)
135
+ {
136
+ return 1.0f / (1.0f + expf(-x));
137
+ }
138
+
139
+ __forceinline__ __device__ bool in_frustum(int idx,
140
+ const float* orig_points,
141
+ const float* viewmatrix,
142
+ const float* projmatrix,
143
+ bool prefiltered,
144
+ float3& p_view)
145
+ {
146
+ float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
147
+
148
+ // Bring points to screen space
149
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
150
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
151
+ float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
152
+ p_view = transformPoint4x3(p_orig, viewmatrix);
153
+
154
+ if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
155
+ {
156
+ if (prefiltered)
157
+ {
158
+ printf("Point is filtered although prefiltered is set. This shouldn't happen!");
159
+ __trap();
160
+ }
161
+ return false;
162
+ }
163
+ return true;
164
+ }
165
+
166
+ #define CHECK_CUDA(A, debug) \
167
+ A; if(debug) { \
168
+ auto ret = cudaDeviceSynchronize(); \
169
+ if (ret != cudaSuccess) { \
170
+ std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
171
+ throw std::runtime_error(cudaGetErrorString(ret)); \
172
+ } \
173
+ }
174
+
175
+ #endif
diff-gaussian-rasterization/cuda_rasterizer/backward.cu ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #include "backward.h"
13
+ #include "auxiliary.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Backward pass for conversion of spherical harmonics to RGB for
19
+ // each Gaussian.
20
+ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
21
+ {
22
+ // Compute intermediate values, as it is done during forward
23
+ glm::vec3 pos = means[idx];
24
+ glm::vec3 dir_orig = pos - campos;
25
+ glm::vec3 dir = dir_orig / glm::length(dir_orig);
26
+
27
+ glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
28
+
29
+ // Use PyTorch rule for clamping: if clamping was applied,
30
+ // gradient becomes 0.
31
+ glm::vec3 dL_dRGB = dL_dcolor[idx];
32
+ dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
33
+ dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
34
+ dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
35
+
36
+ glm::vec3 dRGBdx(0, 0, 0);
37
+ glm::vec3 dRGBdy(0, 0, 0);
38
+ glm::vec3 dRGBdz(0, 0, 0);
39
+ float x = dir.x;
40
+ float y = dir.y;
41
+ float z = dir.z;
42
+
43
+ // Target location for this Gaussian to write SH gradients to
44
+ glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
45
+
46
+ // No tricks here, just high school-level calculus.
47
+ float dRGBdsh0 = SH_C0;
48
+ dL_dsh[0] = dRGBdsh0 * dL_dRGB;
49
+ if (deg > 0)
50
+ {
51
+ float dRGBdsh1 = -SH_C1 * y;
52
+ float dRGBdsh2 = SH_C1 * z;
53
+ float dRGBdsh3 = -SH_C1 * x;
54
+ dL_dsh[1] = dRGBdsh1 * dL_dRGB;
55
+ dL_dsh[2] = dRGBdsh2 * dL_dRGB;
56
+ dL_dsh[3] = dRGBdsh3 * dL_dRGB;
57
+
58
+ dRGBdx = -SH_C1 * sh[3];
59
+ dRGBdy = -SH_C1 * sh[1];
60
+ dRGBdz = SH_C1 * sh[2];
61
+
62
+ if (deg > 1)
63
+ {
64
+ float xx = x * x, yy = y * y, zz = z * z;
65
+ float xy = x * y, yz = y * z, xz = x * z;
66
+
67
+ float dRGBdsh4 = SH_C2[0] * xy;
68
+ float dRGBdsh5 = SH_C2[1] * yz;
69
+ float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
70
+ float dRGBdsh7 = SH_C2[3] * xz;
71
+ float dRGBdsh8 = SH_C2[4] * (xx - yy);
72
+ dL_dsh[4] = dRGBdsh4 * dL_dRGB;
73
+ dL_dsh[5] = dRGBdsh5 * dL_dRGB;
74
+ dL_dsh[6] = dRGBdsh6 * dL_dRGB;
75
+ dL_dsh[7] = dRGBdsh7 * dL_dRGB;
76
+ dL_dsh[8] = dRGBdsh8 * dL_dRGB;
77
+
78
+ dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
79
+ dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
80
+ dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
81
+
82
+ if (deg > 2)
83
+ {
84
+ float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
85
+ float dRGBdsh10 = SH_C3[1] * xy * z;
86
+ float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
87
+ float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
88
+ float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
89
+ float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
90
+ float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
91
+ dL_dsh[9] = dRGBdsh9 * dL_dRGB;
92
+ dL_dsh[10] = dRGBdsh10 * dL_dRGB;
93
+ dL_dsh[11] = dRGBdsh11 * dL_dRGB;
94
+ dL_dsh[12] = dRGBdsh12 * dL_dRGB;
95
+ dL_dsh[13] = dRGBdsh13 * dL_dRGB;
96
+ dL_dsh[14] = dRGBdsh14 * dL_dRGB;
97
+ dL_dsh[15] = dRGBdsh15 * dL_dRGB;
98
+
99
+ dRGBdx += (
100
+ SH_C3[0] * sh[9] * 3.f * 2.f * xy +
101
+ SH_C3[1] * sh[10] * yz +
102
+ SH_C3[2] * sh[11] * -2.f * xy +
103
+ SH_C3[3] * sh[12] * -3.f * 2.f * xz +
104
+ SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
105
+ SH_C3[5] * sh[14] * 2.f * xz +
106
+ SH_C3[6] * sh[15] * 3.f * (xx - yy));
107
+
108
+ dRGBdy += (
109
+ SH_C3[0] * sh[9] * 3.f * (xx - yy) +
110
+ SH_C3[1] * sh[10] * xz +
111
+ SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
112
+ SH_C3[3] * sh[12] * -3.f * 2.f * yz +
113
+ SH_C3[4] * sh[13] * -2.f * xy +
114
+ SH_C3[5] * sh[14] * -2.f * yz +
115
+ SH_C3[6] * sh[15] * -3.f * 2.f * xy);
116
+
117
+ dRGBdz += (
118
+ SH_C3[1] * sh[10] * xy +
119
+ SH_C3[2] * sh[11] * 4.f * 2.f * yz +
120
+ SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
121
+ SH_C3[4] * sh[13] * 4.f * 2.f * xz +
122
+ SH_C3[5] * sh[14] * (xx - yy));
123
+ }
124
+ }
125
+ }
126
+
127
+ // The view direction is an input to the computation. View direction
128
+ // is influenced by the Gaussian's mean, so SHs gradients
129
+ // must propagate back into 3D position.
130
+ glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
131
+
132
+ // Account for normalization of direction
133
+ float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
134
+
135
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
136
+ // that is caused because the mean affects the view-dependent color.
137
+ // Additional mean gradient is accumulated in below methods.
138
+ dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
139
+ }
140
+
141
+ // Backward version of INVERSE 2D covariance matrix computation
142
+ // (due to length launched as separate kernel before other
143
+ // backward steps contained in preprocess)
144
+ __global__ void computeCov2DCUDA(int P,
145
+ const float3* means,
146
+ const int* radii,
147
+ const float* cov3Ds,
148
+ const float h_x, float h_y,
149
+ const float tan_fovx, float tan_fovy,
150
+ const float* view_matrix,
151
+ const float* dL_dconics,
152
+ float3* dL_dmeans,
153
+ float* dL_dcov)
154
+ {
155
+ auto idx = cg::this_grid().thread_rank();
156
+ if (idx >= P || !(radii[idx] > 0))
157
+ return;
158
+
159
+ // Reading location of 3D covariance for this Gaussian
160
+ const float* cov3D = cov3Ds + 6 * idx;
161
+
162
+ // Fetch gradients, recompute 2D covariance and relevant
163
+ // intermediate forward results needed in the backward.
164
+ float3 mean = means[idx];
165
+ float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
166
+ float3 t = transformPoint4x3(mean, view_matrix);
167
+
168
+ const float limx = 1.3f * tan_fovx;
169
+ const float limy = 1.3f * tan_fovy;
170
+ const float txtz = t.x / t.z;
171
+ const float tytz = t.y / t.z;
172
+ t.x = min(limx, max(-limx, txtz)) * t.z;
173
+ t.y = min(limy, max(-limy, tytz)) * t.z;
174
+
175
+ const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
176
+ const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
177
+
178
+ glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
179
+ 0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
180
+ 0, 0, 0);
181
+
182
+ glm::mat3 W = glm::mat3(
183
+ view_matrix[0], view_matrix[4], view_matrix[8],
184
+ view_matrix[1], view_matrix[5], view_matrix[9],
185
+ view_matrix[2], view_matrix[6], view_matrix[10]);
186
+
187
+ glm::mat3 Vrk = glm::mat3(
188
+ cov3D[0], cov3D[1], cov3D[2],
189
+ cov3D[1], cov3D[3], cov3D[4],
190
+ cov3D[2], cov3D[4], cov3D[5]);
191
+
192
+ glm::mat3 T = W * J;
193
+
194
+ glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
195
+
196
+ // Use helper variables for 2D covariance entries. More compact.
197
+ float a = cov2D[0][0] += 0.3f;
198
+ float b = cov2D[0][1];
199
+ float c = cov2D[1][1] += 0.3f;
200
+
201
+ float denom = a * c - b * b;
202
+ float dL_da = 0, dL_db = 0, dL_dc = 0;
203
+ float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
204
+
205
+ if (denom2inv != 0)
206
+ {
207
+ // Gradients of loss w.r.t. entries of 2D covariance matrix,
208
+ // given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
209
+ // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
210
+ dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
211
+ dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
212
+ dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);
213
+
214
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
215
+ // given gradients w.r.t. 2D covariance matrix (diagonal).
216
+ // cov2D = transpose(T) * transpose(Vrk) * T;
217
+ dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
218
+ dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
219
+ dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);
220
+
221
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
222
+ // given gradients w.r.t. 2D covariance matrix (off-diagonal).
223
+ // Off-diagonal elements appear twice --> double the gradient.
224
+ // cov2D = transpose(T) * transpose(Vrk) * T;
225
+ dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
226
+ dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
227
+ dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
228
+ }
229
+ else
230
+ {
231
+ for (int i = 0; i < 6; i++)
232
+ dL_dcov[6 * idx + i] = 0;
233
+ }
234
+
235
+ // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
236
+ // cov2D = transpose(T) * transpose(Vrk) * T;
237
+ float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
238
+ (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
239
+ float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
240
+ (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
241
+ float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
242
+ (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
243
+ float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
244
+ (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
245
+ float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
246
+ (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
247
+ float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
248
+ (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
249
+
250
+ // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
251
+ // T = W * J
252
+ float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
253
+ float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
254
+ float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
255
+ float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
256
+
257
+ float tz = 1.f / t.z;
258
+ float tz2 = tz * tz;
259
+ float tz3 = tz2 * tz;
260
+
261
+ // Gradients of loss w.r.t. transformed Gaussian mean t
262
+ float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
263
+ float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
264
+ float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
265
+
266
+ // Account for transformation of mean to t
267
+ // t = transformPoint4x3(mean, view_matrix);
268
+ float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);
269
+
270
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
271
+ // that is caused because the mean affects the covariance matrix.
272
+ // Additional mean gradient is accumulated in BACKWARD::preprocess.
273
+ dL_dmeans[idx] = dL_dmean;
274
+ }
275
+
276
+ // Backward pass for the conversion of scale and rotation to a
277
+ // 3D covariance matrix for each Gaussian.
278
+ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
279
+ {
280
+ // Recompute (intermediate) results for the 3D covariance computation.
281
+ glm::vec4 q = rot;// / glm::length(rot);
282
+ float r = q.x;
283
+ float x = q.y;
284
+ float y = q.z;
285
+ float z = q.w;
286
+
287
+ glm::mat3 R = glm::mat3(
288
+ 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
289
+ 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
290
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
291
+ );
292
+
293
+ glm::mat3 S = glm::mat3(1.0f);
294
+
295
+ glm::vec3 s = mod * scale;
296
+ S[0][0] = s.x;
297
+ S[1][1] = s.y;
298
+ S[2][2] = s.z;
299
+
300
+ glm::mat3 M = S * R;
301
+
302
+ const float* dL_dcov3D = dL_dcov3Ds + 6 * idx;
303
+
304
+ glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
305
+ glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
306
+
307
+ // Convert per-element covariance loss gradients to matrix form
308
+ glm::mat3 dL_dSigma = glm::mat3(
309
+ dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
310
+ 0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
311
+ 0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]
312
+ );
313
+
314
+ // Compute loss gradient w.r.t. matrix M
315
+ // dSigma_dM = 2 * M
316
+ glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
317
+
318
+ glm::mat3 Rt = glm::transpose(R);
319
+ glm::mat3 dL_dMt = glm::transpose(dL_dM);
320
+
321
+ // Gradients of loss w.r.t. scale
322
+ glm::vec3* dL_dscale = dL_dscales + idx;
323
+ dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
324
+ dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
325
+ dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
326
+
327
+ dL_dMt[0] *= s.x;
328
+ dL_dMt[1] *= s.y;
329
+ dL_dMt[2] *= s.z;
330
+
331
+ // Gradients of loss w.r.t. normalized quaternion
332
+ glm::vec4 dL_dq;
333
+ dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
334
+ dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
335
+ dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
336
+ dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
337
+
338
+ // Gradients of loss w.r.t. unnormalized quaternion
339
+ float4* dL_drot = (float4*)(dL_drots + idx);
340
+ *dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
341
+ }
342
+
343
+ // Backward pass of the preprocessing steps, except
344
+ // for the covariance computation and inversion
345
+ // (those are handled by a previous kernel call)
346
+ template<int C>
347
+ __global__ void preprocessCUDA(
348
+ int P, int D, int M,
349
+ const float3* means,
350
+ const int* radii,
351
+ const float* shs,
352
+ const bool* clamped,
353
+ const glm::vec3* scales,
354
+ const glm::vec4* rotations,
355
+ const float scale_modifier,
356
+ const float* proj,
357
+ const glm::vec3* campos,
358
+ const float3* dL_dmean2D,
359
+ glm::vec3* dL_dmeans,
360
+ float* dL_dcolor,
361
+ float* dL_dcov3D,
362
+ float* dL_dsh,
363
+ glm::vec3* dL_dscale,
364
+ glm::vec4* dL_drot)
365
+ {
366
+ auto idx = cg::this_grid().thread_rank();
367
+ if (idx >= P || !(radii[idx] > 0))
368
+ return;
369
+
370
+ float3 m = means[idx];
371
+
372
+ // Taking care of gradients from the screenspace points
373
+ float4 m_hom = transformPoint4x4(m, proj);
374
+ float m_w = 1.0f / (m_hom.w + 0.0000001f);
375
+
376
+ // Compute loss gradient w.r.t. 3D means due to gradients of 2D means
377
+ // from rendering procedure
378
+ glm::vec3 dL_dmean;
379
+ float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
380
+ float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
381
+ dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
382
+ dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
383
+ dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
384
+
385
+ // That's the second part of the mean gradient. Previous computation
386
+ // of cov2D and following SH conversion also affects it.
387
+ dL_dmeans[idx] += dL_dmean;
388
+
389
+ // Compute gradient updates due to computing colors from SHs
390
+ if (shs)
391
+ computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
392
+
393
+ // Compute gradient updates due to computing covariance from scale/rotation
394
+ if (scales)
395
+ computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
396
+ }
397
+
398
+ // Backward version of the rendering procedure.
399
+ template <uint32_t C>
400
+ __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
401
+ renderCUDA(
402
+ const uint2* __restrict__ ranges,
403
+ const uint32_t* __restrict__ point_list,
404
+ int W, int H,
405
+ const float* __restrict__ bg_color,
406
+ const float2* __restrict__ points_xy_image,
407
+ const float4* __restrict__ conic_opacity,
408
+ const float* __restrict__ colors,
409
+ const float* __restrict__ final_Ts,
410
+ const uint32_t* __restrict__ n_contrib,
411
+ const float* __restrict__ dL_dpixels,
412
+ float3* __restrict__ dL_dmean2D,
413
+ float4* __restrict__ dL_dconic2D,
414
+ float* __restrict__ dL_dopacity,
415
+ float* __restrict__ dL_dcolors)
416
+ {
417
+ // We rasterize again. Compute necessary block info.
418
+ auto block = cg::this_thread_block();
419
+ const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
420
+ const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
421
+ const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
422
+ const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
423
+ const uint32_t pix_id = W * pix.y + pix.x;
424
+ const float2 pixf = { (float)pix.x, (float)pix.y };
425
+
426
+ const bool inside = pix.x < W&& pix.y < H;
427
+ const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
428
+
429
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
430
+
431
+ bool done = !inside;
432
+ int toDo = range.y - range.x;
433
+
434
+ __shared__ int collected_id[BLOCK_SIZE];
435
+ __shared__ float2 collected_xy[BLOCK_SIZE];
436
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
437
+ __shared__ float collected_colors[C * BLOCK_SIZE];
438
+
439
+ // In the forward, we stored the final value for T, the
440
+ // product of all (1 - alpha) factors.
441
+ const float T_final = inside ? final_Ts[pix_id] : 0;
442
+ float T = T_final;
443
+
444
+ // We start from the back. The ID of the last contributing
445
+ // Gaussian is known from each pixel from the forward.
446
+ uint32_t contributor = toDo;
447
+ const int last_contributor = inside ? n_contrib[pix_id] : 0;
448
+
449
+ float accum_rec[C] = { 0 };
450
+ float dL_dpixel[C];
451
+ if (inside)
452
+ for (int i = 0; i < C; i++)
453
+ dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
454
+
455
+ float last_alpha = 0;
456
+ float last_color[C] = { 0 };
457
+
458
+ // Gradient of pixel coordinate w.r.t. normalized
459
+ // screen-space viewport corrdinates (-1 to 1)
460
+ const float ddelx_dx = 0.5 * W;
461
+ const float ddely_dy = 0.5 * H;
462
+
463
+ // Traverse all Gaussians
464
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
465
+ {
466
+ // Load auxiliary data into shared memory, start in the BACK
467
+ // and load them in revers order.
468
+ block.sync();
469
+ const int progress = i * BLOCK_SIZE + block.thread_rank();
470
+ if (range.x + progress < range.y)
471
+ {
472
+ const int coll_id = point_list[range.y - progress - 1];
473
+ collected_id[block.thread_rank()] = coll_id;
474
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
475
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
476
+ for (int i = 0; i < C; i++)
477
+ collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
478
+ }
479
+ block.sync();
480
+
481
+ // Iterate over Gaussians
482
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
483
+ {
484
+ // Keep track of current Gaussian ID. Skip, if this one
485
+ // is behind the last contributor for this pixel.
486
+ contributor--;
487
+ if (contributor >= last_contributor)
488
+ continue;
489
+
490
+ // Compute blending values, as before.
491
+ const float2 xy = collected_xy[j];
492
+ const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
493
+ const float4 con_o = collected_conic_opacity[j];
494
+ const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
495
+ if (power > 0.0f)
496
+ continue;
497
+
498
+ const float G = exp(power);
499
+ const float alpha = min(0.99f, con_o.w * G);
500
+ if (alpha < 1.0f / 255.0f)
501
+ continue;
502
+
503
+ T = T / (1.f - alpha);
504
+ const float dchannel_dcolor = alpha * T;
505
+
506
+ // Propagate gradients to per-Gaussian colors and keep
507
+ // gradients w.r.t. alpha (blending factor for a Gaussian/pixel
508
+ // pair).
509
+ float dL_dalpha = 0.0f;
510
+ const int global_id = collected_id[j];
511
+ for (int ch = 0; ch < C; ch++)
512
+ {
513
+ const float c = collected_colors[ch * BLOCK_SIZE + j];
514
+ // Update last color (to be used in the next iteration)
515
+ accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
516
+ last_color[ch] = c;
517
+
518
+ const float dL_dchannel = dL_dpixel[ch];
519
+ dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
520
+ // Update the gradients w.r.t. color of the Gaussian.
521
+ // Atomic, since this pixel is just one of potentially
522
+ // many that were affected by this Gaussian.
523
+ atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
524
+ }
525
+ dL_dalpha *= T;
526
+ // Update last alpha (to be used in the next iteration)
527
+ last_alpha = alpha;
528
+
529
+ // Account for fact that alpha also influences how much of
530
+ // the background color is added if nothing left to blend
531
+ float bg_dot_dpixel = 0;
532
+ for (int i = 0; i < C; i++)
533
+ bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
534
+ dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
535
+
536
+
537
+ // Helpful reusable temporary variables
538
+ const float dL_dG = con_o.w * dL_dalpha;
539
+ const float gdx = G * d.x;
540
+ const float gdy = G * d.y;
541
+ const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
542
+ const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
543
+
544
+ // Update gradients w.r.t. 2D mean position of the Gaussian
545
+ atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
546
+ atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
547
+
548
+ // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
549
+ atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
550
+ atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
551
+ atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
552
+
553
+ // Update gradients w.r.t. opacity of the Gaussian
554
+ atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
555
+ }
556
+ }
557
+ }
558
+
559
+ void BACKWARD::preprocess(
560
+ int P, int D, int M,
561
+ const float3* means3D,
562
+ const int* radii,
563
+ const float* shs,
564
+ const bool* clamped,
565
+ const glm::vec3* scales,
566
+ const glm::vec4* rotations,
567
+ const float scale_modifier,
568
+ const float* cov3Ds,
569
+ const float* viewmatrix,
570
+ const float* projmatrix,
571
+ const float focal_x, float focal_y,
572
+ const float tan_fovx, float tan_fovy,
573
+ const glm::vec3* campos,
574
+ const float3* dL_dmean2D,
575
+ const float* dL_dconic,
576
+ glm::vec3* dL_dmean3D,
577
+ float* dL_dcolor,
578
+ float* dL_dcov3D,
579
+ float* dL_dsh,
580
+ glm::vec3* dL_dscale,
581
+ glm::vec4* dL_drot)
582
+ {
583
+ // Propagate gradients for the path of 2D conic matrix computation.
584
+ // Somewhat long, thus it is its own kernel rather than being part of
585
+ // "preprocess". When done, loss gradient w.r.t. 3D means has been
586
+ // modified and gradient w.r.t. 3D covariance matrix has been computed.
587
+ computeCov2DCUDA << <(P + 255) / 256, 256 >> > (
588
+ P,
589
+ means3D,
590
+ radii,
591
+ cov3Ds,
592
+ focal_x,
593
+ focal_y,
594
+ tan_fovx,
595
+ tan_fovy,
596
+ viewmatrix,
597
+ dL_dconic,
598
+ (float3*)dL_dmean3D,
599
+ dL_dcov3D);
600
+
601
+ // Propagate gradients for remaining steps: finish 3D mean gradients,
602
+ // propagate color gradients to SH (if desireD), propagate 3D covariance
603
+ // matrix gradients to scale and rotation.
604
+ preprocessCUDA<NUM_CHANNELS> << < (P + 255) / 256, 256 >> > (
605
+ P, D, M,
606
+ (float3*)means3D,
607
+ radii,
608
+ shs,
609
+ clamped,
610
+ (glm::vec3*)scales,
611
+ (glm::vec4*)rotations,
612
+ scale_modifier,
613
+ projmatrix,
614
+ campos,
615
+ (float3*)dL_dmean2D,
616
+ (glm::vec3*)dL_dmean3D,
617
+ dL_dcolor,
618
+ dL_dcov3D,
619
+ dL_dsh,
620
+ dL_dscale,
621
+ dL_drot);
622
+ }
623
+
624
+ void BACKWARD::render(
625
+ const dim3 grid, const dim3 block,
626
+ const uint2* ranges,
627
+ const uint32_t* point_list,
628
+ int W, int H,
629
+ const float* bg_color,
630
+ const float2* means2D,
631
+ const float4* conic_opacity,
632
+ const float* colors,
633
+ const float* final_Ts,
634
+ const uint32_t* n_contrib,
635
+ const float* dL_dpixels,
636
+ float3* dL_dmean2D,
637
+ float4* dL_dconic2D,
638
+ float* dL_dopacity,
639
+ float* dL_dcolors)
640
+ {
641
+ renderCUDA<NUM_CHANNELS> << <grid, block >> >(
642
+ ranges,
643
+ point_list,
644
+ W, H,
645
+ bg_color,
646
+ means2D,
647
+ conic_opacity,
648
+ colors,
649
+ final_Ts,
650
+ n_contrib,
651
+ dL_dpixels,
652
+ dL_dmean2D,
653
+ dL_dconic2D,
654
+ dL_dopacity,
655
+ dL_dcolors
656
+ );
657
+ }
diff-gaussian-rasterization/cuda_rasterizer/backward.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
14
+
15
+ #include <cuda.h>
16
+ #include "cuda_runtime.h"
17
+ #include "device_launch_parameters.h"
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace BACKWARD
22
+ {
23
+ void render(
24
+ const dim3 grid, dim3 block,
25
+ const uint2* ranges,
26
+ const uint32_t* point_list,
27
+ int W, int H,
28
+ const float* bg_color,
29
+ const float2* means2D,
30
+ const float4* conic_opacity,
31
+ const float* colors,
32
+ const float* final_Ts,
33
+ const uint32_t* n_contrib,
34
+ const float* dL_dpixels,
35
+ float3* dL_dmean2D,
36
+ float4* dL_dconic2D,
37
+ float* dL_dopacity,
38
+ float* dL_dcolors);
39
+
40
+ void preprocess(
41
+ int P, int D, int M,
42
+ const float3* means,
43
+ const int* radii,
44
+ const float* shs,
45
+ const bool* clamped,
46
+ const glm::vec3* scales,
47
+ const glm::vec4* rotations,
48
+ const float scale_modifier,
49
+ const float* cov3Ds,
50
+ const float* view,
51
+ const float* proj,
52
+ const float focal_x, float focal_y,
53
+ const float tan_fovx, float tan_fovy,
54
+ const glm::vec3* campos,
55
+ const float3* dL_dmean2D,
56
+ const float* dL_dconics,
57
+ glm::vec3* dL_dmeans,
58
+ float* dL_dcolor,
59
+ float* dL_dcov3D,
60
+ float* dL_dsh,
61
+ glm::vec3* dL_dscale,
62
+ glm::vec4* dL_drot);
63
+ }
64
+
65
+ #endif
diff-gaussian-rasterization/cuda_rasterizer/config.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
13
+ #define CUDA_RASTERIZER_CONFIG_H_INCLUDED
14
+
15
+ #define NUM_CHANNELS 3 // Default 3, RGB
16
+ #define BLOCK_X 16
17
+ #define BLOCK_Y 16
18
+
19
+ #endif
diff-gaussian-rasterization/cuda_rasterizer/forward.cu ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #include "forward.h"
13
+ #include "auxiliary.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Forward method for converting the input spherical harmonics
19
+ // coefficients of each Gaussian to a simple RGB color.
20
+ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
21
+ {
22
+ // The implementation is loosely based on code for
23
+ // "Differentiable Point-Based Radiance Fields for
24
+ // Efficient View Synthesis" by Zhang et al. (2022)
25
+ glm::vec3 pos = means[idx];
26
+ glm::vec3 dir = pos - campos;
27
+ dir = dir / glm::length(dir);
28
+
29
+ glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
30
+ glm::vec3 result = SH_C0 * sh[0];
31
+
32
+ if (deg > 0)
33
+ {
34
+ float x = dir.x;
35
+ float y = dir.y;
36
+ float z = dir.z;
37
+ result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
38
+
39
+ if (deg > 1)
40
+ {
41
+ float xx = x * x, yy = y * y, zz = z * z;
42
+ float xy = x * y, yz = y * z, xz = x * z;
43
+ result = result +
44
+ SH_C2[0] * xy * sh[4] +
45
+ SH_C2[1] * yz * sh[5] +
46
+ SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
47
+ SH_C2[3] * xz * sh[7] +
48
+ SH_C2[4] * (xx - yy) * sh[8];
49
+
50
+ if (deg > 2)
51
+ {
52
+ result = result +
53
+ SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
54
+ SH_C3[1] * xy * z * sh[10] +
55
+ SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
56
+ SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
57
+ SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
58
+ SH_C3[5] * z * (xx - yy) * sh[14] +
59
+ SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
60
+ }
61
+ }
62
+ }
63
+ result += 0.5f;
64
+
65
+ // RGB colors are clamped to positive values. If values are
66
+ // clamped, we need to keep track of this for the backward pass.
67
+ clamped[3 * idx + 0] = (result.x < 0);
68
+ clamped[3 * idx + 1] = (result.y < 0);
69
+ clamped[3 * idx + 2] = (result.z < 0);
70
+ return glm::max(result, 0.0f);
71
+ }
72
+
73
+ // Forward version of 2D covariance matrix computation
74
+ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
75
+ {
76
+ // The following models the steps outlined by equations 29
77
+ // and 31 in "EWA Splatting" (Zwicker et al., 2002).
78
+ // Additionally considers aspect / scaling of viewport.
79
+ // Transposes used to account for row-/column-major conventions.
80
+ float3 t = transformPoint4x3(mean, viewmatrix);
81
+
82
+ const float limx = 1.3f * tan_fovx;
83
+ const float limy = 1.3f * tan_fovy;
84
+ const float txtz = t.x / t.z;
85
+ const float tytz = t.y / t.z;
86
+ t.x = min(limx, max(-limx, txtz)) * t.z;
87
+ t.y = min(limy, max(-limy, tytz)) * t.z;
88
+
89
+ glm::mat3 J = glm::mat3(
90
+ focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
91
+ 0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
92
+ 0, 0, 0);
93
+
94
+ glm::mat3 W = glm::mat3(
95
+ viewmatrix[0], viewmatrix[4], viewmatrix[8],
96
+ viewmatrix[1], viewmatrix[5], viewmatrix[9],
97
+ viewmatrix[2], viewmatrix[6], viewmatrix[10]);
98
+
99
+ glm::mat3 T = W * J;
100
+
101
+ glm::mat3 Vrk = glm::mat3(
102
+ cov3D[0], cov3D[1], cov3D[2],
103
+ cov3D[1], cov3D[3], cov3D[4],
104
+ cov3D[2], cov3D[4], cov3D[5]);
105
+
106
+ glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
107
+
108
+ // Apply low-pass filter: every Gaussian should be at least
109
+ // one pixel wide/high. Discard 3rd row and column.
110
+ cov[0][0] += 0.3f;
111
+ cov[1][1] += 0.3f;
112
+ return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) };
113
+ }
114
+
115
+ // Forward method for converting scale and rotation properties of each
116
+ // Gaussian to a 3D covariance matrix in world space. Also takes care
117
+ // of quaternion normalization.
118
+ __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
119
+ {
120
+ // Create scaling matrix
121
+ glm::mat3 S = glm::mat3(1.0f);
122
+ S[0][0] = mod * scale.x;
123
+ S[1][1] = mod * scale.y;
124
+ S[2][2] = mod * scale.z;
125
+
126
+ // Normalize quaternion to get valid rotation
127
+ glm::vec4 q = rot;// / glm::length(rot);
128
+ float r = q.x;
129
+ float x = q.y;
130
+ float y = q.z;
131
+ float z = q.w;
132
+
133
+ // Compute rotation matrix from quaternion
134
+ glm::mat3 R = glm::mat3(
135
+ 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
136
+ 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
137
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
138
+ );
139
+
140
+ glm::mat3 M = S * R;
141
+
142
+ // Compute 3D world covariance matrix Sigma
143
+ glm::mat3 Sigma = glm::transpose(M) * M;
144
+
145
+ // Covariance is symmetric, only store upper right
146
+ cov3D[0] = Sigma[0][0];
147
+ cov3D[1] = Sigma[0][1];
148
+ cov3D[2] = Sigma[0][2];
149
+ cov3D[3] = Sigma[1][1];
150
+ cov3D[4] = Sigma[1][2];
151
+ cov3D[5] = Sigma[2][2];
152
+ }
153
+
154
+ // Perform initial steps for each Gaussian prior to rasterization.
155
+ template<int C>
156
+ __global__ void preprocessCUDA(int P, int D, int M,
157
+ const float* orig_points,
158
+ const glm::vec3* scales,
159
+ const float scale_modifier,
160
+ const glm::vec4* rotations,
161
+ const float* opacities,
162
+ const float* shs,
163
+ bool* clamped,
164
+ const float* cov3D_precomp,
165
+ const float* colors_precomp,
166
+ const float* viewmatrix,
167
+ const float* projmatrix,
168
+ const glm::vec3* cam_pos,
169
+ const int W, int H,
170
+ const float tan_fovx, float tan_fovy,
171
+ const float focal_x, float focal_y,
172
+ int* radii,
173
+ float2* points_xy_image,
174
+ float* depths,
175
+ float* cov3Ds,
176
+ float* rgb,
177
+ float4* conic_opacity,
178
+ const dim3 grid,
179
+ uint32_t* tiles_touched,
180
+ bool prefiltered)
181
+ {
182
+ auto idx = cg::this_grid().thread_rank();
183
+ if (idx >= P)
184
+ return;
185
+
186
+ // Initialize radius and touched tiles to 0. If this isn't changed,
187
+ // this Gaussian will not be processed further.
188
+ radii[idx] = 0;
189
+ tiles_touched[idx] = 0;
190
+
191
+ // Perform near culling, quit if outside.
192
+ float3 p_view;
193
+ if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
194
+ return;
195
+
196
+ // Transform point by projecting
197
+ float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
198
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
199
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
200
+ float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
201
+
202
+ // If 3D covariance matrix is precomputed, use it, otherwise compute
203
+ // from scaling and rotation parameters.
204
+ const float* cov3D;
205
+ if (cov3D_precomp != nullptr)
206
+ {
207
+ cov3D = cov3D_precomp + idx * 6;
208
+ }
209
+ else
210
+ {
211
+ computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
212
+ cov3D = cov3Ds + idx * 6;
213
+ }
214
+
215
+ // Compute 2D screen-space covariance matrix
216
+ float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
217
+
218
+ // Invert covariance (EWA algorithm)
219
+ float det = (cov.x * cov.z - cov.y * cov.y);
220
+ if (det == 0.0f)
221
+ return;
222
+ float det_inv = 1.f / det;
223
+ float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
224
+
225
+ // Compute extent in screen space (by finding eigenvalues of
226
+ // 2D covariance matrix). Use extent to compute a bounding rectangle
227
+ // of screen-space tiles that this Gaussian overlaps with. Quit if
228
+ // rectangle covers 0 tiles.
229
+ float mid = 0.5f * (cov.x + cov.z);
230
+ float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
231
+ float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
232
+ float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
233
+ float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
234
+ uint2 rect_min, rect_max;
235
+ getRect(point_image, my_radius, rect_min, rect_max, grid);
236
+ if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
237
+ return;
238
+
239
+ // If colors have been precomputed, use them, otherwise convert
240
+ // spherical harmonics coefficients to RGB color.
241
+ if (colors_precomp == nullptr)
242
+ {
243
+ glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
244
+ rgb[idx * C + 0] = result.x;
245
+ rgb[idx * C + 1] = result.y;
246
+ rgb[idx * C + 2] = result.z;
247
+ }
248
+
249
+ // Store some useful helper data for the next steps.
250
+ depths[idx] = p_view.z;
251
+ radii[idx] = my_radius;
252
+ points_xy_image[idx] = point_image;
253
+ // Inverse 2D covariance and opacity neatly pack into one float4
254
+ conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] };
255
+ tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
256
+ }
257
+
258
+ // Main rasterization method. Collaboratively works on one tile per
259
+ // block, each thread treats one pixel. Alternates between fetching
260
+ // and rasterizing data.
261
+ template <uint32_t CHANNELS>
262
+ __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
263
+ renderCUDA(
264
+ const uint2* __restrict__ ranges,
265
+ const uint32_t* __restrict__ point_list,
266
+ int W, int H,
267
+ const float2* __restrict__ points_xy_image,
268
+ const float* __restrict__ features,
269
+ const float4* __restrict__ conic_opacity,
270
+ float* __restrict__ final_T,
271
+ uint32_t* __restrict__ n_contrib,
272
+ const float* __restrict__ bg_color,
273
+ float* __restrict__ out_color)
274
+ {
275
+ // Identify current tile and associated min/max pixel range.
276
+ auto block = cg::this_thread_block();
277
+ uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
278
+ uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
279
+ uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
280
+ uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
281
+ uint32_t pix_id = W * pix.y + pix.x;
282
+ float2 pixf = { (float)pix.x, (float)pix.y };
283
+
284
+ // Check if this thread is associated with a valid pixel or outside.
285
+ bool inside = pix.x < W&& pix.y < H;
286
+ // Done threads can help with fetching, but don't rasterize
287
+ bool done = !inside;
288
+
289
+ // Load start/end range of IDs to process in bit sorted list.
290
+ uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
291
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
292
+ int toDo = range.y - range.x;
293
+
294
+ // Allocate storage for batches of collectively fetched data.
295
+ __shared__ int collected_id[BLOCK_SIZE];
296
+ __shared__ float2 collected_xy[BLOCK_SIZE];
297
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
298
+
299
+ // Initialize helper variables
300
+ float T = 1.0f;
301
+ uint32_t contributor = 0;
302
+ uint32_t last_contributor = 0;
303
+ float C[CHANNELS] = { 0 };
304
+
305
+ // Iterate over batches until all done or range is complete
306
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
307
+ {
308
+ // End if entire block votes that it is done rasterizing
309
+ int num_done = __syncthreads_count(done);
310
+ if (num_done == BLOCK_SIZE)
311
+ break;
312
+
313
+ // Collectively fetch per-Gaussian data from global to shared
314
+ int progress = i * BLOCK_SIZE + block.thread_rank();
315
+ if (range.x + progress < range.y)
316
+ {
317
+ int coll_id = point_list[range.x + progress];
318
+ collected_id[block.thread_rank()] = coll_id;
319
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
320
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
321
+ }
322
+ block.sync();
323
+
324
+ // Iterate over current batch
325
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
326
+ {
327
+ // Keep track of current position in range
328
+ contributor++;
329
+
330
+ // Resample using conic matrix (cf. "Surface
331
+ // Splatting" by Zwicker et al., 2001)
332
+ float2 xy = collected_xy[j];
333
+ float2 d = { xy.x - pixf.x, xy.y - pixf.y };
334
+ float4 con_o = collected_conic_opacity[j];
335
+ float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
336
+ if (power > 0.0f)
337
+ continue;
338
+
339
+ // Eq. (2) from 3D Gaussian splatting paper.
340
+ // Obtain alpha by multiplying with Gaussian opacity
341
+ // and its exponential falloff from mean.
342
+ // Avoid numerical instabilities (see paper appendix).
343
+ float alpha = min(0.99f, con_o.w * exp(power));
344
+ if (alpha < 1.0f / 255.0f)
345
+ continue;
346
+ float test_T = T * (1 - alpha);
347
+ if (test_T < 0.0001f)
348
+ {
349
+ done = true;
350
+ continue;
351
+ }
352
+
353
+ // Eq. (3) from 3D Gaussian splatting paper.
354
+ for (int ch = 0; ch < CHANNELS; ch++)
355
+ C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
356
+
357
+ T = test_T;
358
+
359
+ // Keep track of last range entry to update this
360
+ // pixel.
361
+ last_contributor = contributor;
362
+ }
363
+ }
364
+
365
+ // All threads that treat valid pixel write out their final
366
+ // rendering data to the frame and auxiliary buffers.
367
+ if (inside)
368
+ {
369
+ final_T[pix_id] = T;
370
+ n_contrib[pix_id] = last_contributor;
371
+ for (int ch = 0; ch < CHANNELS; ch++)
372
+ out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
373
+ }
374
+ }
375
+
376
+ void FORWARD::render(
377
+ const dim3 grid, dim3 block,
378
+ const uint2* ranges,
379
+ const uint32_t* point_list,
380
+ int W, int H,
381
+ const float2* means2D,
382
+ const float* colors,
383
+ const float4* conic_opacity,
384
+ float* final_T,
385
+ uint32_t* n_contrib,
386
+ const float* bg_color,
387
+ float* out_color)
388
+ {
389
+ renderCUDA<NUM_CHANNELS> << <grid, block >> > (
390
+ ranges,
391
+ point_list,
392
+ W, H,
393
+ means2D,
394
+ colors,
395
+ conic_opacity,
396
+ final_T,
397
+ n_contrib,
398
+ bg_color,
399
+ out_color);
400
+ }
401
+
402
+ void FORWARD::preprocess(int P, int D, int M,
403
+ const float* means3D,
404
+ const glm::vec3* scales,
405
+ const float scale_modifier,
406
+ const glm::vec4* rotations,
407
+ const float* opacities,
408
+ const float* shs,
409
+ bool* clamped,
410
+ const float* cov3D_precomp,
411
+ const float* colors_precomp,
412
+ const float* viewmatrix,
413
+ const float* projmatrix,
414
+ const glm::vec3* cam_pos,
415
+ const int W, int H,
416
+ const float focal_x, float focal_y,
417
+ const float tan_fovx, float tan_fovy,
418
+ int* radii,
419
+ float2* means2D,
420
+ float* depths,
421
+ float* cov3Ds,
422
+ float* rgb,
423
+ float4* conic_opacity,
424
+ const dim3 grid,
425
+ uint32_t* tiles_touched,
426
+ bool prefiltered)
427
+ {
428
+ preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
429
+ P, D, M,
430
+ means3D,
431
+ scales,
432
+ scale_modifier,
433
+ rotations,
434
+ opacities,
435
+ shs,
436
+ clamped,
437
+ cov3D_precomp,
438
+ colors_precomp,
439
+ viewmatrix,
440
+ projmatrix,
441
+ cam_pos,
442
+ W, H,
443
+ tan_fovx, tan_fovy,
444
+ focal_x, focal_y,
445
+ radii,
446
+ means2D,
447
+ depths,
448
+ cov3Ds,
449
+ rgb,
450
+ conic_opacity,
451
+ grid,
452
+ tiles_touched,
453
+ prefiltered
454
+ );
455
+ }
diff-gaussian-rasterization/cuda_rasterizer/forward.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_FORWARD_H_INCLUDED
14
+
15
+ #include <cuda.h>
16
+ #include "cuda_runtime.h"
17
+ #include "device_launch_parameters.h"
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace FORWARD
22
+ {
23
+ // Perform initial steps for each Gaussian prior to rasterization.
24
+ void preprocess(int P, int D, int M,
25
+ const float* orig_points,
26
+ const glm::vec3* scales,
27
+ const float scale_modifier,
28
+ const glm::vec4* rotations,
29
+ const float* opacities,
30
+ const float* shs,
31
+ bool* clamped,
32
+ const float* cov3D_precomp,
33
+ const float* colors_precomp,
34
+ const float* viewmatrix,
35
+ const float* projmatrix,
36
+ const glm::vec3* cam_pos,
37
+ const int W, int H,
38
+ const float focal_x, float focal_y,
39
+ const float tan_fovx, float tan_fovy,
40
+ int* radii,
41
+ float2* points_xy_image,
42
+ float* depths,
43
+ float* cov3Ds,
44
+ float* colors,
45
+ float4* conic_opacity,
46
+ const dim3 grid,
47
+ uint32_t* tiles_touched,
48
+ bool prefiltered);
49
+
50
+ // Main rasterization method.
51
+ void render(
52
+ const dim3 grid, dim3 block,
53
+ const uint2* ranges,
54
+ const uint32_t* point_list,
55
+ int W, int H,
56
+ const float2* points_xy_image,
57
+ const float* features,
58
+ const float4* conic_opacity,
59
+ float* final_T,
60
+ uint32_t* n_contrib,
61
+ const float* bg_color,
62
+ float* out_color);
63
+ }
64
+
65
+
66
+ #endif
diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_H_INCLUDED
13
+ #define CUDA_RASTERIZER_H_INCLUDED
14
+
15
+ #include <vector>
16
+ #include <functional>
17
+
18
+ namespace CudaRasterizer
19
+ {
20
+ class Rasterizer
21
+ {
22
+ public:
23
+
24
+ static void markVisible(
25
+ int P,
26
+ float* means3D,
27
+ float* viewmatrix,
28
+ float* projmatrix,
29
+ bool* present);
30
+
31
+ static int forward(
32
+ std::function<char* (size_t)> geometryBuffer,
33
+ std::function<char* (size_t)> binningBuffer,
34
+ std::function<char* (size_t)> imageBuffer,
35
+ const int P, int D, int M,
36
+ const float* background,
37
+ const int width, int height,
38
+ const float* means3D,
39
+ const float* shs,
40
+ const float* colors_precomp,
41
+ const float* opacities,
42
+ const float* scales,
43
+ const float scale_modifier,
44
+ const float* rotations,
45
+ const float* cov3D_precomp,
46
+ const float* viewmatrix,
47
+ const float* projmatrix,
48
+ const float* cam_pos,
49
+ const float tan_fovx, float tan_fovy,
50
+ const bool prefiltered,
51
+ float* out_color,
52
+ int* radii = nullptr,
53
+ bool debug = false);
54
+
55
+ static void backward(
56
+ const int P, int D, int M, int R,
57
+ const float* background,
58
+ const int width, int height,
59
+ const float* means3D,
60
+ const float* shs,
61
+ const float* colors_precomp,
62
+ const float* scales,
63
+ const float scale_modifier,
64
+ const float* rotations,
65
+ const float* cov3D_precomp,
66
+ const float* viewmatrix,
67
+ const float* projmatrix,
68
+ const float* campos,
69
+ const float tan_fovx, float tan_fovy,
70
+ const int* radii,
71
+ char* geom_buffer,
72
+ char* binning_buffer,
73
+ char* image_buffer,
74
+ const float* dL_dpix,
75
+ float* dL_dmean2D,
76
+ float* dL_dconic,
77
+ float* dL_dopacity,
78
+ float* dL_dcolor,
79
+ float* dL_dmean3D,
80
+ float* dL_dcov3D,
81
+ float* dL_dsh,
82
+ float* dL_dscale,
83
+ float* dL_drot,
84
+ bool debug);
85
+ };
86
+ };
87
+
88
+ #endif
diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #include "rasterizer_impl.h"
13
+ #include <iostream>
14
+ #include <fstream>
15
+ #include <algorithm>
16
+ #include <numeric>
17
+ #include <cuda.h>
18
+ #include "cuda_runtime.h"
19
+ #include "device_launch_parameters.h"
20
+ #include <cub/cub.cuh>
21
+ #include <cub/device/device_radix_sort.cuh>
22
+ #define GLM_FORCE_CUDA
23
+ #include <glm/glm.hpp>
24
+
25
+ #include <cooperative_groups.h>
26
+ #include <cooperative_groups/reduce.h>
27
+ namespace cg = cooperative_groups;
28
+
29
+ #include "auxiliary.h"
30
+ #include "forward.h"
31
+ #include "backward.h"
32
+
33
+ // Helper function to find the next-highest bit of the MSB
34
+ // on the CPU.
35
+ uint32_t getHigherMsb(uint32_t n)
36
+ {
37
+ uint32_t msb = sizeof(n) * 4;
38
+ uint32_t step = msb;
39
+ while (step > 1)
40
+ {
41
+ step /= 2;
42
+ if (n >> msb)
43
+ msb += step;
44
+ else
45
+ msb -= step;
46
+ }
47
+ if (n >> msb)
48
+ msb++;
49
+ return msb;
50
+ }
51
+
52
+ // Wrapper method to call auxiliary coarse frustum containment test.
53
+ // Mark all Gaussians that pass it.
54
+ __global__ void checkFrustum(int P,
55
+ const float* orig_points,
56
+ const float* viewmatrix,
57
+ const float* projmatrix,
58
+ bool* present)
59
+ {
60
+ auto idx = cg::this_grid().thread_rank();
61
+ if (idx >= P)
62
+ return;
63
+
64
+ float3 p_view;
65
+ present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
66
+ }
67
+
68
+ // Generates one key/value pair for all Gaussian / tile overlaps.
69
+ // Run once per Gaussian (1:N mapping).
70
+ __global__ void duplicateWithKeys(
71
+ int P,
72
+ const float2* points_xy,
73
+ const float* depths,
74
+ const uint32_t* offsets,
75
+ uint64_t* gaussian_keys_unsorted,
76
+ uint32_t* gaussian_values_unsorted,
77
+ int* radii,
78
+ dim3 grid)
79
+ {
80
+ auto idx = cg::this_grid().thread_rank();
81
+ if (idx >= P)
82
+ return;
83
+
84
+ // Generate no key/value pair for invisible Gaussians
85
+ if (radii[idx] > 0)
86
+ {
87
+ // Find this Gaussian's offset in buffer for writing keys/values.
88
+ uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
89
+ uint2 rect_min, rect_max;
90
+
91
+ getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
92
+
93
+ // For each tile that the bounding rect overlaps, emit a
94
+ // key/value pair. The key is | tile ID | depth |,
95
+ // and the value is the ID of the Gaussian. Sorting the values
96
+ // with this key yields Gaussian IDs in a list, such that they
97
+ // are first sorted by tile and then by depth.
98
+ for (int y = rect_min.y; y < rect_max.y; y++)
99
+ {
100
+ for (int x = rect_min.x; x < rect_max.x; x++)
101
+ {
102
+ uint64_t key = y * grid.x + x;
103
+ key <<= 32;
104
+ key |= *((uint32_t*)&depths[idx]);
105
+ gaussian_keys_unsorted[off] = key;
106
+ gaussian_values_unsorted[off] = idx;
107
+ off++;
108
+ }
109
+ }
110
+ }
111
+ }
112
+
113
+ // Check keys to see if it is at the start/end of one tile's range in
114
+ // the full sorted list. If yes, write start/end of this tile.
115
+ // Run once per instanced (duplicated) Gaussian ID.
116
+ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges)
117
+ {
118
+ auto idx = cg::this_grid().thread_rank();
119
+ if (idx >= L)
120
+ return;
121
+
122
+ // Read tile ID from key. Update start/end of tile range if at limit.
123
+ uint64_t key = point_list_keys[idx];
124
+ uint32_t currtile = key >> 32;
125
+ if (idx == 0)
126
+ ranges[currtile].x = 0;
127
+ else
128
+ {
129
+ uint32_t prevtile = point_list_keys[idx - 1] >> 32;
130
+ if (currtile != prevtile)
131
+ {
132
+ ranges[prevtile].y = idx;
133
+ ranges[currtile].x = idx;
134
+ }
135
+ }
136
+ if (idx == L - 1)
137
+ ranges[currtile].y = L;
138
+ }
139
+
140
+ // Mark Gaussians as visible/invisible, based on view frustum testing
141
+ void CudaRasterizer::Rasterizer::markVisible(
142
+ int P,
143
+ float* means3D,
144
+ float* viewmatrix,
145
+ float* projmatrix,
146
+ bool* present)
147
+ {
148
+ checkFrustum << <(P + 255) / 256, 256 >> > (
149
+ P,
150
+ means3D,
151
+ viewmatrix, projmatrix,
152
+ present);
153
+ }
154
+
155
+ CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P)
156
+ {
157
+ GeometryState geom;
158
+ obtain(chunk, geom.depths, P, 128);
159
+ obtain(chunk, geom.clamped, P * 3, 128);
160
+ obtain(chunk, geom.internal_radii, P, 128);
161
+ obtain(chunk, geom.means2D, P, 128);
162
+ obtain(chunk, geom.cov3D, P * 6, 128);
163
+ obtain(chunk, geom.conic_opacity, P, 128);
164
+ obtain(chunk, geom.rgb, P * 3, 128);
165
+ obtain(chunk, geom.tiles_touched, P, 128);
166
+ cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
167
+ obtain(chunk, geom.scanning_space, geom.scan_size, 128);
168
+ obtain(chunk, geom.point_offsets, P, 128);
169
+ return geom;
170
+ }
171
+
172
+ CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N)
173
+ {
174
+ ImageState img;
175
+ obtain(chunk, img.accum_alpha, N, 128);
176
+ obtain(chunk, img.n_contrib, N, 128);
177
+ obtain(chunk, img.ranges, N, 128);
178
+ return img;
179
+ }
180
+
181
+ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P)
182
+ {
183
+ BinningState binning;
184
+ obtain(chunk, binning.point_list, P, 128);
185
+ obtain(chunk, binning.point_list_unsorted, P, 128);
186
+ obtain(chunk, binning.point_list_keys, P, 128);
187
+ obtain(chunk, binning.point_list_keys_unsorted, P, 128);
188
+ cub::DeviceRadixSort::SortPairs(
189
+ nullptr, binning.sorting_size,
190
+ binning.point_list_keys_unsorted, binning.point_list_keys,
191
+ binning.point_list_unsorted, binning.point_list, P);
192
+ obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
193
+ return binning;
194
+ }
195
+
196
+ // Forward rendering procedure for differentiable rasterization
197
+ // of Gaussians.
198
+ int CudaRasterizer::Rasterizer::forward(
199
+ std::function<char* (size_t)> geometryBuffer,
200
+ std::function<char* (size_t)> binningBuffer,
201
+ std::function<char* (size_t)> imageBuffer,
202
+ const int P, int D, int M,
203
+ const float* background,
204
+ const int width, int height,
205
+ const float* means3D,
206
+ const float* shs,
207
+ const float* colors_precomp,
208
+ const float* opacities,
209
+ const float* scales,
210
+ const float scale_modifier,
211
+ const float* rotations,
212
+ const float* cov3D_precomp,
213
+ const float* viewmatrix,
214
+ const float* projmatrix,
215
+ const float* cam_pos,
216
+ const float tan_fovx, float tan_fovy,
217
+ const bool prefiltered,
218
+ float* out_color,
219
+ int* radii,
220
+ bool debug)
221
+ {
222
+ const float focal_y = height / (2.0f * tan_fovy);
223
+ const float focal_x = width / (2.0f * tan_fovx);
224
+
225
+ size_t chunk_size = required<GeometryState>(P);
226
+ char* chunkptr = geometryBuffer(chunk_size);
227
+ GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
228
+
229
+ if (radii == nullptr)
230
+ {
231
+ radii = geomState.internal_radii;
232
+ }
233
+
234
+ dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
235
+ dim3 block(BLOCK_X, BLOCK_Y, 1);
236
+
237
+ // Dynamically resize image-based auxiliary buffers during training
238
+ size_t img_chunk_size = required<ImageState>(width * height);
239
+ char* img_chunkptr = imageBuffer(img_chunk_size);
240
+ ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
241
+
242
+ if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
243
+ {
244
+ throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!");
245
+ }
246
+
247
+ // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
248
+ CHECK_CUDA(FORWARD::preprocess(
249
+ P, D, M,
250
+ means3D,
251
+ (glm::vec3*)scales,
252
+ scale_modifier,
253
+ (glm::vec4*)rotations,
254
+ opacities,
255
+ shs,
256
+ geomState.clamped,
257
+ cov3D_precomp,
258
+ colors_precomp,
259
+ viewmatrix, projmatrix,
260
+ (glm::vec3*)cam_pos,
261
+ width, height,
262
+ focal_x, focal_y,
263
+ tan_fovx, tan_fovy,
264
+ radii,
265
+ geomState.means2D,
266
+ geomState.depths,
267
+ geomState.cov3D,
268
+ geomState.rgb,
269
+ geomState.conic_opacity,
270
+ tile_grid,
271
+ geomState.tiles_touched,
272
+ prefiltered
273
+ ), debug)
274
+
275
+ // Compute prefix sum over full list of touched tile counts by Gaussians
276
+ // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
277
+ CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
278
+
279
+ // Retrieve total number of Gaussian instances to launch and resize aux buffers
280
+ int num_rendered;
281
+ CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
282
+
283
+ size_t binning_chunk_size = required<BinningState>(num_rendered);
284
+ char* binning_chunkptr = binningBuffer(binning_chunk_size);
285
+ BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
286
+
287
+ // For each instance to be rendered, produce adequate [ tile | depth ] key
288
+ // and corresponding dublicated Gaussian indices to be sorted
289
+ duplicateWithKeys << <(P + 255) / 256, 256 >> > (
290
+ P,
291
+ geomState.means2D,
292
+ geomState.depths,
293
+ geomState.point_offsets,
294
+ binningState.point_list_keys_unsorted,
295
+ binningState.point_list_unsorted,
296
+ radii,
297
+ tile_grid)
298
+ CHECK_CUDA(, debug)
299
+
300
+ int bit = getHigherMsb(tile_grid.x * tile_grid.y);
301
+
302
+ // Sort complete list of (duplicated) Gaussian indices by keys
303
+ CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
304
+ binningState.list_sorting_space,
305
+ binningState.sorting_size,
306
+ binningState.point_list_keys_unsorted, binningState.point_list_keys,
307
+ binningState.point_list_unsorted, binningState.point_list,
308
+ num_rendered, 0, 32 + bit), debug)
309
+
310
+ CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
311
+
312
+ // Identify start and end of per-tile workloads in sorted list
313
+ if (num_rendered > 0)
314
+ identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
315
+ num_rendered,
316
+ binningState.point_list_keys,
317
+ imgState.ranges);
318
+ CHECK_CUDA(, debug)
319
+
320
+ // Let each tile blend its range of Gaussians independently in parallel
321
+ const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
322
+ CHECK_CUDA(FORWARD::render(
323
+ tile_grid, block,
324
+ imgState.ranges,
325
+ binningState.point_list,
326
+ width, height,
327
+ geomState.means2D,
328
+ feature_ptr,
329
+ geomState.conic_opacity,
330
+ imgState.accum_alpha,
331
+ imgState.n_contrib,
332
+ background,
333
+ out_color), debug)
334
+
335
+ return num_rendered;
336
+ }
337
+
338
+ // Produce necessary gradients for optimization, corresponding
339
+ // to forward render pass
340
+ void CudaRasterizer::Rasterizer::backward(
341
+ const int P, int D, int M, int R,
342
+ const float* background,
343
+ const int width, int height,
344
+ const float* means3D,
345
+ const float* shs,
346
+ const float* colors_precomp,
347
+ const float* scales,
348
+ const float scale_modifier,
349
+ const float* rotations,
350
+ const float* cov3D_precomp,
351
+ const float* viewmatrix,
352
+ const float* projmatrix,
353
+ const float* campos,
354
+ const float tan_fovx, float tan_fovy,
355
+ const int* radii,
356
+ char* geom_buffer,
357
+ char* binning_buffer,
358
+ char* img_buffer,
359
+ const float* dL_dpix,
360
+ float* dL_dmean2D,
361
+ float* dL_dconic,
362
+ float* dL_dopacity,
363
+ float* dL_dcolor,
364
+ float* dL_dmean3D,
365
+ float* dL_dcov3D,
366
+ float* dL_dsh,
367
+ float* dL_dscale,
368
+ float* dL_drot,
369
+ bool debug)
370
+ {
371
+ GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
372
+ BinningState binningState = BinningState::fromChunk(binning_buffer, R);
373
+ ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
374
+
375
+ if (radii == nullptr)
376
+ {
377
+ radii = geomState.internal_radii;
378
+ }
379
+
380
+ const float focal_y = height / (2.0f * tan_fovy);
381
+ const float focal_x = width / (2.0f * tan_fovx);
382
+
383
+ const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
384
+ const dim3 block(BLOCK_X, BLOCK_Y, 1);
385
+
386
+ // Compute loss gradients w.r.t. 2D mean position, conic matrix,
387
+ // opacity and RGB of Gaussians from per-pixel loss gradients.
388
+ // If we were given precomputed colors and not SHs, use them.
389
+ const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
390
+ CHECK_CUDA(BACKWARD::render(
391
+ tile_grid,
392
+ block,
393
+ imgState.ranges,
394
+ binningState.point_list,
395
+ width, height,
396
+ background,
397
+ geomState.means2D,
398
+ geomState.conic_opacity,
399
+ color_ptr,
400
+ imgState.accum_alpha,
401
+ imgState.n_contrib,
402
+ dL_dpix,
403
+ (float3*)dL_dmean2D,
404
+ (float4*)dL_dconic,
405
+ dL_dopacity,
406
+ dL_dcolor), debug)
407
+
408
+ // Take care of the rest of preprocessing. Was the precomputed covariance
409
+ // given to us or a scales/rot pair? If precomputed, pass that. If not,
410
+ // use the one we computed ourselves.
411
+ const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
412
+ CHECK_CUDA(BACKWARD::preprocess(P, D, M,
413
+ (float3*)means3D,
414
+ radii,
415
+ shs,
416
+ geomState.clamped,
417
+ (glm::vec3*)scales,
418
+ (glm::vec4*)rotations,
419
+ scale_modifier,
420
+ cov3D_ptr,
421
+ viewmatrix,
422
+ projmatrix,
423
+ focal_x, focal_y,
424
+ tan_fovx, tan_fovy,
425
+ (glm::vec3*)campos,
426
+ (float3*)dL_dmean2D,
427
+ dL_dconic,
428
+ (glm::vec3*)dL_dmean3D,
429
+ dL_dcolor,
430
+ dL_dcov3D,
431
+ dL_dsh,
432
+ (glm::vec3*)dL_dscale,
433
+ (glm::vec4*)dL_drot), debug)
434
+ }
diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include <iostream>
15
+ #include <vector>
16
+ #include "rasterizer.h"
17
+ #include <cuda_runtime_api.h>
18
+
19
+ namespace CudaRasterizer
20
+ {
21
+ template <typename T>
22
+ static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
23
+ {
24
+ std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
25
+ ptr = reinterpret_cast<T*>(offset);
26
+ chunk = reinterpret_cast<char*>(ptr + count);
27
+ }
28
+
29
+ struct GeometryState
30
+ {
31
+ size_t scan_size;
32
+ float* depths;
33
+ char* scanning_space;
34
+ bool* clamped;
35
+ int* internal_radii;
36
+ float2* means2D;
37
+ float* cov3D;
38
+ float4* conic_opacity;
39
+ float* rgb;
40
+ uint32_t* point_offsets;
41
+ uint32_t* tiles_touched;
42
+
43
+ static GeometryState fromChunk(char*& chunk, size_t P);
44
+ };
45
+
46
+ struct ImageState
47
+ {
48
+ uint2* ranges;
49
+ uint32_t* n_contrib;
50
+ float* accum_alpha;
51
+
52
+ static ImageState fromChunk(char*& chunk, size_t N);
53
+ };
54
+
55
+ struct BinningState
56
+ {
57
+ size_t sorting_size;
58
+ uint64_t* point_list_keys_unsorted;
59
+ uint64_t* point_list_keys;
60
+ uint32_t* point_list_unsorted;
61
+ uint32_t* point_list;
62
+ char* list_sorting_space;
63
+
64
+ static BinningState fromChunk(char*& chunk, size_t P);
65
+ };
66
+
67
+ template<typename T>
68
+ size_t required(size_t P)
69
+ {
70
+ char* size = nullptr;
71
+ T::fromChunk(size, P);
72
+ return ((size_t)size) + 128;
73
+ }
74
+ };
diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from typing import NamedTuple
13
+ import torch.nn as nn
14
+ import torch
15
+ from . import _C
16
+
17
+ def cpu_deep_copy_tuple(input_tuple):
18
+ copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
19
+ return tuple(copied_tensors)
20
+
21
+ def rasterize_gaussians(
22
+ means3D,
23
+ means2D,
24
+ sh,
25
+ colors_precomp,
26
+ opacities,
27
+ scales,
28
+ rotations,
29
+ cov3Ds_precomp,
30
+ raster_settings,
31
+ ):
32
+ return _RasterizeGaussians.apply(
33
+ means3D,
34
+ means2D,
35
+ sh,
36
+ colors_precomp,
37
+ opacities,
38
+ scales,
39
+ rotations,
40
+ cov3Ds_precomp,
41
+ raster_settings,
42
+ )
43
+
44
+ class _RasterizeGaussians(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(
47
+ ctx,
48
+ means3D,
49
+ means2D,
50
+ sh,
51
+ colors_precomp,
52
+ opacities,
53
+ scales,
54
+ rotations,
55
+ cov3Ds_precomp,
56
+ raster_settings,
57
+ ):
58
+
59
+ # Restructure arguments the way that the C++ lib expects them
60
+ args = (
61
+ raster_settings.bg,
62
+ means3D,
63
+ colors_precomp,
64
+ opacities,
65
+ scales,
66
+ rotations,
67
+ raster_settings.scale_modifier,
68
+ cov3Ds_precomp,
69
+ raster_settings.viewmatrix,
70
+ raster_settings.projmatrix,
71
+ raster_settings.tanfovx,
72
+ raster_settings.tanfovy,
73
+ raster_settings.image_height,
74
+ raster_settings.image_width,
75
+ sh,
76
+ raster_settings.sh_degree,
77
+ raster_settings.campos,
78
+ raster_settings.prefiltered,
79
+ raster_settings.debug
80
+ )
81
+
82
+ # Invoke C++/CUDA rasterizer
83
+ if raster_settings.debug:
84
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
85
+ try:
86
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
87
+ except Exception as ex:
88
+ torch.save(cpu_args, "snapshot_fw.dump")
89
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
90
+ raise ex
91
+ else:
92
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
93
+
94
+ # Keep relevant tensors for backward
95
+ ctx.raster_settings = raster_settings
96
+ ctx.num_rendered = num_rendered
97
+ ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
98
+ return color, radii
99
+
100
+ @staticmethod
101
+ def backward(ctx, grad_out_color, _):
102
+
103
+ # Restore necessary values from context
104
+ num_rendered = ctx.num_rendered
105
+ raster_settings = ctx.raster_settings
106
+ colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
107
+
108
+ # Restructure args as C++ method expects them
109
+ args = (raster_settings.bg,
110
+ means3D,
111
+ radii,
112
+ colors_precomp,
113
+ scales,
114
+ rotations,
115
+ raster_settings.scale_modifier,
116
+ cov3Ds_precomp,
117
+ raster_settings.viewmatrix,
118
+ raster_settings.projmatrix,
119
+ raster_settings.tanfovx,
120
+ raster_settings.tanfovy,
121
+ grad_out_color,
122
+ sh,
123
+ raster_settings.sh_degree,
124
+ raster_settings.campos,
125
+ geomBuffer,
126
+ num_rendered,
127
+ binningBuffer,
128
+ imgBuffer,
129
+ raster_settings.debug)
130
+
131
+ # Compute gradients for relevant tensors by invoking backward method
132
+ if raster_settings.debug:
133
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
134
+ try:
135
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
136
+ except Exception as ex:
137
+ torch.save(cpu_args, "snapshot_bw.dump")
138
+ print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
139
+ raise ex
140
+ else:
141
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
142
+
143
+ grads = (
144
+ grad_means3D,
145
+ grad_means2D,
146
+ grad_sh,
147
+ grad_colors_precomp,
148
+ grad_opacities,
149
+ grad_scales,
150
+ grad_rotations,
151
+ grad_cov3Ds_precomp,
152
+ None,
153
+ )
154
+
155
+ return grads
156
+
157
+ class GaussianRasterizationSettings(NamedTuple):
158
+ image_height: int
159
+ image_width: int
160
+ tanfovx : float
161
+ tanfovy : float
162
+ bg : torch.Tensor
163
+ scale_modifier : float
164
+ viewmatrix : torch.Tensor
165
+ projmatrix : torch.Tensor
166
+ sh_degree : int
167
+ campos : torch.Tensor
168
+ prefiltered : bool
169
+ debug : bool
170
+
171
+ class GaussianRasterizer(nn.Module):
172
+ def __init__(self, raster_settings):
173
+ super().__init__()
174
+ self.raster_settings = raster_settings
175
+
176
+ def markVisible(self, positions):
177
+ # Mark visible points (based on frustum culling for camera) with a boolean
178
+ with torch.no_grad():
179
+ raster_settings = self.raster_settings
180
+ visible = _C.mark_visible(
181
+ positions,
182
+ raster_settings.viewmatrix,
183
+ raster_settings.projmatrix)
184
+
185
+ return visible
186
+
187
+ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
188
+
189
+ raster_settings = self.raster_settings
190
+
191
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
192
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
193
+
194
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
195
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
196
+
197
+ if shs is None:
198
+ shs = torch.Tensor([])
199
+ if colors_precomp is None:
200
+ colors_precomp = torch.Tensor([])
201
+
202
+ if scales is None:
203
+ scales = torch.Tensor([])
204
+ if rotations is None:
205
+ rotations = torch.Tensor([])
206
+ if cov3D_precomp is None:
207
+ cov3D_precomp = torch.Tensor([])
208
+
209
+ # Invoke C++/CUDA rasterization routine
210
+ return rasterize_gaussians(
211
+ means3D,
212
+ means2D,
213
+ shs,
214
+ colors_precomp,
215
+ opacities,
216
+ scales,
217
+ rotations,
218
+ cov3D_precomp,
219
+ raster_settings,
220
+ )
221
+
diff-gaussian-rasterization/ext.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #include <torch/extension.h>
13
+ #include "rasterize_points.h"
14
+
15
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16
+ m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
17
+ m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
18
+ m.def("mark_visible", &markVisible);
19
+ }
diff-gaussian-rasterization/rasterize_points.cu ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #include <math.h>
13
+ #include <torch/extension.h>
14
+ #include <cstdio>
15
+ #include <sstream>
16
+ #include <iostream>
17
+ #include <tuple>
18
+ #include <stdio.h>
19
+ #include <cuda_runtime_api.h>
20
+ #include <memory>
21
+ #include "cuda_rasterizer/config.h"
22
+ #include "cuda_rasterizer/rasterizer.h"
23
+ #include <fstream>
24
+ #include <string>
25
+ #include <functional>
26
+
27
+ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
28
+ auto lambda = [&t](size_t N) {
29
+ t.resize_({(long long)N});
30
+ return reinterpret_cast<char*>(t.contiguous().data_ptr());
31
+ };
32
+ return lambda;
33
+ }
34
+
35
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
36
+ RasterizeGaussiansCUDA(
37
+ const torch::Tensor& background,
38
+ const torch::Tensor& means3D,
39
+ const torch::Tensor& colors,
40
+ const torch::Tensor& opacity,
41
+ const torch::Tensor& scales,
42
+ const torch::Tensor& rotations,
43
+ const float scale_modifier,
44
+ const torch::Tensor& cov3D_precomp,
45
+ const torch::Tensor& viewmatrix,
46
+ const torch::Tensor& projmatrix,
47
+ const float tan_fovx,
48
+ const float tan_fovy,
49
+ const int image_height,
50
+ const int image_width,
51
+ const torch::Tensor& sh,
52
+ const int degree,
53
+ const torch::Tensor& campos,
54
+ const bool prefiltered,
55
+ const bool debug)
56
+ {
57
+ if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
58
+ AT_ERROR("means3D must have dimensions (num_points, 3)");
59
+ }
60
+
61
+ const int P = means3D.size(0);
62
+ const int H = image_height;
63
+ const int W = image_width;
64
+
65
+ auto int_opts = means3D.options().dtype(torch::kInt32);
66
+ auto float_opts = means3D.options().dtype(torch::kFloat32);
67
+
68
+ torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
69
+ torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
70
+
71
+ torch::Device device(torch::kCUDA);
72
+ torch::TensorOptions options(torch::kByte);
73
+ torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
74
+ torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
75
+ torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
76
+ std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
77
+ std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
78
+ std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
79
+
80
+ int rendered = 0;
81
+ if(P != 0)
82
+ {
83
+ int M = 0;
84
+ if(sh.size(0) != 0)
85
+ {
86
+ M = sh.size(1);
87
+ }
88
+
89
+ rendered = CudaRasterizer::Rasterizer::forward(
90
+ geomFunc,
91
+ binningFunc,
92
+ imgFunc,
93
+ P, degree, M,
94
+ background.contiguous().data<float>(),
95
+ W, H,
96
+ means3D.contiguous().data<float>(),
97
+ sh.contiguous().data_ptr<float>(),
98
+ colors.contiguous().data<float>(),
99
+ opacity.contiguous().data<float>(),
100
+ scales.contiguous().data_ptr<float>(),
101
+ scale_modifier,
102
+ rotations.contiguous().data_ptr<float>(),
103
+ cov3D_precomp.contiguous().data<float>(),
104
+ viewmatrix.contiguous().data<float>(),
105
+ projmatrix.contiguous().data<float>(),
106
+ campos.contiguous().data<float>(),
107
+ tan_fovx,
108
+ tan_fovy,
109
+ prefiltered,
110
+ out_color.contiguous().data<float>(),
111
+ radii.contiguous().data<int>(),
112
+ debug);
113
+ }
114
+ return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
115
+ }
116
+
117
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
118
+ RasterizeGaussiansBackwardCUDA(
119
+ const torch::Tensor& background,
120
+ const torch::Tensor& means3D,
121
+ const torch::Tensor& radii,
122
+ const torch::Tensor& colors,
123
+ const torch::Tensor& scales,
124
+ const torch::Tensor& rotations,
125
+ const float scale_modifier,
126
+ const torch::Tensor& cov3D_precomp,
127
+ const torch::Tensor& viewmatrix,
128
+ const torch::Tensor& projmatrix,
129
+ const float tan_fovx,
130
+ const float tan_fovy,
131
+ const torch::Tensor& dL_dout_color,
132
+ const torch::Tensor& sh,
133
+ const int degree,
134
+ const torch::Tensor& campos,
135
+ const torch::Tensor& geomBuffer,
136
+ const int R,
137
+ const torch::Tensor& binningBuffer,
138
+ const torch::Tensor& imageBuffer,
139
+ const bool debug)
140
+ {
141
+ const int P = means3D.size(0);
142
+ const int H = dL_dout_color.size(1);
143
+ const int W = dL_dout_color.size(2);
144
+
145
+ int M = 0;
146
+ if(sh.size(0) != 0)
147
+ {
148
+ M = sh.size(1);
149
+ }
150
+
151
+ torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
152
+ torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
153
+ torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
154
+ torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
155
+ torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
156
+ torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
157
+ torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
158
+ torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
159
+ torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
160
+
161
+ if(P != 0)
162
+ {
163
+ CudaRasterizer::Rasterizer::backward(P, degree, M, R,
164
+ background.contiguous().data<float>(),
165
+ W, H,
166
+ means3D.contiguous().data<float>(),
167
+ sh.contiguous().data<float>(),
168
+ colors.contiguous().data<float>(),
169
+ scales.data_ptr<float>(),
170
+ scale_modifier,
171
+ rotations.data_ptr<float>(),
172
+ cov3D_precomp.contiguous().data<float>(),
173
+ viewmatrix.contiguous().data<float>(),
174
+ projmatrix.contiguous().data<float>(),
175
+ campos.contiguous().data<float>(),
176
+ tan_fovx,
177
+ tan_fovy,
178
+ radii.contiguous().data<int>(),
179
+ reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
180
+ reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
181
+ reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
182
+ dL_dout_color.contiguous().data<float>(),
183
+ dL_dmeans2D.contiguous().data<float>(),
184
+ dL_dconic.contiguous().data<float>(),
185
+ dL_dopacity.contiguous().data<float>(),
186
+ dL_dcolors.contiguous().data<float>(),
187
+ dL_dmeans3D.contiguous().data<float>(),
188
+ dL_dcov3D.contiguous().data<float>(),
189
+ dL_dsh.contiguous().data<float>(),
190
+ dL_dscales.contiguous().data<float>(),
191
+ dL_drotations.contiguous().data<float>(),
192
+ debug);
193
+ }
194
+
195
+ return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
196
+ }
197
+
198
+ torch::Tensor markVisible(
199
+ torch::Tensor& means3D,
200
+ torch::Tensor& viewmatrix,
201
+ torch::Tensor& projmatrix)
202
+ {
203
+ const int P = means3D.size(0);
204
+
205
+ torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
206
+
207
+ if(P != 0)
208
+ {
209
+ CudaRasterizer::Rasterizer::markVisible(P,
210
+ means3D.contiguous().data<float>(),
211
+ viewmatrix.contiguous().data<float>(),
212
+ projmatrix.contiguous().data<float>(),
213
+ present.contiguous().data<bool>());
214
+ }
215
+
216
+ return present;
217
+ }
diff-gaussian-rasterization/rasterize_points.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact [email protected]
10
+ */
11
+
12
+ #pragma once
13
+ #include <torch/extension.h>
14
+ #include <cstdio>
15
+ #include <tuple>
16
+ #include <string>
17
+
18
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
19
+ RasterizeGaussiansCUDA(
20
+ const torch::Tensor& background,
21
+ const torch::Tensor& means3D,
22
+ const torch::Tensor& colors,
23
+ const torch::Tensor& opacity,
24
+ const torch::Tensor& scales,
25
+ const torch::Tensor& rotations,
26
+ const float scale_modifier,
27
+ const torch::Tensor& cov3D_precomp,
28
+ const torch::Tensor& viewmatrix,
29
+ const torch::Tensor& projmatrix,
30
+ const float tan_fovx,
31
+ const float tan_fovy,
32
+ const int image_height,
33
+ const int image_width,
34
+ const torch::Tensor& sh,
35
+ const int degree,
36
+ const torch::Tensor& campos,
37
+ const bool prefiltered,
38
+ const bool debug);
39
+
40
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
41
+ RasterizeGaussiansBackwardCUDA(
42
+ const torch::Tensor& background,
43
+ const torch::Tensor& means3D,
44
+ const torch::Tensor& radii,
45
+ const torch::Tensor& colors,
46
+ const torch::Tensor& scales,
47
+ const torch::Tensor& rotations,
48
+ const float scale_modifier,
49
+ const torch::Tensor& cov3D_precomp,
50
+ const torch::Tensor& viewmatrix,
51
+ const torch::Tensor& projmatrix,
52
+ const float tan_fovx,
53
+ const float tan_fovy,
54
+ const torch::Tensor& dL_dout_color,
55
+ const torch::Tensor& sh,
56
+ const int degree,
57
+ const torch::Tensor& campos,
58
+ const torch::Tensor& geomBuffer,
59
+ const int R,
60
+ const torch::Tensor& binningBuffer,
61
+ const torch::Tensor& imageBuffer,
62
+ const bool debug);
63
+
64
+ torch::Tensor markVisible(
65
+ torch::Tensor& means3D,
66
+ torch::Tensor& viewmatrix,
67
+ torch::Tensor& projmatrix);
diff-gaussian-rasterization/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact [email protected]
10
+ #
11
+
12
+ from setuptools import setup
13
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14
+ import os
15
+ os.path.dirname(os.path.abspath(__file__))
16
+
17
+ setup(
18
+ name="diff_gaussian_rasterization",
19
+ packages=['diff_gaussian_rasterization'],
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="diff_gaussian_rasterization._C",
23
+ sources=[
24
+ "cuda_rasterizer/rasterizer_impl.cu",
25
+ "cuda_rasterizer/forward.cu",
26
+ "cuda_rasterizer/backward.cu",
27
+ "rasterize_points.cu",
28
+ "ext.cpp"],
29
+ extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
30
+ ],
31
+ cmdclass={
32
+ 'build_ext': BuildExtension
33
+ }
34
+ )
diff-gaussian-rasterization/third_party/stbi_image_write.h ADDED
@@ -0,0 +1,1724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* stb_image_write - v1.16 - public domain - http://nothings.org/stb
2
+ writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015
3
+ no warranty implied; use at your own risk
4
+
5
+ Before #including,
6
+
7
+ #define STB_IMAGE_WRITE_IMPLEMENTATION
8
+
9
+ in the file that you want to have the implementation.
10
+
11
+ Will probably not work correctly with strict-aliasing optimizations.
12
+
13
+ ABOUT:
14
+
15
+ This header file is a library for writing images to C stdio or a callback.
16
+
17
+ The PNG output is not optimal; it is 20-50% larger than the file
18
+ written by a decent optimizing implementation; though providing a custom
19
+ zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that.
20
+ This library is designed for source code compactness and simplicity,
21
+ not optimal image file size or run-time performance.
22
+
23
+ BUILDING:
24
+
25
+ You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.
26
+ You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace
27
+ malloc,realloc,free.
28
+ You can #define STBIW_MEMMOVE() to replace memmove()
29
+ You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function
30
+ for PNG compression (instead of the builtin one), it must have the following signature:
31
+ unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality);
32
+ The returned data will be freed with STBIW_FREE() (free() by default),
33
+ so it must be heap allocated with STBIW_MALLOC() (malloc() by default),
34
+
35
+ UNICODE:
36
+
37
+ If compiling for Windows and you wish to use Unicode filenames, compile
38
+ with
39
+ #define STBIW_WINDOWS_UTF8
40
+ and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert
41
+ Windows wchar_t filenames to utf8.
42
+
43
+ USAGE:
44
+
45
+ There are five functions, one for each image file format:
46
+
47
+ int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
48
+ int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
49
+ int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
50
+ int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality);
51
+ int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
52
+
53
+ void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically
54
+
55
+ There are also five equivalent functions that use an arbitrary write function. You are
56
+ expected to open/close your file-equivalent before and after calling these:
57
+
58
+ int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
59
+ int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
60
+ int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
61
+ int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
62
+ int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
63
+
64
+ where the callback is:
65
+ void stbi_write_func(void *context, void *data, int size);
66
+
67
+ You can configure it with these global variables:
68
+ int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE
69
+ int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression
70
+ int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode
71
+
72
+
73
+ You can define STBI_WRITE_NO_STDIO to disable the file variant of these
74
+ functions, so the library will not use stdio.h at all. However, this will
75
+ also disable HDR writing, because it requires stdio for formatted output.
76
+
77
+ Each function returns 0 on failure and non-0 on success.
78
+
79
+ The functions create an image file defined by the parameters. The image
80
+ is a rectangle of pixels stored from left-to-right, top-to-bottom.
81
+ Each pixel contains 'comp' channels of data stored interleaved with 8-bits
82
+ per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is
83
+ monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.
84
+ The *data pointer points to the first byte of the top-left-most pixel.
85
+ For PNG, "stride_in_bytes" is the distance in bytes from the first byte of
86
+ a row of pixels to the first byte of the next row of pixels.
87
+
88
+ PNG creates output files with the same number of components as the input.
89
+ The BMP format expands Y to RGB in the file format and does not
90
+ output alpha.
91
+
92
+ PNG supports writing rectangles of data even when the bytes storing rows of
93
+ data are not consecutive in memory (e.g. sub-rectangles of a larger image),
94
+ by supplying the stride between the beginning of adjacent rows. The other
95
+ formats do not. (Thus you cannot write a native-format BMP through the BMP
96
+ writer, both because it is in BGR order and because it may have padding
97
+ at the end of the line.)
98
+
99
+ PNG allows you to set the deflate compression level by setting the global
100
+ variable 'stbi_write_png_compression_level' (it defaults to 8).
101
+
102
+ HDR expects linear float data. Since the format is always 32-bit rgb(e)
103
+ data, alpha (if provided) is discarded, and for monochrome data it is
104
+ replicated across all three channels.
105
+
106
+ TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed
107
+ data, set the global variable 'stbi_write_tga_with_rle' to 0.
108
+
109
+ JPEG does ignore alpha channels in input data; quality is between 1 and 100.
110
+ Higher quality looks better but results in a bigger image.
111
+ JPEG baseline (no JPEG progressive).
112
+
113
+ CREDITS:
114
+
115
+
116
+ Sean Barrett - PNG/BMP/TGA
117
+ Baldur Karlsson - HDR
118
+ Jean-Sebastien Guay - TGA monochrome
119
+ Tim Kelsey - misc enhancements
120
+ Alan Hickman - TGA RLE
121
+ Emmanuel Julien - initial file IO callback implementation
122
+ Jon Olick - original jo_jpeg.cpp code
123
+ Daniel Gibson - integrate JPEG, allow external zlib
124
+ Aarni Koskela - allow choosing PNG filter
125
+
126
+ bugfixes:
127
+ github:Chribba
128
+ Guillaume Chereau
129
+ github:jry2
130
+ github:romigrou
131
+ Sergio Gonzalez
132
+ Jonas Karlsson
133
+ Filip Wasil
134
+ Thatcher Ulrich
135
+ github:poppolopoppo
136
+ Patrick Boettcher
137
+ github:xeekworx
138
+ Cap Petschulat
139
+ Simon Rodriguez
140
+ Ivan Tikhonov
141
+ github:ignotion
142
+ Adam Schackart
143
+ Andrew Kensler
144
+
145
+ LICENSE
146
+
147
+ See end of file for license information.
148
+
149
+ */
150
+
151
+ #ifndef INCLUDE_STB_IMAGE_WRITE_H
152
+ #define INCLUDE_STB_IMAGE_WRITE_H
153
+
154
+ #include <stdlib.h>
155
+
156
+ // if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline'
157
+ #ifndef STBIWDEF
158
+ #ifdef STB_IMAGE_WRITE_STATIC
159
+ #define STBIWDEF static
160
+ #else
161
+ #ifdef __cplusplus
162
+ #define STBIWDEF extern "C"
163
+ #else
164
+ #define STBIWDEF extern
165
+ #endif
166
+ #endif
167
+ #endif
168
+
169
+ #ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations
170
+ STBIWDEF int stbi_write_tga_with_rle;
171
+ STBIWDEF int stbi_write_png_compression_level;
172
+ STBIWDEF int stbi_write_force_png_filter;
173
+ #endif
174
+
175
+ #ifndef STBI_WRITE_NO_STDIO
176
+ STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
177
+ STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
178
+ STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
179
+ STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
180
+ STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
181
+
182
+ #ifdef STBIW_WINDOWS_UTF8
183
+ STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
184
+ #endif
185
+ #endif
186
+
187
+ typedef void stbi_write_func(void *context, void *data, int size);
188
+
189
+ STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
190
+ STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
191
+ STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
192
+ STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
193
+ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
194
+
195
+ STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean);
196
+
197
+ #endif//INCLUDE_STB_IMAGE_WRITE_H
198
+
199
+ #ifdef STB_IMAGE_WRITE_IMPLEMENTATION
200
+
201
+ #ifdef _WIN32
202
+ #ifndef _CRT_SECURE_NO_WARNINGS
203
+ #define _CRT_SECURE_NO_WARNINGS
204
+ #endif
205
+ #ifndef _CRT_NONSTDC_NO_DEPRECATE
206
+ #define _CRT_NONSTDC_NO_DEPRECATE
207
+ #endif
208
+ #endif
209
+
210
+ #ifndef STBI_WRITE_NO_STDIO
211
+ #include <stdio.h>
212
+ #endif // STBI_WRITE_NO_STDIO
213
+
214
+ #include <stdarg.h>
215
+ #include <stdlib.h>
216
+ #include <string.h>
217
+ #include <math.h>
218
+
219
+ #if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED))
220
+ // ok
221
+ #elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED)
222
+ // ok
223
+ #else
224
+ #error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)."
225
+ #endif
226
+
227
+ #ifndef STBIW_MALLOC
228
+ #define STBIW_MALLOC(sz) malloc(sz)
229
+ #define STBIW_REALLOC(p,newsz) realloc(p,newsz)
230
+ #define STBIW_FREE(p) free(p)
231
+ #endif
232
+
233
+ #ifndef STBIW_REALLOC_SIZED
234
+ #define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz)
235
+ #endif
236
+
237
+
238
+ #ifndef STBIW_MEMMOVE
239
+ #define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)
240
+ #endif
241
+
242
+
243
+ #ifndef STBIW_ASSERT
244
+ #include <assert.h>
245
+ #define STBIW_ASSERT(x) assert(x)
246
+ #endif
247
+
248
+ #define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff)
249
+
250
+ #ifdef STB_IMAGE_WRITE_STATIC
251
+ static int stbi_write_png_compression_level = 8;
252
+ static int stbi_write_tga_with_rle = 1;
253
+ static int stbi_write_force_png_filter = -1;
254
+ #else
255
+ int stbi_write_png_compression_level = 8;
256
+ int stbi_write_tga_with_rle = 1;
257
+ int stbi_write_force_png_filter = -1;
258
+ #endif
259
+
260
+ static int stbi__flip_vertically_on_write = 0;
261
+
262
+ STBIWDEF void stbi_flip_vertically_on_write(int flag)
263
+ {
264
+ stbi__flip_vertically_on_write = flag;
265
+ }
266
+
267
+ typedef struct
268
+ {
269
+ stbi_write_func *func;
270
+ void *context;
271
+ unsigned char buffer[64];
272
+ int buf_used;
273
+ } stbi__write_context;
274
+
275
+ // initialize a callback-based context
276
+ static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context)
277
+ {
278
+ s->func = c;
279
+ s->context = context;
280
+ }
281
+
282
+ #ifndef STBI_WRITE_NO_STDIO
283
+
284
+ static void stbi__stdio_write(void *context, void *data, int size)
285
+ {
286
+ fwrite(data,1,size,(FILE*) context);
287
+ }
288
+
289
+ #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
290
+ #ifdef __cplusplus
291
+ #define STBIW_EXTERN extern "C"
292
+ #else
293
+ #define STBIW_EXTERN extern
294
+ #endif
295
+ STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);
296
+ STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);
297
+
298
+ STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)
299
+ {
300
+ return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);
301
+ }
302
+ #endif
303
+
304
+ static FILE *stbiw__fopen(char const *filename, char const *mode)
305
+ {
306
+ FILE *f;
307
+ #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
308
+ wchar_t wMode[64];
309
+ wchar_t wFilename[1024];
310
+ if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))
311
+ return 0;
312
+
313
+ if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))
314
+ return 0;
315
+
316
+ #if defined(_MSC_VER) && _MSC_VER >= 1400
317
+ if (0 != _wfopen_s(&f, wFilename, wMode))
318
+ f = 0;
319
+ #else
320
+ f = _wfopen(wFilename, wMode);
321
+ #endif
322
+
323
+ #elif defined(_MSC_VER) && _MSC_VER >= 1400
324
+ if (0 != fopen_s(&f, filename, mode))
325
+ f=0;
326
+ #else
327
+ f = fopen(filename, mode);
328
+ #endif
329
+ return f;
330
+ }
331
+
332
+ static int stbi__start_write_file(stbi__write_context *s, const char *filename)
333
+ {
334
+ FILE *f = stbiw__fopen(filename, "wb");
335
+ stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f);
336
+ return f != NULL;
337
+ }
338
+
339
+ static void stbi__end_write_file(stbi__write_context *s)
340
+ {
341
+ fclose((FILE *)s->context);
342
+ }
343
+
344
+ #endif // !STBI_WRITE_NO_STDIO
345
+
346
+ typedef unsigned int stbiw_uint32;
347
+ typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];
348
+
349
+ static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v)
350
+ {
351
+ while (*fmt) {
352
+ switch (*fmt++) {
353
+ case ' ': break;
354
+ case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int));
355
+ s->func(s->context,&x,1);
356
+ break; }
357
+ case '2': { int x = va_arg(v,int);
358
+ unsigned char b[2];
359
+ b[0] = STBIW_UCHAR(x);
360
+ b[1] = STBIW_UCHAR(x>>8);
361
+ s->func(s->context,b,2);
362
+ break; }
363
+ case '4': { stbiw_uint32 x = va_arg(v,int);
364
+ unsigned char b[4];
365
+ b[0]=STBIW_UCHAR(x);
366
+ b[1]=STBIW_UCHAR(x>>8);
367
+ b[2]=STBIW_UCHAR(x>>16);
368
+ b[3]=STBIW_UCHAR(x>>24);
369
+ s->func(s->context,b,4);
370
+ break; }
371
+ default:
372
+ STBIW_ASSERT(0);
373
+ return;
374
+ }
375
+ }
376
+ }
377
+
378
+ static void stbiw__writef(stbi__write_context *s, const char *fmt, ...)
379
+ {
380
+ va_list v;
381
+ va_start(v, fmt);
382
+ stbiw__writefv(s, fmt, v);
383
+ va_end(v);
384
+ }
385
+
386
+ static void stbiw__write_flush(stbi__write_context *s)
387
+ {
388
+ if (s->buf_used) {
389
+ s->func(s->context, &s->buffer, s->buf_used);
390
+ s->buf_used = 0;
391
+ }
392
+ }
393
+
394
+ static void stbiw__putc(stbi__write_context *s, unsigned char c)
395
+ {
396
+ s->func(s->context, &c, 1);
397
+ }
398
+
399
+ static void stbiw__write1(stbi__write_context *s, unsigned char a)
400
+ {
401
+ if ((size_t)s->buf_used + 1 > sizeof(s->buffer))
402
+ stbiw__write_flush(s);
403
+ s->buffer[s->buf_used++] = a;
404
+ }
405
+
406
+ static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c)
407
+ {
408
+ int n;
409
+ if ((size_t)s->buf_used + 3 > sizeof(s->buffer))
410
+ stbiw__write_flush(s);
411
+ n = s->buf_used;
412
+ s->buf_used = n+3;
413
+ s->buffer[n+0] = a;
414
+ s->buffer[n+1] = b;
415
+ s->buffer[n+2] = c;
416
+ }
417
+
418
+ static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d)
419
+ {
420
+ unsigned char bg[3] = { 255, 0, 255}, px[3];
421
+ int k;
422
+
423
+ if (write_alpha < 0)
424
+ stbiw__write1(s, d[comp - 1]);
425
+
426
+ switch (comp) {
427
+ case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case
428
+ case 1:
429
+ if (expand_mono)
430
+ stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp
431
+ else
432
+ stbiw__write1(s, d[0]); // monochrome TGA
433
+ break;
434
+ case 4:
435
+ if (!write_alpha) {
436
+ // composite against pink background
437
+ for (k = 0; k < 3; ++k)
438
+ px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255;
439
+ stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]);
440
+ break;
441
+ }
442
+ /* FALLTHROUGH */
443
+ case 3:
444
+ stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]);
445
+ break;
446
+ }
447
+ if (write_alpha > 0)
448
+ stbiw__write1(s, d[comp - 1]);
449
+ }
450
+
451
+ static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)
452
+ {
453
+ stbiw_uint32 zero = 0;
454
+ int i,j, j_end;
455
+
456
+ if (y <= 0)
457
+ return;
458
+
459
+ if (stbi__flip_vertically_on_write)
460
+ vdir *= -1;
461
+
462
+ if (vdir < 0) {
463
+ j_end = -1; j = y-1;
464
+ } else {
465
+ j_end = y; j = 0;
466
+ }
467
+
468
+ for (; j != j_end; j += vdir) {
469
+ for (i=0; i < x; ++i) {
470
+ unsigned char *d = (unsigned char *) data + (j*x+i)*comp;
471
+ stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d);
472
+ }
473
+ stbiw__write_flush(s);
474
+ s->func(s->context, &zero, scanline_pad);
475
+ }
476
+ }
477
+
478
+ static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)
479
+ {
480
+ if (y < 0 || x < 0) {
481
+ return 0;
482
+ } else {
483
+ va_list v;
484
+ va_start(v, fmt);
485
+ stbiw__writefv(s, fmt, v);
486
+ va_end(v);
487
+ stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono);
488
+ return 1;
489
+ }
490
+ }
491
+
492
+ static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data)
493
+ {
494
+ if (comp != 4) {
495
+ // write RGB bitmap
496
+ int pad = (-x*3) & 3;
497
+ return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad,
498
+ "11 4 22 4" "4 44 22 444444",
499
+ 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header
500
+ 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header
501
+ } else {
502
+ // RGBA bitmaps need a v4 header
503
+ // use BI_BITFIELDS mode with 32bpp and alpha mask
504
+ // (straight BI_RGB with alpha mask doesn't work in most readers)
505
+ return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0,
506
+ "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444",
507
+ 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header
508
+ 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header
509
+ }
510
+ }
511
+
512
+ STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
513
+ {
514
+ stbi__write_context s = { 0 };
515
+ stbi__start_write_callbacks(&s, func, context);
516
+ return stbi_write_bmp_core(&s, x, y, comp, data);
517
+ }
518
+
519
+ #ifndef STBI_WRITE_NO_STDIO
520
+ STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)
521
+ {
522
+ stbi__write_context s = { 0 };
523
+ if (stbi__start_write_file(&s,filename)) {
524
+ int r = stbi_write_bmp_core(&s, x, y, comp, data);
525
+ stbi__end_write_file(&s);
526
+ return r;
527
+ } else
528
+ return 0;
529
+ }
530
+ #endif //!STBI_WRITE_NO_STDIO
531
+
532
+ static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data)
533
+ {
534
+ int has_alpha = (comp == 2 || comp == 4);
535
+ int colorbytes = has_alpha ? comp-1 : comp;
536
+ int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3
537
+
538
+ if (y < 0 || x < 0)
539
+ return 0;
540
+
541
+ if (!stbi_write_tga_with_rle) {
542
+ return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0,
543
+ "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8);
544
+ } else {
545
+ int i,j,k;
546
+ int jend, jdir;
547
+
548
+ stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8);
549
+
550
+ if (stbi__flip_vertically_on_write) {
551
+ j = 0;
552
+ jend = y;
553
+ jdir = 1;
554
+ } else {
555
+ j = y-1;
556
+ jend = -1;
557
+ jdir = -1;
558
+ }
559
+ for (; j != jend; j += jdir) {
560
+ unsigned char *row = (unsigned char *) data + j * x * comp;
561
+ int len;
562
+
563
+ for (i = 0; i < x; i += len) {
564
+ unsigned char *begin = row + i * comp;
565
+ int diff = 1;
566
+ len = 1;
567
+
568
+ if (i < x - 1) {
569
+ ++len;
570
+ diff = memcmp(begin, row + (i + 1) * comp, comp);
571
+ if (diff) {
572
+ const unsigned char *prev = begin;
573
+ for (k = i + 2; k < x && len < 128; ++k) {
574
+ if (memcmp(prev, row + k * comp, comp)) {
575
+ prev += comp;
576
+ ++len;
577
+ } else {
578
+ --len;
579
+ break;
580
+ }
581
+ }
582
+ } else {
583
+ for (k = i + 2; k < x && len < 128; ++k) {
584
+ if (!memcmp(begin, row + k * comp, comp)) {
585
+ ++len;
586
+ } else {
587
+ break;
588
+ }
589
+ }
590
+ }
591
+ }
592
+
593
+ if (diff) {
594
+ unsigned char header = STBIW_UCHAR(len - 1);
595
+ stbiw__write1(s, header);
596
+ for (k = 0; k < len; ++k) {
597
+ stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp);
598
+ }
599
+ } else {
600
+ unsigned char header = STBIW_UCHAR(len - 129);
601
+ stbiw__write1(s, header);
602
+ stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin);
603
+ }
604
+ }
605
+ }
606
+ stbiw__write_flush(s);
607
+ }
608
+ return 1;
609
+ }
610
+
611
+ STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
612
+ {
613
+ stbi__write_context s = { 0 };
614
+ stbi__start_write_callbacks(&s, func, context);
615
+ return stbi_write_tga_core(&s, x, y, comp, (void *) data);
616
+ }
617
+
618
+ #ifndef STBI_WRITE_NO_STDIO
619
+ STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)
620
+ {
621
+ stbi__write_context s = { 0 };
622
+ if (stbi__start_write_file(&s,filename)) {
623
+ int r = stbi_write_tga_core(&s, x, y, comp, (void *) data);
624
+ stbi__end_write_file(&s);
625
+ return r;
626
+ } else
627
+ return 0;
628
+ }
629
+ #endif
630
+
631
+ // *************************************************************************************************
632
+ // Radiance RGBE HDR writer
633
+ // by Baldur Karlsson
634
+
635
+ #define stbiw__max(a, b) ((a) > (b) ? (a) : (b))
636
+
637
+ #ifndef STBI_WRITE_NO_STDIO
638
+
639
+ static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)
640
+ {
641
+ int exponent;
642
+ float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));
643
+
644
+ if (maxcomp < 1e-32f) {
645
+ rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;
646
+ } else {
647
+ float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;
648
+
649
+ rgbe[0] = (unsigned char)(linear[0] * normalize);
650
+ rgbe[1] = (unsigned char)(linear[1] * normalize);
651
+ rgbe[2] = (unsigned char)(linear[2] * normalize);
652
+ rgbe[3] = (unsigned char)(exponent + 128);
653
+ }
654
+ }
655
+
656
+ static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte)
657
+ {
658
+ unsigned char lengthbyte = STBIW_UCHAR(length+128);
659
+ STBIW_ASSERT(length+128 <= 255);
660
+ s->func(s->context, &lengthbyte, 1);
661
+ s->func(s->context, &databyte, 1);
662
+ }
663
+
664
+ static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data)
665
+ {
666
+ unsigned char lengthbyte = STBIW_UCHAR(length);
667
+ STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code
668
+ s->func(s->context, &lengthbyte, 1);
669
+ s->func(s->context, data, length);
670
+ }
671
+
672
+ static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline)
673
+ {
674
+ unsigned char scanlineheader[4] = { 2, 2, 0, 0 };
675
+ unsigned char rgbe[4];
676
+ float linear[3];
677
+ int x;
678
+
679
+ scanlineheader[2] = (width&0xff00)>>8;
680
+ scanlineheader[3] = (width&0x00ff);
681
+
682
+ /* skip RLE for images too small or large */
683
+ if (width < 8 || width >= 32768) {
684
+ for (x=0; x < width; x++) {
685
+ switch (ncomp) {
686
+ case 4: /* fallthrough */
687
+ case 3: linear[2] = scanline[x*ncomp + 2];
688
+ linear[1] = scanline[x*ncomp + 1];
689
+ linear[0] = scanline[x*ncomp + 0];
690
+ break;
691
+ default:
692
+ linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
693
+ break;
694
+ }
695
+ stbiw__linear_to_rgbe(rgbe, linear);
696
+ s->func(s->context, rgbe, 4);
697
+ }
698
+ } else {
699
+ int c,r;
700
+ /* encode into scratch buffer */
701
+ for (x=0; x < width; x++) {
702
+ switch(ncomp) {
703
+ case 4: /* fallthrough */
704
+ case 3: linear[2] = scanline[x*ncomp + 2];
705
+ linear[1] = scanline[x*ncomp + 1];
706
+ linear[0] = scanline[x*ncomp + 0];
707
+ break;
708
+ default:
709
+ linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
710
+ break;
711
+ }
712
+ stbiw__linear_to_rgbe(rgbe, linear);
713
+ scratch[x + width*0] = rgbe[0];
714
+ scratch[x + width*1] = rgbe[1];
715
+ scratch[x + width*2] = rgbe[2];
716
+ scratch[x + width*3] = rgbe[3];
717
+ }
718
+
719
+ s->func(s->context, scanlineheader, 4);
720
+
721
+ /* RLE each component separately */
722
+ for (c=0; c < 4; c++) {
723
+ unsigned char *comp = &scratch[width*c];
724
+
725
+ x = 0;
726
+ while (x < width) {
727
+ // find first run
728
+ r = x;
729
+ while (r+2 < width) {
730
+ if (comp[r] == comp[r+1] && comp[r] == comp[r+2])
731
+ break;
732
+ ++r;
733
+ }
734
+ if (r+2 >= width)
735
+ r = width;
736
+ // dump up to first run
737
+ while (x < r) {
738
+ int len = r-x;
739
+ if (len > 128) len = 128;
740
+ stbiw__write_dump_data(s, len, &comp[x]);
741
+ x += len;
742
+ }
743
+ // if there's a run, output it
744
+ if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd
745
+ // find next byte after run
746
+ while (r < width && comp[r] == comp[x])
747
+ ++r;
748
+ // output run up to r
749
+ while (x < r) {
750
+ int len = r-x;
751
+ if (len > 127) len = 127;
752
+ stbiw__write_run_data(s, len, comp[x]);
753
+ x += len;
754
+ }
755
+ }
756
+ }
757
+ }
758
+ }
759
+ }
760
+
761
+ static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data)
762
+ {
763
+ if (y <= 0 || x <= 0 || data == NULL)
764
+ return 0;
765
+ else {
766
+ // Each component is stored separately. Allocate scratch space for full output scanline.
767
+ unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);
768
+ int i, len;
769
+ char buffer[128];
770
+ char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n";
771
+ s->func(s->context, header, sizeof(header)-1);
772
+
773
+ #ifdef __STDC_LIB_EXT1__
774
+ len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
775
+ #else
776
+ len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
777
+ #endif
778
+ s->func(s->context, buffer, len);
779
+
780
+ for(i=0; i < y; i++)
781
+ stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i));
782
+ STBIW_FREE(scratch);
783
+ return 1;
784
+ }
785
+ }
786
+
787
+ STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data)
788
+ {
789
+ stbi__write_context s = { 0 };
790
+ stbi__start_write_callbacks(&s, func, context);
791
+ return stbi_write_hdr_core(&s, x, y, comp, (float *) data);
792
+ }
793
+
794
+ STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)
795
+ {
796
+ stbi__write_context s = { 0 };
797
+ if (stbi__start_write_file(&s,filename)) {
798
+ int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data);
799
+ stbi__end_write_file(&s);
800
+ return r;
801
+ } else
802
+ return 0;
803
+ }
804
+ #endif // STBI_WRITE_NO_STDIO
805
+
806
+
807
+ //////////////////////////////////////////////////////////////////////////////
808
+ //
809
+ // PNG writer
810
+ //
811
+
812
+ #ifndef STBIW_ZLIB_COMPRESS
813
+ // stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()
814
+ #define stbiw__sbraw(a) ((int *) (void *) (a) - 2)
815
+ #define stbiw__sbm(a) stbiw__sbraw(a)[0]
816
+ #define stbiw__sbn(a) stbiw__sbraw(a)[1]
817
+
818
+ #define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))
819
+ #define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)
820
+ #define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))
821
+
822
+ #define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))
823
+ #define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0)
824
+ #define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)
825
+
826
+ static void *stbiw__sbgrowf(void **arr, int increment, int itemsize)
827
+ {
828
+ int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;
829
+ void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2);
830
+ STBIW_ASSERT(p);
831
+ if (p) {
832
+ if (!*arr) ((int *) p)[1] = 0;
833
+ *arr = (void *) ((int *) p + 2);
834
+ stbiw__sbm(*arr) = m;
835
+ }
836
+ return *arr;
837
+ }
838
+
839
+ static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)
840
+ {
841
+ while (*bitcount >= 8) {
842
+ stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer));
843
+ *bitbuffer >>= 8;
844
+ *bitcount -= 8;
845
+ }
846
+ return data;
847
+ }
848
+
849
+ static int stbiw__zlib_bitrev(int code, int codebits)
850
+ {
851
+ int res=0;
852
+ while (codebits--) {
853
+ res = (res << 1) | (code & 1);
854
+ code >>= 1;
855
+ }
856
+ return res;
857
+ }
858
+
859
+ static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)
860
+ {
861
+ int i;
862
+ for (i=0; i < limit && i < 258; ++i)
863
+ if (a[i] != b[i]) break;
864
+ return i;
865
+ }
866
+
867
+ static unsigned int stbiw__zhash(unsigned char *data)
868
+ {
869
+ stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);
870
+ hash ^= hash << 3;
871
+ hash += hash >> 5;
872
+ hash ^= hash << 4;
873
+ hash += hash >> 17;
874
+ hash ^= hash << 25;
875
+ hash += hash >> 6;
876
+ return hash;
877
+ }
878
+
879
+ #define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))
880
+ #define stbiw__zlib_add(code,codebits) \
881
+ (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())
882
+ #define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)
883
+ // default huffman tables
884
+ #define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8)
885
+ #define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9)
886
+ #define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7)
887
+ #define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8)
888
+ #define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))
889
+ #define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))
890
+
891
+ #define stbiw__ZHASH 16384
892
+
893
+ #endif // STBIW_ZLIB_COMPRESS
894
+
895
+ STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)
896
+ {
897
+ #ifdef STBIW_ZLIB_COMPRESS
898
+ // user provided a zlib compress implementation, use that
899
+ return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality);
900
+ #else // use builtin
901
+ static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };
902
+ static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 };
903
+ static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };
904
+ static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };
905
+ unsigned int bitbuf=0;
906
+ int i,j, bitcount=0;
907
+ unsigned char *out = NULL;
908
+ unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));
909
+ if (hash_table == NULL)
910
+ return NULL;
911
+ if (quality < 5) quality = 5;
912
+
913
+ stbiw__sbpush(out, 0x78); // DEFLATE 32K window
914
+ stbiw__sbpush(out, 0x5e); // FLEVEL = 1
915
+ stbiw__zlib_add(1,1); // BFINAL = 1
916
+ stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman
917
+
918
+ for (i=0; i < stbiw__ZHASH; ++i)
919
+ hash_table[i] = NULL;
920
+
921
+ i=0;
922
+ while (i < data_len-3) {
923
+ // hash next 3 bytes of data to be compressed
924
+ int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;
925
+ unsigned char *bestloc = 0;
926
+ unsigned char **hlist = hash_table[h];
927
+ int n = stbiw__sbcount(hlist);
928
+ for (j=0; j < n; ++j) {
929
+ if (hlist[j]-data > i-32768) { // if entry lies within window
930
+ int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);
931
+ if (d >= best) { best=d; bestloc=hlist[j]; }
932
+ }
933
+ }
934
+ // when hash table entry is too long, delete half the entries
935
+ if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {
936
+ STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);
937
+ stbiw__sbn(hash_table[h]) = quality;
938
+ }
939
+ stbiw__sbpush(hash_table[h],data+i);
940
+
941
+ if (bestloc) {
942
+ // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal
943
+ h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);
944
+ hlist = hash_table[h];
945
+ n = stbiw__sbcount(hlist);
946
+ for (j=0; j < n; ++j) {
947
+ if (hlist[j]-data > i-32767) {
948
+ int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);
949
+ if (e > best) { // if next match is better, bail on current match
950
+ bestloc = NULL;
951
+ break;
952
+ }
953
+ }
954
+ }
955
+ }
956
+
957
+ if (bestloc) {
958
+ int d = (int) (data+i - bestloc); // distance back
959
+ STBIW_ASSERT(d <= 32767 && best <= 258);
960
+ for (j=0; best > lengthc[j+1]-1; ++j);
961
+ stbiw__zlib_huff(j+257);
962
+ if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);
963
+ for (j=0; d > distc[j+1]-1; ++j);
964
+ stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);
965
+ if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);
966
+ i += best;
967
+ } else {
968
+ stbiw__zlib_huffb(data[i]);
969
+ ++i;
970
+ }
971
+ }
972
+ // write out final bytes
973
+ for (;i < data_len; ++i)
974
+ stbiw__zlib_huffb(data[i]);
975
+ stbiw__zlib_huff(256); // end of block
976
+ // pad with 0 bits to byte boundary
977
+ while (bitcount)
978
+ stbiw__zlib_add(0,1);
979
+
980
+ for (i=0; i < stbiw__ZHASH; ++i)
981
+ (void) stbiw__sbfree(hash_table[i]);
982
+ STBIW_FREE(hash_table);
983
+
984
+ // store uncompressed instead if compression was worse
985
+ if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) {
986
+ stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1
987
+ for (j = 0; j < data_len;) {
988
+ int blocklen = data_len - j;
989
+ if (blocklen > 32767) blocklen = 32767;
990
+ stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression
991
+ stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN
992
+ stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8));
993
+ stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN
994
+ stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8));
995
+ memcpy(out+stbiw__sbn(out), data+j, blocklen);
996
+ stbiw__sbn(out) += blocklen;
997
+ j += blocklen;
998
+ }
999
+ }
1000
+
1001
+ {
1002
+ // compute adler32 on input
1003
+ unsigned int s1=1, s2=0;
1004
+ int blocklen = (int) (data_len % 5552);
1005
+ j=0;
1006
+ while (j < data_len) {
1007
+ for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; }
1008
+ s1 %= 65521; s2 %= 65521;
1009
+ j += blocklen;
1010
+ blocklen = 5552;
1011
+ }
1012
+ stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8));
1013
+ stbiw__sbpush(out, STBIW_UCHAR(s2));
1014
+ stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8));
1015
+ stbiw__sbpush(out, STBIW_UCHAR(s1));
1016
+ }
1017
+ *out_len = stbiw__sbn(out);
1018
+ // make returned pointer freeable
1019
+ STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);
1020
+ return (unsigned char *) stbiw__sbraw(out);
1021
+ #endif // STBIW_ZLIB_COMPRESS
1022
+ }
1023
+
1024
+ static unsigned int stbiw__crc32(unsigned char *buffer, int len)
1025
+ {
1026
+ #ifdef STBIW_CRC32
1027
+ return STBIW_CRC32(buffer, len);
1028
+ #else
1029
+ static unsigned int crc_table[256] =
1030
+ {
1031
+ 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,
1032
+ 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,
1033
+ 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
1034
+ 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,
1035
+ 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,
1036
+ 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
1037
+ 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,
1038
+ 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,
1039
+ 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
1040
+ 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,
1041
+ 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,
1042
+ 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
1043
+ 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,
1044
+ 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,
1045
+ 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
1046
+ 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,
1047
+ 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,
1048
+ 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
1049
+ 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,
1050
+ 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,
1051
+ 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
1052
+ 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,
1053
+ 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,
1054
+ 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
1055
+ 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,
1056
+ 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,
1057
+ 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
1058
+ 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,
1059
+ 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,
1060
+ 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
1061
+ 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,
1062
+ 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D
1063
+ };
1064
+
1065
+ unsigned int crc = ~0u;
1066
+ int i;
1067
+ for (i=0; i < len; ++i)
1068
+ crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];
1069
+ return ~crc;
1070
+ #endif
1071
+ }
1072
+
1073
+ #define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4)
1074
+ #define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));
1075
+ #define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])
1076
+
1077
+ static void stbiw__wpcrc(unsigned char **data, int len)
1078
+ {
1079
+ unsigned int crc = stbiw__crc32(*data - len - 4, len+4);
1080
+ stbiw__wp32(*data, crc);
1081
+ }
1082
+
1083
+ static unsigned char stbiw__paeth(int a, int b, int c)
1084
+ {
1085
+ int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);
1086
+ if (pa <= pb && pa <= pc) return STBIW_UCHAR(a);
1087
+ if (pb <= pc) return STBIW_UCHAR(b);
1088
+ return STBIW_UCHAR(c);
1089
+ }
1090
+
1091
+ // @OPTIMIZE: provide an option that always forces left-predict or paeth predict
1092
+ static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer)
1093
+ {
1094
+ static int mapping[] = { 0,1,2,3,4 };
1095
+ static int firstmap[] = { 0,1,0,5,6 };
1096
+ int *mymap = (y != 0) ? mapping : firstmap;
1097
+ int i;
1098
+ int type = mymap[filter_type];
1099
+ unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y);
1100
+ int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes;
1101
+
1102
+ if (type==0) {
1103
+ memcpy(line_buffer, z, width*n);
1104
+ return;
1105
+ }
1106
+
1107
+ // first loop isn't optimized since it's just one pixel
1108
+ for (i = 0; i < n; ++i) {
1109
+ switch (type) {
1110
+ case 1: line_buffer[i] = z[i]; break;
1111
+ case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break;
1112
+ case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break;
1113
+ case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break;
1114
+ case 5: line_buffer[i] = z[i]; break;
1115
+ case 6: line_buffer[i] = z[i]; break;
1116
+ }
1117
+ }
1118
+ switch (type) {
1119
+ case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break;
1120
+ case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break;
1121
+ case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break;
1122
+ case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break;
1123
+ case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break;
1124
+ case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;
1125
+ }
1126
+ }
1127
+
1128
+ STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)
1129
+ {
1130
+ int force_filter = stbi_write_force_png_filter;
1131
+ int ctype[5] = { -1, 0, 4, 2, 6 };
1132
+ unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };
1133
+ unsigned char *out,*o, *filt, *zlib;
1134
+ signed char *line_buffer;
1135
+ int j,zlen;
1136
+
1137
+ if (stride_bytes == 0)
1138
+ stride_bytes = x * n;
1139
+
1140
+ if (force_filter >= 5) {
1141
+ force_filter = -1;
1142
+ }
1143
+
1144
+ filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;
1145
+ line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }
1146
+ for (j=0; j < y; ++j) {
1147
+ int filter_type;
1148
+ if (force_filter > -1) {
1149
+ filter_type = force_filter;
1150
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer);
1151
+ } else { // Estimate the best filter by running through all of them:
1152
+ int best_filter = 0, best_filter_val = 0x7fffffff, est, i;
1153
+ for (filter_type = 0; filter_type < 5; filter_type++) {
1154
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer);
1155
+
1156
+ // Estimate the entropy of the line using this filter; the less, the better.
1157
+ est = 0;
1158
+ for (i = 0; i < x*n; ++i) {
1159
+ est += abs((signed char) line_buffer[i]);
1160
+ }
1161
+ if (est < best_filter_val) {
1162
+ best_filter_val = est;
1163
+ best_filter = filter_type;
1164
+ }
1165
+ }
1166
+ if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it
1167
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer);
1168
+ filter_type = best_filter;
1169
+ }
1170
+ }
1171
+ // when we get here, filter_type contains the filter type, and line_buffer contains the data
1172
+ filt[j*(x*n+1)] = (unsigned char) filter_type;
1173
+ STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);
1174
+ }
1175
+ STBIW_FREE(line_buffer);
1176
+ zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level);
1177
+ STBIW_FREE(filt);
1178
+ if (!zlib) return 0;
1179
+
1180
+ // each tag requires 12 bytes of overhead
1181
+ out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);
1182
+ if (!out) return 0;
1183
+ *out_len = 8 + 12+13 + 12+zlen + 12;
1184
+
1185
+ o=out;
1186
+ STBIW_MEMMOVE(o,sig,8); o+= 8;
1187
+ stbiw__wp32(o, 13); // header length
1188
+ stbiw__wptag(o, "IHDR");
1189
+ stbiw__wp32(o, x);
1190
+ stbiw__wp32(o, y);
1191
+ *o++ = 8;
1192
+ *o++ = STBIW_UCHAR(ctype[n]);
1193
+ *o++ = 0;
1194
+ *o++ = 0;
1195
+ *o++ = 0;
1196
+ stbiw__wpcrc(&o,13);
1197
+
1198
+ stbiw__wp32(o, zlen);
1199
+ stbiw__wptag(o, "IDAT");
1200
+ STBIW_MEMMOVE(o, zlib, zlen);
1201
+ o += zlen;
1202
+ STBIW_FREE(zlib);
1203
+ stbiw__wpcrc(&o, zlen);
1204
+
1205
+ stbiw__wp32(o,0);
1206
+ stbiw__wptag(o, "IEND");
1207
+ stbiw__wpcrc(&o,0);
1208
+
1209
+ STBIW_ASSERT(o == out + *out_len);
1210
+
1211
+ return out;
1212
+ }
1213
+
1214
+ #ifndef STBI_WRITE_NO_STDIO
1215
+ STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)
1216
+ {
1217
+ FILE *f;
1218
+ int len;
1219
+ unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
1220
+ if (png == NULL) return 0;
1221
+
1222
+ f = stbiw__fopen(filename, "wb");
1223
+ if (!f) { STBIW_FREE(png); return 0; }
1224
+ fwrite(png, 1, len, f);
1225
+ fclose(f);
1226
+ STBIW_FREE(png);
1227
+ return 1;
1228
+ }
1229
+ #endif
1230
+
1231
+ STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes)
1232
+ {
1233
+ int len;
1234
+ unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
1235
+ if (png == NULL) return 0;
1236
+ func(context, png, len);
1237
+ STBIW_FREE(png);
1238
+ return 1;
1239
+ }
1240
+
1241
+
1242
+ /* ***************************************************************************
1243
+ *
1244
+ * JPEG writer
1245
+ *
1246
+ * This is based on Jon Olick's jo_jpeg.cpp:
1247
+ * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html
1248
+ */
1249
+
1250
+ static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18,
1251
+ 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 };
1252
+
1253
+ static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) {
1254
+ int bitBuf = *bitBufP, bitCnt = *bitCntP;
1255
+ bitCnt += bs[1];
1256
+ bitBuf |= bs[0] << (24 - bitCnt);
1257
+ while(bitCnt >= 8) {
1258
+ unsigned char c = (bitBuf >> 16) & 255;
1259
+ stbiw__putc(s, c);
1260
+ if(c == 255) {
1261
+ stbiw__putc(s, 0);
1262
+ }
1263
+ bitBuf <<= 8;
1264
+ bitCnt -= 8;
1265
+ }
1266
+ *bitBufP = bitBuf;
1267
+ *bitCntP = bitCnt;
1268
+ }
1269
+
1270
+ static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) {
1271
+ float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p;
1272
+ float z1, z2, z3, z4, z5, z11, z13;
1273
+
1274
+ float tmp0 = d0 + d7;
1275
+ float tmp7 = d0 - d7;
1276
+ float tmp1 = d1 + d6;
1277
+ float tmp6 = d1 - d6;
1278
+ float tmp2 = d2 + d5;
1279
+ float tmp5 = d2 - d5;
1280
+ float tmp3 = d3 + d4;
1281
+ float tmp4 = d3 - d4;
1282
+
1283
+ // Even part
1284
+ float tmp10 = tmp0 + tmp3; // phase 2
1285
+ float tmp13 = tmp0 - tmp3;
1286
+ float tmp11 = tmp1 + tmp2;
1287
+ float tmp12 = tmp1 - tmp2;
1288
+
1289
+ d0 = tmp10 + tmp11; // phase 3
1290
+ d4 = tmp10 - tmp11;
1291
+
1292
+ z1 = (tmp12 + tmp13) * 0.707106781f; // c4
1293
+ d2 = tmp13 + z1; // phase 5
1294
+ d6 = tmp13 - z1;
1295
+
1296
+ // Odd part
1297
+ tmp10 = tmp4 + tmp5; // phase 2
1298
+ tmp11 = tmp5 + tmp6;
1299
+ tmp12 = tmp6 + tmp7;
1300
+
1301
+ // The rotator is modified from fig 4-8 to avoid extra negations.
1302
+ z5 = (tmp10 - tmp12) * 0.382683433f; // c6
1303
+ z2 = tmp10 * 0.541196100f + z5; // c2-c6
1304
+ z4 = tmp12 * 1.306562965f + z5; // c2+c6
1305
+ z3 = tmp11 * 0.707106781f; // c4
1306
+
1307
+ z11 = tmp7 + z3; // phase 5
1308
+ z13 = tmp7 - z3;
1309
+
1310
+ *d5p = z13 + z2; // phase 6
1311
+ *d3p = z13 - z2;
1312
+ *d1p = z11 + z4;
1313
+ *d7p = z11 - z4;
1314
+
1315
+ *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6;
1316
+ }
1317
+
1318
+ static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) {
1319
+ int tmp1 = val < 0 ? -val : val;
1320
+ val = val < 0 ? val-1 : val;
1321
+ bits[1] = 1;
1322
+ while(tmp1 >>= 1) {
1323
+ ++bits[1];
1324
+ }
1325
+ bits[0] = val & ((1<<bits[1])-1);
1326
+ }
1327
+
1328
+ static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt, float *CDU, int du_stride, float *fdtbl, int DC, const unsigned short HTDC[256][2], const unsigned short HTAC[256][2]) {
1329
+ const unsigned short EOB[2] = { HTAC[0x00][0], HTAC[0x00][1] };
1330
+ const unsigned short M16zeroes[2] = { HTAC[0xF0][0], HTAC[0xF0][1] };
1331
+ int dataOff, i, j, n, diff, end0pos, x, y;
1332
+ int DU[64];
1333
+
1334
+ // DCT rows
1335
+ for(dataOff=0, n=du_stride*8; dataOff<n; dataOff+=du_stride) {
1336
+ stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+1], &CDU[dataOff+2], &CDU[dataOff+3], &CDU[dataOff+4], &CDU[dataOff+5], &CDU[dataOff+6], &CDU[dataOff+7]);
1337
+ }
1338
+ // DCT columns
1339
+ for(dataOff=0; dataOff<8; ++dataOff) {
1340
+ stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+du_stride], &CDU[dataOff+du_stride*2], &CDU[dataOff+du_stride*3], &CDU[dataOff+du_stride*4],
1341
+ &CDU[dataOff+du_stride*5], &CDU[dataOff+du_stride*6], &CDU[dataOff+du_stride*7]);
1342
+ }
1343
+ // Quantize/descale/zigzag the coefficients
1344
+ for(y = 0, j=0; y < 8; ++y) {
1345
+ for(x = 0; x < 8; ++x,++j) {
1346
+ float v;
1347
+ i = y*du_stride+x;
1348
+ v = CDU[i]*fdtbl[j];
1349
+ // DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? ceilf(v - 0.5f) : floorf(v + 0.5f));
1350
+ // ceilf() and floorf() are C99, not C89, but I /think/ they're not needed here anyway?
1351
+ DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? v - 0.5f : v + 0.5f);
1352
+ }
1353
+ }
1354
+
1355
+ // Encode DC
1356
+ diff = DU[0] - DC;
1357
+ if (diff == 0) {
1358
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[0]);
1359
+ } else {
1360
+ unsigned short bits[2];
1361
+ stbiw__jpg_calcBits(diff, bits);
1362
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[bits[1]]);
1363
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
1364
+ }
1365
+ // Encode ACs
1366
+ end0pos = 63;
1367
+ for(; (end0pos>0)&&(DU[end0pos]==0); --end0pos) {
1368
+ }
1369
+ // end0pos = first element in reverse order !=0
1370
+ if(end0pos == 0) {
1371
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
1372
+ return DU[0];
1373
+ }
1374
+ for(i = 1; i <= end0pos; ++i) {
1375
+ int startpos = i;
1376
+ int nrzeroes;
1377
+ unsigned short bits[2];
1378
+ for (; DU[i]==0 && i<=end0pos; ++i) {
1379
+ }
1380
+ nrzeroes = i-startpos;
1381
+ if ( nrzeroes >= 16 ) {
1382
+ int lng = nrzeroes>>4;
1383
+ int nrmarker;
1384
+ for (nrmarker=1; nrmarker <= lng; ++nrmarker)
1385
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes);
1386
+ nrzeroes &= 15;
1387
+ }
1388
+ stbiw__jpg_calcBits(DU[i], bits);
1389
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]);
1390
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
1391
+ }
1392
+ if(end0pos != 63) {
1393
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
1394
+ }
1395
+ return DU[0];
1396
+ }
1397
+
1398
+ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {
1399
+ // Constants that don't pollute global namespace
1400
+ static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
1401
+ static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1402
+ static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d};
1403
+ static const unsigned char std_ac_luminance_values[] = {
1404
+ 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,
1405
+ 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,
1406
+ 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,
1407
+ 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,
1408
+ 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,
1409
+ 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,
1410
+ 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
1411
+ };
1412
+ static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0};
1413
+ static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1414
+ static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77};
1415
+ static const unsigned char std_ac_chrominance_values[] = {
1416
+ 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,
1417
+ 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,
1418
+ 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,
1419
+ 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,
1420
+ 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,
1421
+ 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,
1422
+ 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
1423
+ };
1424
+ // Huffman tables
1425
+ static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}};
1426
+ static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}};
1427
+ static const unsigned short YAC_HT[256][2] = {
1428
+ {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1429
+ {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1430
+ {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1431
+ {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1432
+ {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1433
+ {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1434
+ {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1435
+ {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1436
+ {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1437
+ {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1438
+ {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1439
+ {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1440
+ {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1441
+ {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1442
+ {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0},
1443
+ {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
1444
+ };
1445
+ static const unsigned short UVAC_HT[256][2] = {
1446
+ {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1447
+ {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1448
+ {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1449
+ {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1450
+ {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1451
+ {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1452
+ {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1453
+ {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1454
+ {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1455
+ {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1456
+ {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1457
+ {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1458
+ {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1459
+ {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1460
+ {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0},
1461
+ {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
1462
+ };
1463
+ static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22,
1464
+ 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99};
1465
+ static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99,
1466
+ 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99};
1467
+ static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f,
1468
+ 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f };
1469
+
1470
+ int row, col, i, k, subsample;
1471
+ float fdtbl_Y[64], fdtbl_UV[64];
1472
+ unsigned char YTable[64], UVTable[64];
1473
+
1474
+ if(!data || !width || !height || comp > 4 || comp < 1) {
1475
+ return 0;
1476
+ }
1477
+
1478
+ quality = quality ? quality : 90;
1479
+ subsample = quality <= 90 ? 1 : 0;
1480
+ quality = quality < 1 ? 1 : quality > 100 ? 100 : quality;
1481
+ quality = quality < 50 ? 5000 / quality : 200 - quality * 2;
1482
+
1483
+ for(i = 0; i < 64; ++i) {
1484
+ int uvti, yti = (YQT[i]*quality+50)/100;
1485
+ YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti);
1486
+ uvti = (UVQT[i]*quality+50)/100;
1487
+ UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti);
1488
+ }
1489
+
1490
+ for(row = 0, k = 0; row < 8; ++row) {
1491
+ for(col = 0; col < 8; ++col, ++k) {
1492
+ fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
1493
+ fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
1494
+ }
1495
+ }
1496
+
1497
+ // Write Headers
1498
+ {
1499
+ static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 };
1500
+ static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 };
1501
+ const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width),
1502
+ 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 };
1503
+ s->func(s->context, (void*)head0, sizeof(head0));
1504
+ s->func(s->context, (void*)YTable, sizeof(YTable));
1505
+ stbiw__putc(s, 1);
1506
+ s->func(s->context, UVTable, sizeof(UVTable));
1507
+ s->func(s->context, (void*)head1, sizeof(head1));
1508
+ s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
1509
+ s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
1510
+ stbiw__putc(s, 0x10); // HTYACinfo
1511
+ s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1);
1512
+ s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values));
1513
+ stbiw__putc(s, 1); // HTUDCinfo
1514
+ s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1);
1515
+ s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values));
1516
+ stbiw__putc(s, 0x11); // HTUACinfo
1517
+ s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1);
1518
+ s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values));
1519
+ s->func(s->context, (void*)head2, sizeof(head2));
1520
+ }
1521
+
1522
+ // Encode 8x8 macroblocks
1523
+ {
1524
+ static const unsigned short fillBits[] = {0x7F, 7};
1525
+ int DCY=0, DCU=0, DCV=0;
1526
+ int bitBuf=0, bitCnt=0;
1527
+ // comp == 2 is grey+alpha (alpha is ignored)
1528
+ int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0;
1529
+ const unsigned char *dataR = (const unsigned char *)data;
1530
+ const unsigned char *dataG = dataR + ofsG;
1531
+ const unsigned char *dataB = dataR + ofsB;
1532
+ int x, y, pos;
1533
+ if(subsample) {
1534
+ for(y = 0; y < height; y += 16) {
1535
+ for(x = 0; x < width; x += 16) {
1536
+ float Y[256], U[256], V[256];
1537
+ for(row = y, pos = 0; row < y+16; ++row) {
1538
+ // row >= height => use last input row
1539
+ int clamped_row = (row < height) ? row : height - 1;
1540
+ int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
1541
+ for(col = x; col < x+16; ++col, ++pos) {
1542
+ // if col >= width => use pixel from last input column
1543
+ int p = base_p + ((col < width) ? col : (width-1))*comp;
1544
+ float r = dataR[p], g = dataG[p], b = dataB[p];
1545
+ Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
1546
+ U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
1547
+ V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
1548
+ }
1549
+ }
1550
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1551
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1552
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1553
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1554
+
1555
+ // subsample U,V
1556
+ {
1557
+ float subU[64], subV[64];
1558
+ int yy, xx;
1559
+ for(yy = 0, pos = 0; yy < 8; ++yy) {
1560
+ for(xx = 0; xx < 8; ++xx, ++pos) {
1561
+ int j = yy*32+xx*2;
1562
+ subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f;
1563
+ subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f;
1564
+ }
1565
+ }
1566
+ DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
1567
+ DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
1568
+ }
1569
+ }
1570
+ }
1571
+ } else {
1572
+ for(y = 0; y < height; y += 8) {
1573
+ for(x = 0; x < width; x += 8) {
1574
+ float Y[64], U[64], V[64];
1575
+ for(row = y, pos = 0; row < y+8; ++row) {
1576
+ // row >= height => use last input row
1577
+ int clamped_row = (row < height) ? row : height - 1;
1578
+ int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
1579
+ for(col = x; col < x+8; ++col, ++pos) {
1580
+ // if col >= width => use pixel from last input column
1581
+ int p = base_p + ((col < width) ? col : (width-1))*comp;
1582
+ float r = dataR[p], g = dataG[p], b = dataB[p];
1583
+ Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
1584
+ U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
1585
+ V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
1586
+ }
1587
+ }
1588
+
1589
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1590
+ DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
1591
+ DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
1592
+ }
1593
+ }
1594
+ }
1595
+
1596
+ // Do the bit alignment of the EOI marker
1597
+ stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits);
1598
+ }
1599
+
1600
+ // EOI
1601
+ stbiw__putc(s, 0xFF);
1602
+ stbiw__putc(s, 0xD9);
1603
+
1604
+ return 1;
1605
+ }
1606
+
1607
+ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality)
1608
+ {
1609
+ stbi__write_context s = { 0 };
1610
+ stbi__start_write_callbacks(&s, func, context);
1611
+ return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);
1612
+ }
1613
+
1614
+
1615
+ #ifndef STBI_WRITE_NO_STDIO
1616
+ STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)
1617
+ {
1618
+ stbi__write_context s = { 0 };
1619
+ if (stbi__start_write_file(&s,filename)) {
1620
+ int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);
1621
+ stbi__end_write_file(&s);
1622
+ return r;
1623
+ } else
1624
+ return 0;
1625
+ }
1626
+ #endif
1627
+
1628
+ #endif // STB_IMAGE_WRITE_IMPLEMENTATION
1629
+
1630
+ /* Revision history
1631
+ 1.16 (2021-07-11)
1632
+ make Deflate code emit uncompressed blocks when it would otherwise expand
1633
+ support writing BMPs with alpha channel
1634
+ 1.15 (2020-07-13) unknown
1635
+ 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels
1636
+ 1.13
1637
+ 1.12
1638
+ 1.11 (2019-08-11)
1639
+
1640
+ 1.10 (2019-02-07)
1641
+ support utf8 filenames in Windows; fix warnings and platform ifdefs
1642
+ 1.09 (2018-02-11)
1643
+ fix typo in zlib quality API, improve STB_I_W_STATIC in C++
1644
+ 1.08 (2018-01-29)
1645
+ add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter
1646
+ 1.07 (2017-07-24)
1647
+ doc fix
1648
+ 1.06 (2017-07-23)
1649
+ writing JPEG (using Jon Olick's code)
1650
+ 1.05 ???
1651
+ 1.04 (2017-03-03)
1652
+ monochrome BMP expansion
1653
+ 1.03 ???
1654
+ 1.02 (2016-04-02)
1655
+ avoid allocating large structures on the stack
1656
+ 1.01 (2016-01-16)
1657
+ STBIW_REALLOC_SIZED: support allocators with no realloc support
1658
+ avoid race-condition in crc initialization
1659
+ minor compile issues
1660
+ 1.00 (2015-09-14)
1661
+ installable file IO function
1662
+ 0.99 (2015-09-13)
1663
+ warning fixes; TGA rle support
1664
+ 0.98 (2015-04-08)
1665
+ added STBIW_MALLOC, STBIW_ASSERT etc
1666
+ 0.97 (2015-01-18)
1667
+ fixed HDR asserts, rewrote HDR rle logic
1668
+ 0.96 (2015-01-17)
1669
+ add HDR output
1670
+ fix monochrome BMP
1671
+ 0.95 (2014-08-17)
1672
+ add monochrome TGA output
1673
+ 0.94 (2014-05-31)
1674
+ rename private functions to avoid conflicts with stb_image.h
1675
+ 0.93 (2014-05-27)
1676
+ warning fixes
1677
+ 0.92 (2010-08-01)
1678
+ casts to unsigned char to fix warnings
1679
+ 0.91 (2010-07-17)
1680
+ first public release
1681
+ 0.90 first internal release
1682
+ */
1683
+
1684
+ /*
1685
+ ------------------------------------------------------------------------------
1686
+ This software is available under 2 licenses -- choose whichever you prefer.
1687
+ ------------------------------------------------------------------------------
1688
+ ALTERNATIVE A - MIT License
1689
+ Copyright (c) 2017 Sean Barrett
1690
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
1691
+ this software and associated documentation files (the "Software"), to deal in
1692
+ the Software without restriction, including without limitation the rights to
1693
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
1694
+ of the Software, and to permit persons to whom the Software is furnished to do
1695
+ so, subject to the following conditions:
1696
+ The above copyright notice and this permission notice shall be included in all
1697
+ copies or substantial portions of the Software.
1698
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1699
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1700
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1701
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1702
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1703
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1704
+ SOFTWARE.
1705
+ ------------------------------------------------------------------------------
1706
+ ALTERNATIVE B - Public Domain (www.unlicense.org)
1707
+ This is free and unencumbered software released into the public domain.
1708
+ Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
1709
+ software, either in source code form or as a compiled binary, for any purpose,
1710
+ commercial or non-commercial, and by any means.
1711
+ In jurisdictions that recognize copyright laws, the author or authors of this
1712
+ software dedicate any and all copyright interest in the software to the public
1713
+ domain. We make this dedication for the benefit of the public at large and to
1714
+ the detriment of our heirs and successors. We intend this dedication to be an
1715
+ overt act of relinquishment in perpetuity of all present and future rights to
1716
+ this software under copyright law.
1717
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1718
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1719
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1720
+ AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
1721
+ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
1722
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1723
+ ------------------------------------------------------------------------------
1724
+ */
util/text_img.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import rembg
3
+ import torch
4
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import gradio as gr
9
+
10
+ # pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
11
+ # pipe.to("cuda")
12
+
13
+ def check_prompt(prompt):
14
+ if prompt is None:
15
+ raise gr.Error("Please enter a prompt!")
16
+
17
+ controlnet = ControlNetModel.from_pretrained(
18
+ "diffusers/controlnet-canny-sdxl-1.0",
19
+ torch_dtype=torch.float16,
20
+ use_safetensors=True
21
+ )
22
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
23
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
24
+ "stabilityai/stable-diffusion-xl-base-1.0",
25
+ controlnet=controlnet,
26
+ vae=vae,
27
+ torch_dtype=torch.float16,
28
+ use_safetensors=True
29
+ )
30
+
31
+ pipe.to("cuda")
32
+
33
+ # Function to generate an image from text using diffusion
34
+ @spaces.GPU
35
+ def generate_image(prompt, negative_prompt, control_image, scale=0.5):
36
+ prompt += "no background, side view, minimalist shot, single shoe, no legs, product photo"
37
+
38
+ canny_image = get_canny(control_image)
39
+
40
+ image = pipe(
41
+ prompt,
42
+ negative_prompt=negative_prompt,
43
+ image=canny_image,
44
+ controlnet_conditioning_scale=scale,
45
+ ).images[0]
46
+ image2 = rembg.remove(image)
47
+
48
+ return image2
49
+
50
+ def get_canny(image):
51
+ image = np.array(image)
52
+
53
+ low_threshold = 100
54
+ high_threshold = 200
55
+
56
+ image = cv2.Canny(image,low_threshold,high_threshold)
57
+ image = image[:,:,None]
58
+ image = np.concatenate([image, image, image], axis=2)
59
+ canny_image = Image.fromarray(image)
60
+ return canny_image