MaxMilan1
commited on
Commit
·
a1f69bb
1
Parent(s):
883d514
changes
Browse files- app.py +28 -1
- diff-gaussian-rasterization/.gitignore +3 -0
- diff-gaussian-rasterization/.gitmodules +3 -0
- diff-gaussian-rasterization/CMakeLists.txt +36 -0
- diff-gaussian-rasterization/LICENSE.md +83 -0
- diff-gaussian-rasterization/README.md +19 -0
- diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h +175 -0
- diff-gaussian-rasterization/cuda_rasterizer/backward.cu +657 -0
- diff-gaussian-rasterization/cuda_rasterizer/backward.h +65 -0
- diff-gaussian-rasterization/cuda_rasterizer/config.h +19 -0
- diff-gaussian-rasterization/cuda_rasterizer/forward.cu +455 -0
- diff-gaussian-rasterization/cuda_rasterizer/forward.h +66 -0
- diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h +88 -0
- diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +434 -0
- diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +74 -0
- diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +221 -0
- diff-gaussian-rasterization/ext.cpp +19 -0
- diff-gaussian-rasterization/rasterize_points.cu +217 -0
- diff-gaussian-rasterization/rasterize_points.h +67 -0
- diff-gaussian-rasterization/setup.py +34 -0
- diff-gaussian-rasterization/third_party/stbi_image_write.h +1724 -0
- util/text_img.py +60 -0
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import os
|
3 |
|
4 |
from util.instantmesh import generate_mvs, make3d, preprocess, check_input_image
|
|
|
5 |
|
6 |
_CITE_ = r"""
|
7 |
```bibtex
|
@@ -16,6 +17,33 @@ _CITE_ = r"""
|
|
16 |
|
17 |
|
18 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
with gr.Row(variant="panel"):
|
20 |
with gr.Column():
|
21 |
with gr.Row():
|
@@ -62,7 +90,6 @@ with gr.Blocks() as demo:
|
|
62 |
inputs=[input_image],
|
63 |
label="Examples",
|
64 |
cache_examples=False,
|
65 |
-
examples_per_page=12
|
66 |
)
|
67 |
|
68 |
with gr.Column():
|
|
|
2 |
import os
|
3 |
|
4 |
from util.instantmesh import generate_mvs, make3d, preprocess, check_input_image
|
5 |
+
from util.text_img import generate_image, check_prompt
|
6 |
|
7 |
_CITE_ = r"""
|
8 |
```bibtex
|
|
|
17 |
|
18 |
|
19 |
with gr.Blocks() as demo:
|
20 |
+
with gr.Tab("Text to Image Generator"):
|
21 |
+
with gr.Row():
|
22 |
+
with gr.Column():
|
23 |
+
prompt = gr.Textbox(label="Enter a discription of a shoe")
|
24 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="low quality, bad quality, sketches, legs")
|
25 |
+
scale = gr.Slider(label="Control Image Scale", minimum=0.1, maximum=1.0, step=0.1, value=0.5)
|
26 |
+
with gr.Column():
|
27 |
+
control_image = gr.Image(label="Enter an image of a shoe, that you want to use as a reference", type='numpy')
|
28 |
+
# neg_prompt = gr.Textbox(label="Enter a negative prompt", value="low quality, watermark, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, closed eyes, text, logo")
|
29 |
+
with gr.Row():
|
30 |
+
with gr.Column():
|
31 |
+
gr.Examples(
|
32 |
+
examples=[
|
33 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
34 |
+
],
|
35 |
+
inputs=[control_image],
|
36 |
+
label="Examples",
|
37 |
+
cache_examples=False,
|
38 |
+
)
|
39 |
+
with gr.Column():
|
40 |
+
button_gen = gr.Button("Generate Image")
|
41 |
+
with gr.Row():
|
42 |
+
with gr.Column():
|
43 |
+
image_nobg = gr.Image(label="Generated Image", show_download_button=True, show_label=False)
|
44 |
+
|
45 |
+
button_gen.click(check_prompt, inputs=[prompt]).succes(generate_image, inputs=[prompt, negative_prompt, control_image, scale], outputs=[image_nobg])
|
46 |
+
|
47 |
with gr.Row(variant="panel"):
|
48 |
with gr.Column():
|
49 |
with gr.Row():
|
|
|
90 |
inputs=[input_image],
|
91 |
label="Examples",
|
92 |
cache_examples=False,
|
|
|
93 |
)
|
94 |
|
95 |
with gr.Column():
|
diff-gaussian-rasterization/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
build/
|
2 |
+
diff_gaussian_rasterization.egg-info/
|
3 |
+
dist/
|
diff-gaussian-rasterization/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/glm"]
|
2 |
+
path = third_party/glm
|
3 |
+
url = https://github.com/g-truc/glm.git
|
diff-gaussian-rasterization/CMakeLists.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
cmake_minimum_required(VERSION 3.20)
|
13 |
+
|
14 |
+
project(DiffRast LANGUAGES CUDA CXX)
|
15 |
+
|
16 |
+
set(CMAKE_CXX_STANDARD 17)
|
17 |
+
set(CMAKE_CXX_EXTENSIONS OFF)
|
18 |
+
set(CMAKE_CUDA_STANDARD 17)
|
19 |
+
|
20 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
21 |
+
|
22 |
+
add_library(CudaRasterizer
|
23 |
+
cuda_rasterizer/backward.h
|
24 |
+
cuda_rasterizer/backward.cu
|
25 |
+
cuda_rasterizer/forward.h
|
26 |
+
cuda_rasterizer/forward.cu
|
27 |
+
cuda_rasterizer/auxiliary.h
|
28 |
+
cuda_rasterizer/rasterizer_impl.cu
|
29 |
+
cuda_rasterizer/rasterizer_impl.h
|
30 |
+
cuda_rasterizer/rasterizer.h
|
31 |
+
)
|
32 |
+
|
33 |
+
set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
|
34 |
+
|
35 |
+
target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
|
36 |
+
target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
diff-gaussian-rasterization/LICENSE.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Gaussian-Splatting License
|
2 |
+
===========================
|
3 |
+
|
4 |
+
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
|
5 |
+
The *Software* is in the process of being registered with the Agence pour la Protection des
|
6 |
+
Programmes (APP).
|
7 |
+
|
8 |
+
The *Software* is still being developed by the *Licensor*.
|
9 |
+
|
10 |
+
*Licensor*'s goal is to allow the research community to use, test and evaluate
|
11 |
+
the *Software*.
|
12 |
+
|
13 |
+
## 1. Definitions
|
14 |
+
|
15 |
+
*Licensee* means any person or entity that uses the *Software* and distributes
|
16 |
+
its *Work*.
|
17 |
+
|
18 |
+
*Licensor* means the owners of the *Software*, i.e Inria and MPII
|
19 |
+
|
20 |
+
*Software* means the original work of authorship made available under this
|
21 |
+
License ie gaussian-splatting.
|
22 |
+
|
23 |
+
*Work* means the *Software* and any additions to or derivative works of the
|
24 |
+
*Software* that are made available under this License.
|
25 |
+
|
26 |
+
|
27 |
+
## 2. Purpose
|
28 |
+
This license is intended to define the rights granted to the *Licensee* by
|
29 |
+
Licensors under the *Software*.
|
30 |
+
|
31 |
+
## 3. Rights granted
|
32 |
+
|
33 |
+
For the above reasons Licensors have decided to distribute the *Software*.
|
34 |
+
Licensors grant non-exclusive rights to use the *Software* for research purposes
|
35 |
+
to research users (both academic and industrial), free of charge, without right
|
36 |
+
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
|
37 |
+
and/or evaluation purposes only.
|
38 |
+
|
39 |
+
Subject to the terms and conditions of this License, you are granted a
|
40 |
+
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
|
41 |
+
publicly display, publicly perform and distribute its *Work* and any resulting
|
42 |
+
derivative works in any form.
|
43 |
+
|
44 |
+
## 4. Limitations
|
45 |
+
|
46 |
+
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
|
47 |
+
so under this License, (b) you include a complete copy of this License with
|
48 |
+
your distribution, and (c) you retain without modification any copyright,
|
49 |
+
patent, trademark, or attribution notices that are present in the *Work*.
|
50 |
+
|
51 |
+
**4.2 Derivative Works.** You may specify that additional or different terms apply
|
52 |
+
to the use, reproduction, and distribution of your derivative works of the *Work*
|
53 |
+
("Your Terms") only if (a) Your Terms provide that the use limitation in
|
54 |
+
Section 2 applies to your derivative works, and (b) you identify the specific
|
55 |
+
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
|
56 |
+
this License (including the redistribution requirements in Section 3.1) will
|
57 |
+
continue to apply to the *Work* itself.
|
58 |
+
|
59 |
+
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
|
60 |
+
users explicitly acknowledge having received from Licensors all information
|
61 |
+
allowing to appreciate the adequacy between of the *Software* and their needs and
|
62 |
+
to undertake all necessary precautions for its execution and use.
|
63 |
+
|
64 |
+
**4.4** The *Software* is provided both as a compiled library file and as source
|
65 |
+
code. In case of using the *Software* for a publication or other results obtained
|
66 |
+
through the use of the *Software*, users are strongly encouraged to cite the
|
67 |
+
corresponding publications as explained in the documentation of the *Software*.
|
68 |
+
|
69 |
+
## 5. Disclaimer
|
70 |
+
|
71 |
+
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
|
72 |
+
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
|
73 |
+
UNAUTHORIZED USE: [email protected] . ANY SUCH ACTION WILL
|
74 |
+
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
|
75 |
+
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
|
76 |
+
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
|
77 |
+
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
|
78 |
+
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
79 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
80 |
+
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
|
81 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
82 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
|
83 |
+
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
|
diff-gaussian-rasterization/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Differential Gaussian Rasterization
|
2 |
+
|
3 |
+
Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
|
4 |
+
|
5 |
+
<section class="section" id="BibTeX">
|
6 |
+
<div class="container is-max-desktop content">
|
7 |
+
<h2 class="title">BibTeX</h2>
|
8 |
+
<pre><code>@Article{kerbl3Dgaussians,
|
9 |
+
author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
|
10 |
+
title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
|
11 |
+
journal = {ACM Transactions on Graphics},
|
12 |
+
number = {4},
|
13 |
+
volume = {42},
|
14 |
+
month = {July},
|
15 |
+
year = {2023},
|
16 |
+
url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
|
17 |
+
}</code></pre>
|
18 |
+
</div>
|
19 |
+
</section>
|
diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
14 |
+
|
15 |
+
#include "config.h"
|
16 |
+
#include "stdio.h"
|
17 |
+
|
18 |
+
#define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
|
19 |
+
#define NUM_WARPS (BLOCK_SIZE/32)
|
20 |
+
|
21 |
+
// Spherical harmonics coefficients
|
22 |
+
__device__ const float SH_C0 = 0.28209479177387814f;
|
23 |
+
__device__ const float SH_C1 = 0.4886025119029199f;
|
24 |
+
__device__ const float SH_C2[] = {
|
25 |
+
1.0925484305920792f,
|
26 |
+
-1.0925484305920792f,
|
27 |
+
0.31539156525252005f,
|
28 |
+
-1.0925484305920792f,
|
29 |
+
0.5462742152960396f
|
30 |
+
};
|
31 |
+
__device__ const float SH_C3[] = {
|
32 |
+
-0.5900435899266435f,
|
33 |
+
2.890611442640554f,
|
34 |
+
-0.4570457994644658f,
|
35 |
+
0.3731763325901154f,
|
36 |
+
-0.4570457994644658f,
|
37 |
+
1.445305721320277f,
|
38 |
+
-0.5900435899266435f
|
39 |
+
};
|
40 |
+
|
41 |
+
__forceinline__ __device__ float ndc2Pix(float v, int S)
|
42 |
+
{
|
43 |
+
return ((v + 1.0) * S - 1.0) * 0.5;
|
44 |
+
}
|
45 |
+
|
46 |
+
__forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
|
47 |
+
{
|
48 |
+
rect_min = {
|
49 |
+
min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
|
50 |
+
min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
|
51 |
+
};
|
52 |
+
rect_max = {
|
53 |
+
min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
|
54 |
+
min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
|
55 |
+
};
|
56 |
+
}
|
57 |
+
|
58 |
+
__forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
|
59 |
+
{
|
60 |
+
float3 transformed = {
|
61 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
62 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
63 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
64 |
+
};
|
65 |
+
return transformed;
|
66 |
+
}
|
67 |
+
|
68 |
+
__forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
|
69 |
+
{
|
70 |
+
float4 transformed = {
|
71 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
72 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
73 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
74 |
+
matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
|
75 |
+
};
|
76 |
+
return transformed;
|
77 |
+
}
|
78 |
+
|
79 |
+
__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
|
80 |
+
{
|
81 |
+
float3 transformed = {
|
82 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
|
83 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
|
84 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
|
85 |
+
};
|
86 |
+
return transformed;
|
87 |
+
}
|
88 |
+
|
89 |
+
__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
|
90 |
+
{
|
91 |
+
float3 transformed = {
|
92 |
+
matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
|
93 |
+
matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
|
94 |
+
matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
|
95 |
+
};
|
96 |
+
return transformed;
|
97 |
+
}
|
98 |
+
|
99 |
+
__forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
|
100 |
+
{
|
101 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
102 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
103 |
+
float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
|
104 |
+
return dnormvdz;
|
105 |
+
}
|
106 |
+
|
107 |
+
__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
|
108 |
+
{
|
109 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
110 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
111 |
+
|
112 |
+
float3 dnormvdv;
|
113 |
+
dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
|
114 |
+
dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
|
115 |
+
dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
|
116 |
+
return dnormvdv;
|
117 |
+
}
|
118 |
+
|
119 |
+
__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
|
120 |
+
{
|
121 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
|
122 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
123 |
+
|
124 |
+
float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
|
125 |
+
float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
|
126 |
+
float4 dnormvdv;
|
127 |
+
dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
|
128 |
+
dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
|
129 |
+
dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
|
130 |
+
dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
|
131 |
+
return dnormvdv;
|
132 |
+
}
|
133 |
+
|
134 |
+
__forceinline__ __device__ float sigmoid(float x)
|
135 |
+
{
|
136 |
+
return 1.0f / (1.0f + expf(-x));
|
137 |
+
}
|
138 |
+
|
139 |
+
__forceinline__ __device__ bool in_frustum(int idx,
|
140 |
+
const float* orig_points,
|
141 |
+
const float* viewmatrix,
|
142 |
+
const float* projmatrix,
|
143 |
+
bool prefiltered,
|
144 |
+
float3& p_view)
|
145 |
+
{
|
146 |
+
float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
|
147 |
+
|
148 |
+
// Bring points to screen space
|
149 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
150 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
151 |
+
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
|
152 |
+
p_view = transformPoint4x3(p_orig, viewmatrix);
|
153 |
+
|
154 |
+
if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
|
155 |
+
{
|
156 |
+
if (prefiltered)
|
157 |
+
{
|
158 |
+
printf("Point is filtered although prefiltered is set. This shouldn't happen!");
|
159 |
+
__trap();
|
160 |
+
}
|
161 |
+
return false;
|
162 |
+
}
|
163 |
+
return true;
|
164 |
+
}
|
165 |
+
|
166 |
+
#define CHECK_CUDA(A, debug) \
|
167 |
+
A; if(debug) { \
|
168 |
+
auto ret = cudaDeviceSynchronize(); \
|
169 |
+
if (ret != cudaSuccess) { \
|
170 |
+
std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
|
171 |
+
throw std::runtime_error(cudaGetErrorString(ret)); \
|
172 |
+
} \
|
173 |
+
}
|
174 |
+
|
175 |
+
#endif
|
diff-gaussian-rasterization/cuda_rasterizer/backward.cu
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "backward.h"
|
13 |
+
#include "auxiliary.h"
|
14 |
+
#include <cooperative_groups.h>
|
15 |
+
#include <cooperative_groups/reduce.h>
|
16 |
+
namespace cg = cooperative_groups;
|
17 |
+
|
18 |
+
// Backward pass for conversion of spherical harmonics to RGB for
|
19 |
+
// each Gaussian.
|
20 |
+
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
|
21 |
+
{
|
22 |
+
// Compute intermediate values, as it is done during forward
|
23 |
+
glm::vec3 pos = means[idx];
|
24 |
+
glm::vec3 dir_orig = pos - campos;
|
25 |
+
glm::vec3 dir = dir_orig / glm::length(dir_orig);
|
26 |
+
|
27 |
+
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
|
28 |
+
|
29 |
+
// Use PyTorch rule for clamping: if clamping was applied,
|
30 |
+
// gradient becomes 0.
|
31 |
+
glm::vec3 dL_dRGB = dL_dcolor[idx];
|
32 |
+
dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
|
33 |
+
dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
|
34 |
+
dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
|
35 |
+
|
36 |
+
glm::vec3 dRGBdx(0, 0, 0);
|
37 |
+
glm::vec3 dRGBdy(0, 0, 0);
|
38 |
+
glm::vec3 dRGBdz(0, 0, 0);
|
39 |
+
float x = dir.x;
|
40 |
+
float y = dir.y;
|
41 |
+
float z = dir.z;
|
42 |
+
|
43 |
+
// Target location for this Gaussian to write SH gradients to
|
44 |
+
glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
|
45 |
+
|
46 |
+
// No tricks here, just high school-level calculus.
|
47 |
+
float dRGBdsh0 = SH_C0;
|
48 |
+
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
|
49 |
+
if (deg > 0)
|
50 |
+
{
|
51 |
+
float dRGBdsh1 = -SH_C1 * y;
|
52 |
+
float dRGBdsh2 = SH_C1 * z;
|
53 |
+
float dRGBdsh3 = -SH_C1 * x;
|
54 |
+
dL_dsh[1] = dRGBdsh1 * dL_dRGB;
|
55 |
+
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
|
56 |
+
dL_dsh[3] = dRGBdsh3 * dL_dRGB;
|
57 |
+
|
58 |
+
dRGBdx = -SH_C1 * sh[3];
|
59 |
+
dRGBdy = -SH_C1 * sh[1];
|
60 |
+
dRGBdz = SH_C1 * sh[2];
|
61 |
+
|
62 |
+
if (deg > 1)
|
63 |
+
{
|
64 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
65 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
66 |
+
|
67 |
+
float dRGBdsh4 = SH_C2[0] * xy;
|
68 |
+
float dRGBdsh5 = SH_C2[1] * yz;
|
69 |
+
float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
|
70 |
+
float dRGBdsh7 = SH_C2[3] * xz;
|
71 |
+
float dRGBdsh8 = SH_C2[4] * (xx - yy);
|
72 |
+
dL_dsh[4] = dRGBdsh4 * dL_dRGB;
|
73 |
+
dL_dsh[5] = dRGBdsh5 * dL_dRGB;
|
74 |
+
dL_dsh[6] = dRGBdsh6 * dL_dRGB;
|
75 |
+
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
|
76 |
+
dL_dsh[8] = dRGBdsh8 * dL_dRGB;
|
77 |
+
|
78 |
+
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
|
79 |
+
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
|
80 |
+
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
|
81 |
+
|
82 |
+
if (deg > 2)
|
83 |
+
{
|
84 |
+
float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
|
85 |
+
float dRGBdsh10 = SH_C3[1] * xy * z;
|
86 |
+
float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
|
87 |
+
float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
|
88 |
+
float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
|
89 |
+
float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
|
90 |
+
float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
|
91 |
+
dL_dsh[9] = dRGBdsh9 * dL_dRGB;
|
92 |
+
dL_dsh[10] = dRGBdsh10 * dL_dRGB;
|
93 |
+
dL_dsh[11] = dRGBdsh11 * dL_dRGB;
|
94 |
+
dL_dsh[12] = dRGBdsh12 * dL_dRGB;
|
95 |
+
dL_dsh[13] = dRGBdsh13 * dL_dRGB;
|
96 |
+
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
|
97 |
+
dL_dsh[15] = dRGBdsh15 * dL_dRGB;
|
98 |
+
|
99 |
+
dRGBdx += (
|
100 |
+
SH_C3[0] * sh[9] * 3.f * 2.f * xy +
|
101 |
+
SH_C3[1] * sh[10] * yz +
|
102 |
+
SH_C3[2] * sh[11] * -2.f * xy +
|
103 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * xz +
|
104 |
+
SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
|
105 |
+
SH_C3[5] * sh[14] * 2.f * xz +
|
106 |
+
SH_C3[6] * sh[15] * 3.f * (xx - yy));
|
107 |
+
|
108 |
+
dRGBdy += (
|
109 |
+
SH_C3[0] * sh[9] * 3.f * (xx - yy) +
|
110 |
+
SH_C3[1] * sh[10] * xz +
|
111 |
+
SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
|
112 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * yz +
|
113 |
+
SH_C3[4] * sh[13] * -2.f * xy +
|
114 |
+
SH_C3[5] * sh[14] * -2.f * yz +
|
115 |
+
SH_C3[6] * sh[15] * -3.f * 2.f * xy);
|
116 |
+
|
117 |
+
dRGBdz += (
|
118 |
+
SH_C3[1] * sh[10] * xy +
|
119 |
+
SH_C3[2] * sh[11] * 4.f * 2.f * yz +
|
120 |
+
SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
|
121 |
+
SH_C3[4] * sh[13] * 4.f * 2.f * xz +
|
122 |
+
SH_C3[5] * sh[14] * (xx - yy));
|
123 |
+
}
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
// The view direction is an input to the computation. View direction
|
128 |
+
// is influenced by the Gaussian's mean, so SHs gradients
|
129 |
+
// must propagate back into 3D position.
|
130 |
+
glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
|
131 |
+
|
132 |
+
// Account for normalization of direction
|
133 |
+
float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
|
134 |
+
|
135 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
136 |
+
// that is caused because the mean affects the view-dependent color.
|
137 |
+
// Additional mean gradient is accumulated in below methods.
|
138 |
+
dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
|
139 |
+
}
|
140 |
+
|
141 |
+
// Backward version of INVERSE 2D covariance matrix computation
|
142 |
+
// (due to length launched as separate kernel before other
|
143 |
+
// backward steps contained in preprocess)
|
144 |
+
__global__ void computeCov2DCUDA(int P,
|
145 |
+
const float3* means,
|
146 |
+
const int* radii,
|
147 |
+
const float* cov3Ds,
|
148 |
+
const float h_x, float h_y,
|
149 |
+
const float tan_fovx, float tan_fovy,
|
150 |
+
const float* view_matrix,
|
151 |
+
const float* dL_dconics,
|
152 |
+
float3* dL_dmeans,
|
153 |
+
float* dL_dcov)
|
154 |
+
{
|
155 |
+
auto idx = cg::this_grid().thread_rank();
|
156 |
+
if (idx >= P || !(radii[idx] > 0))
|
157 |
+
return;
|
158 |
+
|
159 |
+
// Reading location of 3D covariance for this Gaussian
|
160 |
+
const float* cov3D = cov3Ds + 6 * idx;
|
161 |
+
|
162 |
+
// Fetch gradients, recompute 2D covariance and relevant
|
163 |
+
// intermediate forward results needed in the backward.
|
164 |
+
float3 mean = means[idx];
|
165 |
+
float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
|
166 |
+
float3 t = transformPoint4x3(mean, view_matrix);
|
167 |
+
|
168 |
+
const float limx = 1.3f * tan_fovx;
|
169 |
+
const float limy = 1.3f * tan_fovy;
|
170 |
+
const float txtz = t.x / t.z;
|
171 |
+
const float tytz = t.y / t.z;
|
172 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
173 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
174 |
+
|
175 |
+
const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
|
176 |
+
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
|
177 |
+
|
178 |
+
glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
|
179 |
+
0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
|
180 |
+
0, 0, 0);
|
181 |
+
|
182 |
+
glm::mat3 W = glm::mat3(
|
183 |
+
view_matrix[0], view_matrix[4], view_matrix[8],
|
184 |
+
view_matrix[1], view_matrix[5], view_matrix[9],
|
185 |
+
view_matrix[2], view_matrix[6], view_matrix[10]);
|
186 |
+
|
187 |
+
glm::mat3 Vrk = glm::mat3(
|
188 |
+
cov3D[0], cov3D[1], cov3D[2],
|
189 |
+
cov3D[1], cov3D[3], cov3D[4],
|
190 |
+
cov3D[2], cov3D[4], cov3D[5]);
|
191 |
+
|
192 |
+
glm::mat3 T = W * J;
|
193 |
+
|
194 |
+
glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
|
195 |
+
|
196 |
+
// Use helper variables for 2D covariance entries. More compact.
|
197 |
+
float a = cov2D[0][0] += 0.3f;
|
198 |
+
float b = cov2D[0][1];
|
199 |
+
float c = cov2D[1][1] += 0.3f;
|
200 |
+
|
201 |
+
float denom = a * c - b * b;
|
202 |
+
float dL_da = 0, dL_db = 0, dL_dc = 0;
|
203 |
+
float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
|
204 |
+
|
205 |
+
if (denom2inv != 0)
|
206 |
+
{
|
207 |
+
// Gradients of loss w.r.t. entries of 2D covariance matrix,
|
208 |
+
// given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
|
209 |
+
// e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
|
210 |
+
dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
|
211 |
+
dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
|
212 |
+
dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);
|
213 |
+
|
214 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
215 |
+
// given gradients w.r.t. 2D covariance matrix (diagonal).
|
216 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
217 |
+
dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
|
218 |
+
dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
|
219 |
+
dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);
|
220 |
+
|
221 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
222 |
+
// given gradients w.r.t. 2D covariance matrix (off-diagonal).
|
223 |
+
// Off-diagonal elements appear twice --> double the gradient.
|
224 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
225 |
+
dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
|
226 |
+
dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
|
227 |
+
dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
|
228 |
+
}
|
229 |
+
else
|
230 |
+
{
|
231 |
+
for (int i = 0; i < 6; i++)
|
232 |
+
dL_dcov[6 * idx + i] = 0;
|
233 |
+
}
|
234 |
+
|
235 |
+
// Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
|
236 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
237 |
+
float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
|
238 |
+
(T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
|
239 |
+
float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
|
240 |
+
(T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
|
241 |
+
float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
|
242 |
+
(T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
|
243 |
+
float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
|
244 |
+
(T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
|
245 |
+
float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
|
246 |
+
(T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
|
247 |
+
float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
|
248 |
+
(T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
|
249 |
+
|
250 |
+
// Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
|
251 |
+
// T = W * J
|
252 |
+
float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
|
253 |
+
float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
|
254 |
+
float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
|
255 |
+
float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
|
256 |
+
|
257 |
+
float tz = 1.f / t.z;
|
258 |
+
float tz2 = tz * tz;
|
259 |
+
float tz3 = tz2 * tz;
|
260 |
+
|
261 |
+
// Gradients of loss w.r.t. transformed Gaussian mean t
|
262 |
+
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
|
263 |
+
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
|
264 |
+
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
|
265 |
+
|
266 |
+
// Account for transformation of mean to t
|
267 |
+
// t = transformPoint4x3(mean, view_matrix);
|
268 |
+
float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);
|
269 |
+
|
270 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
271 |
+
// that is caused because the mean affects the covariance matrix.
|
272 |
+
// Additional mean gradient is accumulated in BACKWARD::preprocess.
|
273 |
+
dL_dmeans[idx] = dL_dmean;
|
274 |
+
}
|
275 |
+
|
276 |
+
// Backward pass for the conversion of scale and rotation to a
|
277 |
+
// 3D covariance matrix for each Gaussian.
|
278 |
+
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
|
279 |
+
{
|
280 |
+
// Recompute (intermediate) results for the 3D covariance computation.
|
281 |
+
glm::vec4 q = rot;// / glm::length(rot);
|
282 |
+
float r = q.x;
|
283 |
+
float x = q.y;
|
284 |
+
float y = q.z;
|
285 |
+
float z = q.w;
|
286 |
+
|
287 |
+
glm::mat3 R = glm::mat3(
|
288 |
+
1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
|
289 |
+
2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
290 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
|
291 |
+
);
|
292 |
+
|
293 |
+
glm::mat3 S = glm::mat3(1.0f);
|
294 |
+
|
295 |
+
glm::vec3 s = mod * scale;
|
296 |
+
S[0][0] = s.x;
|
297 |
+
S[1][1] = s.y;
|
298 |
+
S[2][2] = s.z;
|
299 |
+
|
300 |
+
glm::mat3 M = S * R;
|
301 |
+
|
302 |
+
const float* dL_dcov3D = dL_dcov3Ds + 6 * idx;
|
303 |
+
|
304 |
+
glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
|
305 |
+
glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
|
306 |
+
|
307 |
+
// Convert per-element covariance loss gradients to matrix form
|
308 |
+
glm::mat3 dL_dSigma = glm::mat3(
|
309 |
+
dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
|
310 |
+
0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
|
311 |
+
0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]
|
312 |
+
);
|
313 |
+
|
314 |
+
// Compute loss gradient w.r.t. matrix M
|
315 |
+
// dSigma_dM = 2 * M
|
316 |
+
glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
|
317 |
+
|
318 |
+
glm::mat3 Rt = glm::transpose(R);
|
319 |
+
glm::mat3 dL_dMt = glm::transpose(dL_dM);
|
320 |
+
|
321 |
+
// Gradients of loss w.r.t. scale
|
322 |
+
glm::vec3* dL_dscale = dL_dscales + idx;
|
323 |
+
dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
|
324 |
+
dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
|
325 |
+
dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
|
326 |
+
|
327 |
+
dL_dMt[0] *= s.x;
|
328 |
+
dL_dMt[1] *= s.y;
|
329 |
+
dL_dMt[2] *= s.z;
|
330 |
+
|
331 |
+
// Gradients of loss w.r.t. normalized quaternion
|
332 |
+
glm::vec4 dL_dq;
|
333 |
+
dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
|
334 |
+
dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
|
335 |
+
dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
|
336 |
+
dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
|
337 |
+
|
338 |
+
// Gradients of loss w.r.t. unnormalized quaternion
|
339 |
+
float4* dL_drot = (float4*)(dL_drots + idx);
|
340 |
+
*dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
|
341 |
+
}
|
342 |
+
|
343 |
+
// Backward pass of the preprocessing steps, except
|
344 |
+
// for the covariance computation and inversion
|
345 |
+
// (those are handled by a previous kernel call)
|
346 |
+
template<int C>
|
347 |
+
__global__ void preprocessCUDA(
|
348 |
+
int P, int D, int M,
|
349 |
+
const float3* means,
|
350 |
+
const int* radii,
|
351 |
+
const float* shs,
|
352 |
+
const bool* clamped,
|
353 |
+
const glm::vec3* scales,
|
354 |
+
const glm::vec4* rotations,
|
355 |
+
const float scale_modifier,
|
356 |
+
const float* proj,
|
357 |
+
const glm::vec3* campos,
|
358 |
+
const float3* dL_dmean2D,
|
359 |
+
glm::vec3* dL_dmeans,
|
360 |
+
float* dL_dcolor,
|
361 |
+
float* dL_dcov3D,
|
362 |
+
float* dL_dsh,
|
363 |
+
glm::vec3* dL_dscale,
|
364 |
+
glm::vec4* dL_drot)
|
365 |
+
{
|
366 |
+
auto idx = cg::this_grid().thread_rank();
|
367 |
+
if (idx >= P || !(radii[idx] > 0))
|
368 |
+
return;
|
369 |
+
|
370 |
+
float3 m = means[idx];
|
371 |
+
|
372 |
+
// Taking care of gradients from the screenspace points
|
373 |
+
float4 m_hom = transformPoint4x4(m, proj);
|
374 |
+
float m_w = 1.0f / (m_hom.w + 0.0000001f);
|
375 |
+
|
376 |
+
// Compute loss gradient w.r.t. 3D means due to gradients of 2D means
|
377 |
+
// from rendering procedure
|
378 |
+
glm::vec3 dL_dmean;
|
379 |
+
float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
|
380 |
+
float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
|
381 |
+
dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
|
382 |
+
dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
|
383 |
+
dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
|
384 |
+
|
385 |
+
// That's the second part of the mean gradient. Previous computation
|
386 |
+
// of cov2D and following SH conversion also affects it.
|
387 |
+
dL_dmeans[idx] += dL_dmean;
|
388 |
+
|
389 |
+
// Compute gradient updates due to computing colors from SHs
|
390 |
+
if (shs)
|
391 |
+
computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
|
392 |
+
|
393 |
+
// Compute gradient updates due to computing covariance from scale/rotation
|
394 |
+
if (scales)
|
395 |
+
computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
|
396 |
+
}
|
397 |
+
|
398 |
+
// Backward version of the rendering procedure.
|
399 |
+
template <uint32_t C>
|
400 |
+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
|
401 |
+
renderCUDA(
|
402 |
+
const uint2* __restrict__ ranges,
|
403 |
+
const uint32_t* __restrict__ point_list,
|
404 |
+
int W, int H,
|
405 |
+
const float* __restrict__ bg_color,
|
406 |
+
const float2* __restrict__ points_xy_image,
|
407 |
+
const float4* __restrict__ conic_opacity,
|
408 |
+
const float* __restrict__ colors,
|
409 |
+
const float* __restrict__ final_Ts,
|
410 |
+
const uint32_t* __restrict__ n_contrib,
|
411 |
+
const float* __restrict__ dL_dpixels,
|
412 |
+
float3* __restrict__ dL_dmean2D,
|
413 |
+
float4* __restrict__ dL_dconic2D,
|
414 |
+
float* __restrict__ dL_dopacity,
|
415 |
+
float* __restrict__ dL_dcolors)
|
416 |
+
{
|
417 |
+
// We rasterize again. Compute necessary block info.
|
418 |
+
auto block = cg::this_thread_block();
|
419 |
+
const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
420 |
+
const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
|
421 |
+
const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
|
422 |
+
const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
|
423 |
+
const uint32_t pix_id = W * pix.y + pix.x;
|
424 |
+
const float2 pixf = { (float)pix.x, (float)pix.y };
|
425 |
+
|
426 |
+
const bool inside = pix.x < W&& pix.y < H;
|
427 |
+
const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
428 |
+
|
429 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
430 |
+
|
431 |
+
bool done = !inside;
|
432 |
+
int toDo = range.y - range.x;
|
433 |
+
|
434 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
435 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
436 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
437 |
+
__shared__ float collected_colors[C * BLOCK_SIZE];
|
438 |
+
|
439 |
+
// In the forward, we stored the final value for T, the
|
440 |
+
// product of all (1 - alpha) factors.
|
441 |
+
const float T_final = inside ? final_Ts[pix_id] : 0;
|
442 |
+
float T = T_final;
|
443 |
+
|
444 |
+
// We start from the back. The ID of the last contributing
|
445 |
+
// Gaussian is known from each pixel from the forward.
|
446 |
+
uint32_t contributor = toDo;
|
447 |
+
const int last_contributor = inside ? n_contrib[pix_id] : 0;
|
448 |
+
|
449 |
+
float accum_rec[C] = { 0 };
|
450 |
+
float dL_dpixel[C];
|
451 |
+
if (inside)
|
452 |
+
for (int i = 0; i < C; i++)
|
453 |
+
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
|
454 |
+
|
455 |
+
float last_alpha = 0;
|
456 |
+
float last_color[C] = { 0 };
|
457 |
+
|
458 |
+
// Gradient of pixel coordinate w.r.t. normalized
|
459 |
+
// screen-space viewport corrdinates (-1 to 1)
|
460 |
+
const float ddelx_dx = 0.5 * W;
|
461 |
+
const float ddely_dy = 0.5 * H;
|
462 |
+
|
463 |
+
// Traverse all Gaussians
|
464 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
|
465 |
+
{
|
466 |
+
// Load auxiliary data into shared memory, start in the BACK
|
467 |
+
// and load them in revers order.
|
468 |
+
block.sync();
|
469 |
+
const int progress = i * BLOCK_SIZE + block.thread_rank();
|
470 |
+
if (range.x + progress < range.y)
|
471 |
+
{
|
472 |
+
const int coll_id = point_list[range.y - progress - 1];
|
473 |
+
collected_id[block.thread_rank()] = coll_id;
|
474 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
475 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
476 |
+
for (int i = 0; i < C; i++)
|
477 |
+
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
|
478 |
+
}
|
479 |
+
block.sync();
|
480 |
+
|
481 |
+
// Iterate over Gaussians
|
482 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
|
483 |
+
{
|
484 |
+
// Keep track of current Gaussian ID. Skip, if this one
|
485 |
+
// is behind the last contributor for this pixel.
|
486 |
+
contributor--;
|
487 |
+
if (contributor >= last_contributor)
|
488 |
+
continue;
|
489 |
+
|
490 |
+
// Compute blending values, as before.
|
491 |
+
const float2 xy = collected_xy[j];
|
492 |
+
const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
|
493 |
+
const float4 con_o = collected_conic_opacity[j];
|
494 |
+
const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
|
495 |
+
if (power > 0.0f)
|
496 |
+
continue;
|
497 |
+
|
498 |
+
const float G = exp(power);
|
499 |
+
const float alpha = min(0.99f, con_o.w * G);
|
500 |
+
if (alpha < 1.0f / 255.0f)
|
501 |
+
continue;
|
502 |
+
|
503 |
+
T = T / (1.f - alpha);
|
504 |
+
const float dchannel_dcolor = alpha * T;
|
505 |
+
|
506 |
+
// Propagate gradients to per-Gaussian colors and keep
|
507 |
+
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
|
508 |
+
// pair).
|
509 |
+
float dL_dalpha = 0.0f;
|
510 |
+
const int global_id = collected_id[j];
|
511 |
+
for (int ch = 0; ch < C; ch++)
|
512 |
+
{
|
513 |
+
const float c = collected_colors[ch * BLOCK_SIZE + j];
|
514 |
+
// Update last color (to be used in the next iteration)
|
515 |
+
accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
|
516 |
+
last_color[ch] = c;
|
517 |
+
|
518 |
+
const float dL_dchannel = dL_dpixel[ch];
|
519 |
+
dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
|
520 |
+
// Update the gradients w.r.t. color of the Gaussian.
|
521 |
+
// Atomic, since this pixel is just one of potentially
|
522 |
+
// many that were affected by this Gaussian.
|
523 |
+
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
|
524 |
+
}
|
525 |
+
dL_dalpha *= T;
|
526 |
+
// Update last alpha (to be used in the next iteration)
|
527 |
+
last_alpha = alpha;
|
528 |
+
|
529 |
+
// Account for fact that alpha also influences how much of
|
530 |
+
// the background color is added if nothing left to blend
|
531 |
+
float bg_dot_dpixel = 0;
|
532 |
+
for (int i = 0; i < C; i++)
|
533 |
+
bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
|
534 |
+
dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
|
535 |
+
|
536 |
+
|
537 |
+
// Helpful reusable temporary variables
|
538 |
+
const float dL_dG = con_o.w * dL_dalpha;
|
539 |
+
const float gdx = G * d.x;
|
540 |
+
const float gdy = G * d.y;
|
541 |
+
const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
|
542 |
+
const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
|
543 |
+
|
544 |
+
// Update gradients w.r.t. 2D mean position of the Gaussian
|
545 |
+
atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
|
546 |
+
atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
|
547 |
+
|
548 |
+
// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
|
549 |
+
atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
|
550 |
+
atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
|
551 |
+
atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
|
552 |
+
|
553 |
+
// Update gradients w.r.t. opacity of the Gaussian
|
554 |
+
atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
|
555 |
+
}
|
556 |
+
}
|
557 |
+
}
|
558 |
+
|
559 |
+
void BACKWARD::preprocess(
|
560 |
+
int P, int D, int M,
|
561 |
+
const float3* means3D,
|
562 |
+
const int* radii,
|
563 |
+
const float* shs,
|
564 |
+
const bool* clamped,
|
565 |
+
const glm::vec3* scales,
|
566 |
+
const glm::vec4* rotations,
|
567 |
+
const float scale_modifier,
|
568 |
+
const float* cov3Ds,
|
569 |
+
const float* viewmatrix,
|
570 |
+
const float* projmatrix,
|
571 |
+
const float focal_x, float focal_y,
|
572 |
+
const float tan_fovx, float tan_fovy,
|
573 |
+
const glm::vec3* campos,
|
574 |
+
const float3* dL_dmean2D,
|
575 |
+
const float* dL_dconic,
|
576 |
+
glm::vec3* dL_dmean3D,
|
577 |
+
float* dL_dcolor,
|
578 |
+
float* dL_dcov3D,
|
579 |
+
float* dL_dsh,
|
580 |
+
glm::vec3* dL_dscale,
|
581 |
+
glm::vec4* dL_drot)
|
582 |
+
{
|
583 |
+
// Propagate gradients for the path of 2D conic matrix computation.
|
584 |
+
// Somewhat long, thus it is its own kernel rather than being part of
|
585 |
+
// "preprocess". When done, loss gradient w.r.t. 3D means has been
|
586 |
+
// modified and gradient w.r.t. 3D covariance matrix has been computed.
|
587 |
+
computeCov2DCUDA << <(P + 255) / 256, 256 >> > (
|
588 |
+
P,
|
589 |
+
means3D,
|
590 |
+
radii,
|
591 |
+
cov3Ds,
|
592 |
+
focal_x,
|
593 |
+
focal_y,
|
594 |
+
tan_fovx,
|
595 |
+
tan_fovy,
|
596 |
+
viewmatrix,
|
597 |
+
dL_dconic,
|
598 |
+
(float3*)dL_dmean3D,
|
599 |
+
dL_dcov3D);
|
600 |
+
|
601 |
+
// Propagate gradients for remaining steps: finish 3D mean gradients,
|
602 |
+
// propagate color gradients to SH (if desireD), propagate 3D covariance
|
603 |
+
// matrix gradients to scale and rotation.
|
604 |
+
preprocessCUDA<NUM_CHANNELS> << < (P + 255) / 256, 256 >> > (
|
605 |
+
P, D, M,
|
606 |
+
(float3*)means3D,
|
607 |
+
radii,
|
608 |
+
shs,
|
609 |
+
clamped,
|
610 |
+
(glm::vec3*)scales,
|
611 |
+
(glm::vec4*)rotations,
|
612 |
+
scale_modifier,
|
613 |
+
projmatrix,
|
614 |
+
campos,
|
615 |
+
(float3*)dL_dmean2D,
|
616 |
+
(glm::vec3*)dL_dmean3D,
|
617 |
+
dL_dcolor,
|
618 |
+
dL_dcov3D,
|
619 |
+
dL_dsh,
|
620 |
+
dL_dscale,
|
621 |
+
dL_drot);
|
622 |
+
}
|
623 |
+
|
624 |
+
void BACKWARD::render(
|
625 |
+
const dim3 grid, const dim3 block,
|
626 |
+
const uint2* ranges,
|
627 |
+
const uint32_t* point_list,
|
628 |
+
int W, int H,
|
629 |
+
const float* bg_color,
|
630 |
+
const float2* means2D,
|
631 |
+
const float4* conic_opacity,
|
632 |
+
const float* colors,
|
633 |
+
const float* final_Ts,
|
634 |
+
const uint32_t* n_contrib,
|
635 |
+
const float* dL_dpixels,
|
636 |
+
float3* dL_dmean2D,
|
637 |
+
float4* dL_dconic2D,
|
638 |
+
float* dL_dopacity,
|
639 |
+
float* dL_dcolors)
|
640 |
+
{
|
641 |
+
renderCUDA<NUM_CHANNELS> << <grid, block >> >(
|
642 |
+
ranges,
|
643 |
+
point_list,
|
644 |
+
W, H,
|
645 |
+
bg_color,
|
646 |
+
means2D,
|
647 |
+
conic_opacity,
|
648 |
+
colors,
|
649 |
+
final_Ts,
|
650 |
+
n_contrib,
|
651 |
+
dL_dpixels,
|
652 |
+
dL_dmean2D,
|
653 |
+
dL_dconic2D,
|
654 |
+
dL_dopacity,
|
655 |
+
dL_dcolors
|
656 |
+
);
|
657 |
+
}
|
diff-gaussian-rasterization/cuda_rasterizer/backward.h
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include "cuda_runtime.h"
|
17 |
+
#include "device_launch_parameters.h"
|
18 |
+
#define GLM_FORCE_CUDA
|
19 |
+
#include <glm/glm.hpp>
|
20 |
+
|
21 |
+
namespace BACKWARD
|
22 |
+
{
|
23 |
+
void render(
|
24 |
+
const dim3 grid, dim3 block,
|
25 |
+
const uint2* ranges,
|
26 |
+
const uint32_t* point_list,
|
27 |
+
int W, int H,
|
28 |
+
const float* bg_color,
|
29 |
+
const float2* means2D,
|
30 |
+
const float4* conic_opacity,
|
31 |
+
const float* colors,
|
32 |
+
const float* final_Ts,
|
33 |
+
const uint32_t* n_contrib,
|
34 |
+
const float* dL_dpixels,
|
35 |
+
float3* dL_dmean2D,
|
36 |
+
float4* dL_dconic2D,
|
37 |
+
float* dL_dopacity,
|
38 |
+
float* dL_dcolors);
|
39 |
+
|
40 |
+
void preprocess(
|
41 |
+
int P, int D, int M,
|
42 |
+
const float3* means,
|
43 |
+
const int* radii,
|
44 |
+
const float* shs,
|
45 |
+
const bool* clamped,
|
46 |
+
const glm::vec3* scales,
|
47 |
+
const glm::vec4* rotations,
|
48 |
+
const float scale_modifier,
|
49 |
+
const float* cov3Ds,
|
50 |
+
const float* view,
|
51 |
+
const float* proj,
|
52 |
+
const float focal_x, float focal_y,
|
53 |
+
const float tan_fovx, float tan_fovy,
|
54 |
+
const glm::vec3* campos,
|
55 |
+
const float3* dL_dmean2D,
|
56 |
+
const float* dL_dconics,
|
57 |
+
glm::vec3* dL_dmeans,
|
58 |
+
float* dL_dcolor,
|
59 |
+
float* dL_dcov3D,
|
60 |
+
float* dL_dsh,
|
61 |
+
glm::vec3* dL_dscale,
|
62 |
+
glm::vec4* dL_drot);
|
63 |
+
}
|
64 |
+
|
65 |
+
#endif
|
diff-gaussian-rasterization/cuda_rasterizer/config.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
14 |
+
|
15 |
+
#define NUM_CHANNELS 3 // Default 3, RGB
|
16 |
+
#define BLOCK_X 16
|
17 |
+
#define BLOCK_Y 16
|
18 |
+
|
19 |
+
#endif
|
diff-gaussian-rasterization/cuda_rasterizer/forward.cu
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "forward.h"
|
13 |
+
#include "auxiliary.h"
|
14 |
+
#include <cooperative_groups.h>
|
15 |
+
#include <cooperative_groups/reduce.h>
|
16 |
+
namespace cg = cooperative_groups;
|
17 |
+
|
18 |
+
// Forward method for converting the input spherical harmonics
|
19 |
+
// coefficients of each Gaussian to a simple RGB color.
|
20 |
+
__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
|
21 |
+
{
|
22 |
+
// The implementation is loosely based on code for
|
23 |
+
// "Differentiable Point-Based Radiance Fields for
|
24 |
+
// Efficient View Synthesis" by Zhang et al. (2022)
|
25 |
+
glm::vec3 pos = means[idx];
|
26 |
+
glm::vec3 dir = pos - campos;
|
27 |
+
dir = dir / glm::length(dir);
|
28 |
+
|
29 |
+
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
|
30 |
+
glm::vec3 result = SH_C0 * sh[0];
|
31 |
+
|
32 |
+
if (deg > 0)
|
33 |
+
{
|
34 |
+
float x = dir.x;
|
35 |
+
float y = dir.y;
|
36 |
+
float z = dir.z;
|
37 |
+
result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
|
38 |
+
|
39 |
+
if (deg > 1)
|
40 |
+
{
|
41 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
42 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
43 |
+
result = result +
|
44 |
+
SH_C2[0] * xy * sh[4] +
|
45 |
+
SH_C2[1] * yz * sh[5] +
|
46 |
+
SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
|
47 |
+
SH_C2[3] * xz * sh[7] +
|
48 |
+
SH_C2[4] * (xx - yy) * sh[8];
|
49 |
+
|
50 |
+
if (deg > 2)
|
51 |
+
{
|
52 |
+
result = result +
|
53 |
+
SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
|
54 |
+
SH_C3[1] * xy * z * sh[10] +
|
55 |
+
SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
|
56 |
+
SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
|
57 |
+
SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
|
58 |
+
SH_C3[5] * z * (xx - yy) * sh[14] +
|
59 |
+
SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
result += 0.5f;
|
64 |
+
|
65 |
+
// RGB colors are clamped to positive values. If values are
|
66 |
+
// clamped, we need to keep track of this for the backward pass.
|
67 |
+
clamped[3 * idx + 0] = (result.x < 0);
|
68 |
+
clamped[3 * idx + 1] = (result.y < 0);
|
69 |
+
clamped[3 * idx + 2] = (result.z < 0);
|
70 |
+
return glm::max(result, 0.0f);
|
71 |
+
}
|
72 |
+
|
73 |
+
// Forward version of 2D covariance matrix computation
|
74 |
+
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
|
75 |
+
{
|
76 |
+
// The following models the steps outlined by equations 29
|
77 |
+
// and 31 in "EWA Splatting" (Zwicker et al., 2002).
|
78 |
+
// Additionally considers aspect / scaling of viewport.
|
79 |
+
// Transposes used to account for row-/column-major conventions.
|
80 |
+
float3 t = transformPoint4x3(mean, viewmatrix);
|
81 |
+
|
82 |
+
const float limx = 1.3f * tan_fovx;
|
83 |
+
const float limy = 1.3f * tan_fovy;
|
84 |
+
const float txtz = t.x / t.z;
|
85 |
+
const float tytz = t.y / t.z;
|
86 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
87 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
88 |
+
|
89 |
+
glm::mat3 J = glm::mat3(
|
90 |
+
focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
|
91 |
+
0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
|
92 |
+
0, 0, 0);
|
93 |
+
|
94 |
+
glm::mat3 W = glm::mat3(
|
95 |
+
viewmatrix[0], viewmatrix[4], viewmatrix[8],
|
96 |
+
viewmatrix[1], viewmatrix[5], viewmatrix[9],
|
97 |
+
viewmatrix[2], viewmatrix[6], viewmatrix[10]);
|
98 |
+
|
99 |
+
glm::mat3 T = W * J;
|
100 |
+
|
101 |
+
glm::mat3 Vrk = glm::mat3(
|
102 |
+
cov3D[0], cov3D[1], cov3D[2],
|
103 |
+
cov3D[1], cov3D[3], cov3D[4],
|
104 |
+
cov3D[2], cov3D[4], cov3D[5]);
|
105 |
+
|
106 |
+
glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
|
107 |
+
|
108 |
+
// Apply low-pass filter: every Gaussian should be at least
|
109 |
+
// one pixel wide/high. Discard 3rd row and column.
|
110 |
+
cov[0][0] += 0.3f;
|
111 |
+
cov[1][1] += 0.3f;
|
112 |
+
return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) };
|
113 |
+
}
|
114 |
+
|
115 |
+
// Forward method for converting scale and rotation properties of each
|
116 |
+
// Gaussian to a 3D covariance matrix in world space. Also takes care
|
117 |
+
// of quaternion normalization.
|
118 |
+
__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
|
119 |
+
{
|
120 |
+
// Create scaling matrix
|
121 |
+
glm::mat3 S = glm::mat3(1.0f);
|
122 |
+
S[0][0] = mod * scale.x;
|
123 |
+
S[1][1] = mod * scale.y;
|
124 |
+
S[2][2] = mod * scale.z;
|
125 |
+
|
126 |
+
// Normalize quaternion to get valid rotation
|
127 |
+
glm::vec4 q = rot;// / glm::length(rot);
|
128 |
+
float r = q.x;
|
129 |
+
float x = q.y;
|
130 |
+
float y = q.z;
|
131 |
+
float z = q.w;
|
132 |
+
|
133 |
+
// Compute rotation matrix from quaternion
|
134 |
+
glm::mat3 R = glm::mat3(
|
135 |
+
1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
|
136 |
+
2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
137 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
|
138 |
+
);
|
139 |
+
|
140 |
+
glm::mat3 M = S * R;
|
141 |
+
|
142 |
+
// Compute 3D world covariance matrix Sigma
|
143 |
+
glm::mat3 Sigma = glm::transpose(M) * M;
|
144 |
+
|
145 |
+
// Covariance is symmetric, only store upper right
|
146 |
+
cov3D[0] = Sigma[0][0];
|
147 |
+
cov3D[1] = Sigma[0][1];
|
148 |
+
cov3D[2] = Sigma[0][2];
|
149 |
+
cov3D[3] = Sigma[1][1];
|
150 |
+
cov3D[4] = Sigma[1][2];
|
151 |
+
cov3D[5] = Sigma[2][2];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
155 |
+
template<int C>
|
156 |
+
__global__ void preprocessCUDA(int P, int D, int M,
|
157 |
+
const float* orig_points,
|
158 |
+
const glm::vec3* scales,
|
159 |
+
const float scale_modifier,
|
160 |
+
const glm::vec4* rotations,
|
161 |
+
const float* opacities,
|
162 |
+
const float* shs,
|
163 |
+
bool* clamped,
|
164 |
+
const float* cov3D_precomp,
|
165 |
+
const float* colors_precomp,
|
166 |
+
const float* viewmatrix,
|
167 |
+
const float* projmatrix,
|
168 |
+
const glm::vec3* cam_pos,
|
169 |
+
const int W, int H,
|
170 |
+
const float tan_fovx, float tan_fovy,
|
171 |
+
const float focal_x, float focal_y,
|
172 |
+
int* radii,
|
173 |
+
float2* points_xy_image,
|
174 |
+
float* depths,
|
175 |
+
float* cov3Ds,
|
176 |
+
float* rgb,
|
177 |
+
float4* conic_opacity,
|
178 |
+
const dim3 grid,
|
179 |
+
uint32_t* tiles_touched,
|
180 |
+
bool prefiltered)
|
181 |
+
{
|
182 |
+
auto idx = cg::this_grid().thread_rank();
|
183 |
+
if (idx >= P)
|
184 |
+
return;
|
185 |
+
|
186 |
+
// Initialize radius and touched tiles to 0. If this isn't changed,
|
187 |
+
// this Gaussian will not be processed further.
|
188 |
+
radii[idx] = 0;
|
189 |
+
tiles_touched[idx] = 0;
|
190 |
+
|
191 |
+
// Perform near culling, quit if outside.
|
192 |
+
float3 p_view;
|
193 |
+
if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
|
194 |
+
return;
|
195 |
+
|
196 |
+
// Transform point by projecting
|
197 |
+
float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
|
198 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
199 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
200 |
+
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
|
201 |
+
|
202 |
+
// If 3D covariance matrix is precomputed, use it, otherwise compute
|
203 |
+
// from scaling and rotation parameters.
|
204 |
+
const float* cov3D;
|
205 |
+
if (cov3D_precomp != nullptr)
|
206 |
+
{
|
207 |
+
cov3D = cov3D_precomp + idx * 6;
|
208 |
+
}
|
209 |
+
else
|
210 |
+
{
|
211 |
+
computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
|
212 |
+
cov3D = cov3Ds + idx * 6;
|
213 |
+
}
|
214 |
+
|
215 |
+
// Compute 2D screen-space covariance matrix
|
216 |
+
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
|
217 |
+
|
218 |
+
// Invert covariance (EWA algorithm)
|
219 |
+
float det = (cov.x * cov.z - cov.y * cov.y);
|
220 |
+
if (det == 0.0f)
|
221 |
+
return;
|
222 |
+
float det_inv = 1.f / det;
|
223 |
+
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
|
224 |
+
|
225 |
+
// Compute extent in screen space (by finding eigenvalues of
|
226 |
+
// 2D covariance matrix). Use extent to compute a bounding rectangle
|
227 |
+
// of screen-space tiles that this Gaussian overlaps with. Quit if
|
228 |
+
// rectangle covers 0 tiles.
|
229 |
+
float mid = 0.5f * (cov.x + cov.z);
|
230 |
+
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
|
231 |
+
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
|
232 |
+
float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
|
233 |
+
float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
|
234 |
+
uint2 rect_min, rect_max;
|
235 |
+
getRect(point_image, my_radius, rect_min, rect_max, grid);
|
236 |
+
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
|
237 |
+
return;
|
238 |
+
|
239 |
+
// If colors have been precomputed, use them, otherwise convert
|
240 |
+
// spherical harmonics coefficients to RGB color.
|
241 |
+
if (colors_precomp == nullptr)
|
242 |
+
{
|
243 |
+
glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
|
244 |
+
rgb[idx * C + 0] = result.x;
|
245 |
+
rgb[idx * C + 1] = result.y;
|
246 |
+
rgb[idx * C + 2] = result.z;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Store some useful helper data for the next steps.
|
250 |
+
depths[idx] = p_view.z;
|
251 |
+
radii[idx] = my_radius;
|
252 |
+
points_xy_image[idx] = point_image;
|
253 |
+
// Inverse 2D covariance and opacity neatly pack into one float4
|
254 |
+
conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] };
|
255 |
+
tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
|
256 |
+
}
|
257 |
+
|
258 |
+
// Main rasterization method. Collaboratively works on one tile per
|
259 |
+
// block, each thread treats one pixel. Alternates between fetching
|
260 |
+
// and rasterizing data.
|
261 |
+
template <uint32_t CHANNELS>
|
262 |
+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
|
263 |
+
renderCUDA(
|
264 |
+
const uint2* __restrict__ ranges,
|
265 |
+
const uint32_t* __restrict__ point_list,
|
266 |
+
int W, int H,
|
267 |
+
const float2* __restrict__ points_xy_image,
|
268 |
+
const float* __restrict__ features,
|
269 |
+
const float4* __restrict__ conic_opacity,
|
270 |
+
float* __restrict__ final_T,
|
271 |
+
uint32_t* __restrict__ n_contrib,
|
272 |
+
const float* __restrict__ bg_color,
|
273 |
+
float* __restrict__ out_color)
|
274 |
+
{
|
275 |
+
// Identify current tile and associated min/max pixel range.
|
276 |
+
auto block = cg::this_thread_block();
|
277 |
+
uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
278 |
+
uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
|
279 |
+
uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
|
280 |
+
uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
|
281 |
+
uint32_t pix_id = W * pix.y + pix.x;
|
282 |
+
float2 pixf = { (float)pix.x, (float)pix.y };
|
283 |
+
|
284 |
+
// Check if this thread is associated with a valid pixel or outside.
|
285 |
+
bool inside = pix.x < W&& pix.y < H;
|
286 |
+
// Done threads can help with fetching, but don't rasterize
|
287 |
+
bool done = !inside;
|
288 |
+
|
289 |
+
// Load start/end range of IDs to process in bit sorted list.
|
290 |
+
uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
291 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
292 |
+
int toDo = range.y - range.x;
|
293 |
+
|
294 |
+
// Allocate storage for batches of collectively fetched data.
|
295 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
296 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
297 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
298 |
+
|
299 |
+
// Initialize helper variables
|
300 |
+
float T = 1.0f;
|
301 |
+
uint32_t contributor = 0;
|
302 |
+
uint32_t last_contributor = 0;
|
303 |
+
float C[CHANNELS] = { 0 };
|
304 |
+
|
305 |
+
// Iterate over batches until all done or range is complete
|
306 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
|
307 |
+
{
|
308 |
+
// End if entire block votes that it is done rasterizing
|
309 |
+
int num_done = __syncthreads_count(done);
|
310 |
+
if (num_done == BLOCK_SIZE)
|
311 |
+
break;
|
312 |
+
|
313 |
+
// Collectively fetch per-Gaussian data from global to shared
|
314 |
+
int progress = i * BLOCK_SIZE + block.thread_rank();
|
315 |
+
if (range.x + progress < range.y)
|
316 |
+
{
|
317 |
+
int coll_id = point_list[range.x + progress];
|
318 |
+
collected_id[block.thread_rank()] = coll_id;
|
319 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
320 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
321 |
+
}
|
322 |
+
block.sync();
|
323 |
+
|
324 |
+
// Iterate over current batch
|
325 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
|
326 |
+
{
|
327 |
+
// Keep track of current position in range
|
328 |
+
contributor++;
|
329 |
+
|
330 |
+
// Resample using conic matrix (cf. "Surface
|
331 |
+
// Splatting" by Zwicker et al., 2001)
|
332 |
+
float2 xy = collected_xy[j];
|
333 |
+
float2 d = { xy.x - pixf.x, xy.y - pixf.y };
|
334 |
+
float4 con_o = collected_conic_opacity[j];
|
335 |
+
float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
|
336 |
+
if (power > 0.0f)
|
337 |
+
continue;
|
338 |
+
|
339 |
+
// Eq. (2) from 3D Gaussian splatting paper.
|
340 |
+
// Obtain alpha by multiplying with Gaussian opacity
|
341 |
+
// and its exponential falloff from mean.
|
342 |
+
// Avoid numerical instabilities (see paper appendix).
|
343 |
+
float alpha = min(0.99f, con_o.w * exp(power));
|
344 |
+
if (alpha < 1.0f / 255.0f)
|
345 |
+
continue;
|
346 |
+
float test_T = T * (1 - alpha);
|
347 |
+
if (test_T < 0.0001f)
|
348 |
+
{
|
349 |
+
done = true;
|
350 |
+
continue;
|
351 |
+
}
|
352 |
+
|
353 |
+
// Eq. (3) from 3D Gaussian splatting paper.
|
354 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
355 |
+
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
|
356 |
+
|
357 |
+
T = test_T;
|
358 |
+
|
359 |
+
// Keep track of last range entry to update this
|
360 |
+
// pixel.
|
361 |
+
last_contributor = contributor;
|
362 |
+
}
|
363 |
+
}
|
364 |
+
|
365 |
+
// All threads that treat valid pixel write out their final
|
366 |
+
// rendering data to the frame and auxiliary buffers.
|
367 |
+
if (inside)
|
368 |
+
{
|
369 |
+
final_T[pix_id] = T;
|
370 |
+
n_contrib[pix_id] = last_contributor;
|
371 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
372 |
+
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
|
373 |
+
}
|
374 |
+
}
|
375 |
+
|
376 |
+
void FORWARD::render(
|
377 |
+
const dim3 grid, dim3 block,
|
378 |
+
const uint2* ranges,
|
379 |
+
const uint32_t* point_list,
|
380 |
+
int W, int H,
|
381 |
+
const float2* means2D,
|
382 |
+
const float* colors,
|
383 |
+
const float4* conic_opacity,
|
384 |
+
float* final_T,
|
385 |
+
uint32_t* n_contrib,
|
386 |
+
const float* bg_color,
|
387 |
+
float* out_color)
|
388 |
+
{
|
389 |
+
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
|
390 |
+
ranges,
|
391 |
+
point_list,
|
392 |
+
W, H,
|
393 |
+
means2D,
|
394 |
+
colors,
|
395 |
+
conic_opacity,
|
396 |
+
final_T,
|
397 |
+
n_contrib,
|
398 |
+
bg_color,
|
399 |
+
out_color);
|
400 |
+
}
|
401 |
+
|
402 |
+
void FORWARD::preprocess(int P, int D, int M,
|
403 |
+
const float* means3D,
|
404 |
+
const glm::vec3* scales,
|
405 |
+
const float scale_modifier,
|
406 |
+
const glm::vec4* rotations,
|
407 |
+
const float* opacities,
|
408 |
+
const float* shs,
|
409 |
+
bool* clamped,
|
410 |
+
const float* cov3D_precomp,
|
411 |
+
const float* colors_precomp,
|
412 |
+
const float* viewmatrix,
|
413 |
+
const float* projmatrix,
|
414 |
+
const glm::vec3* cam_pos,
|
415 |
+
const int W, int H,
|
416 |
+
const float focal_x, float focal_y,
|
417 |
+
const float tan_fovx, float tan_fovy,
|
418 |
+
int* radii,
|
419 |
+
float2* means2D,
|
420 |
+
float* depths,
|
421 |
+
float* cov3Ds,
|
422 |
+
float* rgb,
|
423 |
+
float4* conic_opacity,
|
424 |
+
const dim3 grid,
|
425 |
+
uint32_t* tiles_touched,
|
426 |
+
bool prefiltered)
|
427 |
+
{
|
428 |
+
preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
|
429 |
+
P, D, M,
|
430 |
+
means3D,
|
431 |
+
scales,
|
432 |
+
scale_modifier,
|
433 |
+
rotations,
|
434 |
+
opacities,
|
435 |
+
shs,
|
436 |
+
clamped,
|
437 |
+
cov3D_precomp,
|
438 |
+
colors_precomp,
|
439 |
+
viewmatrix,
|
440 |
+
projmatrix,
|
441 |
+
cam_pos,
|
442 |
+
W, H,
|
443 |
+
tan_fovx, tan_fovy,
|
444 |
+
focal_x, focal_y,
|
445 |
+
radii,
|
446 |
+
means2D,
|
447 |
+
depths,
|
448 |
+
cov3Ds,
|
449 |
+
rgb,
|
450 |
+
conic_opacity,
|
451 |
+
grid,
|
452 |
+
tiles_touched,
|
453 |
+
prefiltered
|
454 |
+
);
|
455 |
+
}
|
diff-gaussian-rasterization/cuda_rasterizer/forward.h
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include "cuda_runtime.h"
|
17 |
+
#include "device_launch_parameters.h"
|
18 |
+
#define GLM_FORCE_CUDA
|
19 |
+
#include <glm/glm.hpp>
|
20 |
+
|
21 |
+
namespace FORWARD
|
22 |
+
{
|
23 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
24 |
+
void preprocess(int P, int D, int M,
|
25 |
+
const float* orig_points,
|
26 |
+
const glm::vec3* scales,
|
27 |
+
const float scale_modifier,
|
28 |
+
const glm::vec4* rotations,
|
29 |
+
const float* opacities,
|
30 |
+
const float* shs,
|
31 |
+
bool* clamped,
|
32 |
+
const float* cov3D_precomp,
|
33 |
+
const float* colors_precomp,
|
34 |
+
const float* viewmatrix,
|
35 |
+
const float* projmatrix,
|
36 |
+
const glm::vec3* cam_pos,
|
37 |
+
const int W, int H,
|
38 |
+
const float focal_x, float focal_y,
|
39 |
+
const float tan_fovx, float tan_fovy,
|
40 |
+
int* radii,
|
41 |
+
float2* points_xy_image,
|
42 |
+
float* depths,
|
43 |
+
float* cov3Ds,
|
44 |
+
float* colors,
|
45 |
+
float4* conic_opacity,
|
46 |
+
const dim3 grid,
|
47 |
+
uint32_t* tiles_touched,
|
48 |
+
bool prefiltered);
|
49 |
+
|
50 |
+
// Main rasterization method.
|
51 |
+
void render(
|
52 |
+
const dim3 grid, dim3 block,
|
53 |
+
const uint2* ranges,
|
54 |
+
const uint32_t* point_list,
|
55 |
+
int W, int H,
|
56 |
+
const float2* points_xy_image,
|
57 |
+
const float* features,
|
58 |
+
const float4* conic_opacity,
|
59 |
+
float* final_T,
|
60 |
+
uint32_t* n_contrib,
|
61 |
+
const float* bg_color,
|
62 |
+
float* out_color);
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
#endif
|
diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_H_INCLUDED
|
14 |
+
|
15 |
+
#include <vector>
|
16 |
+
#include <functional>
|
17 |
+
|
18 |
+
namespace CudaRasterizer
|
19 |
+
{
|
20 |
+
class Rasterizer
|
21 |
+
{
|
22 |
+
public:
|
23 |
+
|
24 |
+
static void markVisible(
|
25 |
+
int P,
|
26 |
+
float* means3D,
|
27 |
+
float* viewmatrix,
|
28 |
+
float* projmatrix,
|
29 |
+
bool* present);
|
30 |
+
|
31 |
+
static int forward(
|
32 |
+
std::function<char* (size_t)> geometryBuffer,
|
33 |
+
std::function<char* (size_t)> binningBuffer,
|
34 |
+
std::function<char* (size_t)> imageBuffer,
|
35 |
+
const int P, int D, int M,
|
36 |
+
const float* background,
|
37 |
+
const int width, int height,
|
38 |
+
const float* means3D,
|
39 |
+
const float* shs,
|
40 |
+
const float* colors_precomp,
|
41 |
+
const float* opacities,
|
42 |
+
const float* scales,
|
43 |
+
const float scale_modifier,
|
44 |
+
const float* rotations,
|
45 |
+
const float* cov3D_precomp,
|
46 |
+
const float* viewmatrix,
|
47 |
+
const float* projmatrix,
|
48 |
+
const float* cam_pos,
|
49 |
+
const float tan_fovx, float tan_fovy,
|
50 |
+
const bool prefiltered,
|
51 |
+
float* out_color,
|
52 |
+
int* radii = nullptr,
|
53 |
+
bool debug = false);
|
54 |
+
|
55 |
+
static void backward(
|
56 |
+
const int P, int D, int M, int R,
|
57 |
+
const float* background,
|
58 |
+
const int width, int height,
|
59 |
+
const float* means3D,
|
60 |
+
const float* shs,
|
61 |
+
const float* colors_precomp,
|
62 |
+
const float* scales,
|
63 |
+
const float scale_modifier,
|
64 |
+
const float* rotations,
|
65 |
+
const float* cov3D_precomp,
|
66 |
+
const float* viewmatrix,
|
67 |
+
const float* projmatrix,
|
68 |
+
const float* campos,
|
69 |
+
const float tan_fovx, float tan_fovy,
|
70 |
+
const int* radii,
|
71 |
+
char* geom_buffer,
|
72 |
+
char* binning_buffer,
|
73 |
+
char* image_buffer,
|
74 |
+
const float* dL_dpix,
|
75 |
+
float* dL_dmean2D,
|
76 |
+
float* dL_dconic,
|
77 |
+
float* dL_dopacity,
|
78 |
+
float* dL_dcolor,
|
79 |
+
float* dL_dmean3D,
|
80 |
+
float* dL_dcov3D,
|
81 |
+
float* dL_dsh,
|
82 |
+
float* dL_dscale,
|
83 |
+
float* dL_drot,
|
84 |
+
bool debug);
|
85 |
+
};
|
86 |
+
};
|
87 |
+
|
88 |
+
#endif
|
diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "rasterizer_impl.h"
|
13 |
+
#include <iostream>
|
14 |
+
#include <fstream>
|
15 |
+
#include <algorithm>
|
16 |
+
#include <numeric>
|
17 |
+
#include <cuda.h>
|
18 |
+
#include "cuda_runtime.h"
|
19 |
+
#include "device_launch_parameters.h"
|
20 |
+
#include <cub/cub.cuh>
|
21 |
+
#include <cub/device/device_radix_sort.cuh>
|
22 |
+
#define GLM_FORCE_CUDA
|
23 |
+
#include <glm/glm.hpp>
|
24 |
+
|
25 |
+
#include <cooperative_groups.h>
|
26 |
+
#include <cooperative_groups/reduce.h>
|
27 |
+
namespace cg = cooperative_groups;
|
28 |
+
|
29 |
+
#include "auxiliary.h"
|
30 |
+
#include "forward.h"
|
31 |
+
#include "backward.h"
|
32 |
+
|
33 |
+
// Helper function to find the next-highest bit of the MSB
|
34 |
+
// on the CPU.
|
35 |
+
uint32_t getHigherMsb(uint32_t n)
|
36 |
+
{
|
37 |
+
uint32_t msb = sizeof(n) * 4;
|
38 |
+
uint32_t step = msb;
|
39 |
+
while (step > 1)
|
40 |
+
{
|
41 |
+
step /= 2;
|
42 |
+
if (n >> msb)
|
43 |
+
msb += step;
|
44 |
+
else
|
45 |
+
msb -= step;
|
46 |
+
}
|
47 |
+
if (n >> msb)
|
48 |
+
msb++;
|
49 |
+
return msb;
|
50 |
+
}
|
51 |
+
|
52 |
+
// Wrapper method to call auxiliary coarse frustum containment test.
|
53 |
+
// Mark all Gaussians that pass it.
|
54 |
+
__global__ void checkFrustum(int P,
|
55 |
+
const float* orig_points,
|
56 |
+
const float* viewmatrix,
|
57 |
+
const float* projmatrix,
|
58 |
+
bool* present)
|
59 |
+
{
|
60 |
+
auto idx = cg::this_grid().thread_rank();
|
61 |
+
if (idx >= P)
|
62 |
+
return;
|
63 |
+
|
64 |
+
float3 p_view;
|
65 |
+
present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
|
66 |
+
}
|
67 |
+
|
68 |
+
// Generates one key/value pair for all Gaussian / tile overlaps.
|
69 |
+
// Run once per Gaussian (1:N mapping).
|
70 |
+
__global__ void duplicateWithKeys(
|
71 |
+
int P,
|
72 |
+
const float2* points_xy,
|
73 |
+
const float* depths,
|
74 |
+
const uint32_t* offsets,
|
75 |
+
uint64_t* gaussian_keys_unsorted,
|
76 |
+
uint32_t* gaussian_values_unsorted,
|
77 |
+
int* radii,
|
78 |
+
dim3 grid)
|
79 |
+
{
|
80 |
+
auto idx = cg::this_grid().thread_rank();
|
81 |
+
if (idx >= P)
|
82 |
+
return;
|
83 |
+
|
84 |
+
// Generate no key/value pair for invisible Gaussians
|
85 |
+
if (radii[idx] > 0)
|
86 |
+
{
|
87 |
+
// Find this Gaussian's offset in buffer for writing keys/values.
|
88 |
+
uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
|
89 |
+
uint2 rect_min, rect_max;
|
90 |
+
|
91 |
+
getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
|
92 |
+
|
93 |
+
// For each tile that the bounding rect overlaps, emit a
|
94 |
+
// key/value pair. The key is | tile ID | depth |,
|
95 |
+
// and the value is the ID of the Gaussian. Sorting the values
|
96 |
+
// with this key yields Gaussian IDs in a list, such that they
|
97 |
+
// are first sorted by tile and then by depth.
|
98 |
+
for (int y = rect_min.y; y < rect_max.y; y++)
|
99 |
+
{
|
100 |
+
for (int x = rect_min.x; x < rect_max.x; x++)
|
101 |
+
{
|
102 |
+
uint64_t key = y * grid.x + x;
|
103 |
+
key <<= 32;
|
104 |
+
key |= *((uint32_t*)&depths[idx]);
|
105 |
+
gaussian_keys_unsorted[off] = key;
|
106 |
+
gaussian_values_unsorted[off] = idx;
|
107 |
+
off++;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
112 |
+
|
113 |
+
// Check keys to see if it is at the start/end of one tile's range in
|
114 |
+
// the full sorted list. If yes, write start/end of this tile.
|
115 |
+
// Run once per instanced (duplicated) Gaussian ID.
|
116 |
+
__global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges)
|
117 |
+
{
|
118 |
+
auto idx = cg::this_grid().thread_rank();
|
119 |
+
if (idx >= L)
|
120 |
+
return;
|
121 |
+
|
122 |
+
// Read tile ID from key. Update start/end of tile range if at limit.
|
123 |
+
uint64_t key = point_list_keys[idx];
|
124 |
+
uint32_t currtile = key >> 32;
|
125 |
+
if (idx == 0)
|
126 |
+
ranges[currtile].x = 0;
|
127 |
+
else
|
128 |
+
{
|
129 |
+
uint32_t prevtile = point_list_keys[idx - 1] >> 32;
|
130 |
+
if (currtile != prevtile)
|
131 |
+
{
|
132 |
+
ranges[prevtile].y = idx;
|
133 |
+
ranges[currtile].x = idx;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
if (idx == L - 1)
|
137 |
+
ranges[currtile].y = L;
|
138 |
+
}
|
139 |
+
|
140 |
+
// Mark Gaussians as visible/invisible, based on view frustum testing
|
141 |
+
void CudaRasterizer::Rasterizer::markVisible(
|
142 |
+
int P,
|
143 |
+
float* means3D,
|
144 |
+
float* viewmatrix,
|
145 |
+
float* projmatrix,
|
146 |
+
bool* present)
|
147 |
+
{
|
148 |
+
checkFrustum << <(P + 255) / 256, 256 >> > (
|
149 |
+
P,
|
150 |
+
means3D,
|
151 |
+
viewmatrix, projmatrix,
|
152 |
+
present);
|
153 |
+
}
|
154 |
+
|
155 |
+
CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P)
|
156 |
+
{
|
157 |
+
GeometryState geom;
|
158 |
+
obtain(chunk, geom.depths, P, 128);
|
159 |
+
obtain(chunk, geom.clamped, P * 3, 128);
|
160 |
+
obtain(chunk, geom.internal_radii, P, 128);
|
161 |
+
obtain(chunk, geom.means2D, P, 128);
|
162 |
+
obtain(chunk, geom.cov3D, P * 6, 128);
|
163 |
+
obtain(chunk, geom.conic_opacity, P, 128);
|
164 |
+
obtain(chunk, geom.rgb, P * 3, 128);
|
165 |
+
obtain(chunk, geom.tiles_touched, P, 128);
|
166 |
+
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
|
167 |
+
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
|
168 |
+
obtain(chunk, geom.point_offsets, P, 128);
|
169 |
+
return geom;
|
170 |
+
}
|
171 |
+
|
172 |
+
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N)
|
173 |
+
{
|
174 |
+
ImageState img;
|
175 |
+
obtain(chunk, img.accum_alpha, N, 128);
|
176 |
+
obtain(chunk, img.n_contrib, N, 128);
|
177 |
+
obtain(chunk, img.ranges, N, 128);
|
178 |
+
return img;
|
179 |
+
}
|
180 |
+
|
181 |
+
CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P)
|
182 |
+
{
|
183 |
+
BinningState binning;
|
184 |
+
obtain(chunk, binning.point_list, P, 128);
|
185 |
+
obtain(chunk, binning.point_list_unsorted, P, 128);
|
186 |
+
obtain(chunk, binning.point_list_keys, P, 128);
|
187 |
+
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
|
188 |
+
cub::DeviceRadixSort::SortPairs(
|
189 |
+
nullptr, binning.sorting_size,
|
190 |
+
binning.point_list_keys_unsorted, binning.point_list_keys,
|
191 |
+
binning.point_list_unsorted, binning.point_list, P);
|
192 |
+
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
|
193 |
+
return binning;
|
194 |
+
}
|
195 |
+
|
196 |
+
// Forward rendering procedure for differentiable rasterization
|
197 |
+
// of Gaussians.
|
198 |
+
int CudaRasterizer::Rasterizer::forward(
|
199 |
+
std::function<char* (size_t)> geometryBuffer,
|
200 |
+
std::function<char* (size_t)> binningBuffer,
|
201 |
+
std::function<char* (size_t)> imageBuffer,
|
202 |
+
const int P, int D, int M,
|
203 |
+
const float* background,
|
204 |
+
const int width, int height,
|
205 |
+
const float* means3D,
|
206 |
+
const float* shs,
|
207 |
+
const float* colors_precomp,
|
208 |
+
const float* opacities,
|
209 |
+
const float* scales,
|
210 |
+
const float scale_modifier,
|
211 |
+
const float* rotations,
|
212 |
+
const float* cov3D_precomp,
|
213 |
+
const float* viewmatrix,
|
214 |
+
const float* projmatrix,
|
215 |
+
const float* cam_pos,
|
216 |
+
const float tan_fovx, float tan_fovy,
|
217 |
+
const bool prefiltered,
|
218 |
+
float* out_color,
|
219 |
+
int* radii,
|
220 |
+
bool debug)
|
221 |
+
{
|
222 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
223 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
224 |
+
|
225 |
+
size_t chunk_size = required<GeometryState>(P);
|
226 |
+
char* chunkptr = geometryBuffer(chunk_size);
|
227 |
+
GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
|
228 |
+
|
229 |
+
if (radii == nullptr)
|
230 |
+
{
|
231 |
+
radii = geomState.internal_radii;
|
232 |
+
}
|
233 |
+
|
234 |
+
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
235 |
+
dim3 block(BLOCK_X, BLOCK_Y, 1);
|
236 |
+
|
237 |
+
// Dynamically resize image-based auxiliary buffers during training
|
238 |
+
size_t img_chunk_size = required<ImageState>(width * height);
|
239 |
+
char* img_chunkptr = imageBuffer(img_chunk_size);
|
240 |
+
ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
|
241 |
+
|
242 |
+
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
|
243 |
+
{
|
244 |
+
throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!");
|
245 |
+
}
|
246 |
+
|
247 |
+
// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
|
248 |
+
CHECK_CUDA(FORWARD::preprocess(
|
249 |
+
P, D, M,
|
250 |
+
means3D,
|
251 |
+
(glm::vec3*)scales,
|
252 |
+
scale_modifier,
|
253 |
+
(glm::vec4*)rotations,
|
254 |
+
opacities,
|
255 |
+
shs,
|
256 |
+
geomState.clamped,
|
257 |
+
cov3D_precomp,
|
258 |
+
colors_precomp,
|
259 |
+
viewmatrix, projmatrix,
|
260 |
+
(glm::vec3*)cam_pos,
|
261 |
+
width, height,
|
262 |
+
focal_x, focal_y,
|
263 |
+
tan_fovx, tan_fovy,
|
264 |
+
radii,
|
265 |
+
geomState.means2D,
|
266 |
+
geomState.depths,
|
267 |
+
geomState.cov3D,
|
268 |
+
geomState.rgb,
|
269 |
+
geomState.conic_opacity,
|
270 |
+
tile_grid,
|
271 |
+
geomState.tiles_touched,
|
272 |
+
prefiltered
|
273 |
+
), debug)
|
274 |
+
|
275 |
+
// Compute prefix sum over full list of touched tile counts by Gaussians
|
276 |
+
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
|
277 |
+
CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
|
278 |
+
|
279 |
+
// Retrieve total number of Gaussian instances to launch and resize aux buffers
|
280 |
+
int num_rendered;
|
281 |
+
CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
|
282 |
+
|
283 |
+
size_t binning_chunk_size = required<BinningState>(num_rendered);
|
284 |
+
char* binning_chunkptr = binningBuffer(binning_chunk_size);
|
285 |
+
BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
|
286 |
+
|
287 |
+
// For each instance to be rendered, produce adequate [ tile | depth ] key
|
288 |
+
// and corresponding dublicated Gaussian indices to be sorted
|
289 |
+
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
|
290 |
+
P,
|
291 |
+
geomState.means2D,
|
292 |
+
geomState.depths,
|
293 |
+
geomState.point_offsets,
|
294 |
+
binningState.point_list_keys_unsorted,
|
295 |
+
binningState.point_list_unsorted,
|
296 |
+
radii,
|
297 |
+
tile_grid)
|
298 |
+
CHECK_CUDA(, debug)
|
299 |
+
|
300 |
+
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
|
301 |
+
|
302 |
+
// Sort complete list of (duplicated) Gaussian indices by keys
|
303 |
+
CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
|
304 |
+
binningState.list_sorting_space,
|
305 |
+
binningState.sorting_size,
|
306 |
+
binningState.point_list_keys_unsorted, binningState.point_list_keys,
|
307 |
+
binningState.point_list_unsorted, binningState.point_list,
|
308 |
+
num_rendered, 0, 32 + bit), debug)
|
309 |
+
|
310 |
+
CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
|
311 |
+
|
312 |
+
// Identify start and end of per-tile workloads in sorted list
|
313 |
+
if (num_rendered > 0)
|
314 |
+
identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
|
315 |
+
num_rendered,
|
316 |
+
binningState.point_list_keys,
|
317 |
+
imgState.ranges);
|
318 |
+
CHECK_CUDA(, debug)
|
319 |
+
|
320 |
+
// Let each tile blend its range of Gaussians independently in parallel
|
321 |
+
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
|
322 |
+
CHECK_CUDA(FORWARD::render(
|
323 |
+
tile_grid, block,
|
324 |
+
imgState.ranges,
|
325 |
+
binningState.point_list,
|
326 |
+
width, height,
|
327 |
+
geomState.means2D,
|
328 |
+
feature_ptr,
|
329 |
+
geomState.conic_opacity,
|
330 |
+
imgState.accum_alpha,
|
331 |
+
imgState.n_contrib,
|
332 |
+
background,
|
333 |
+
out_color), debug)
|
334 |
+
|
335 |
+
return num_rendered;
|
336 |
+
}
|
337 |
+
|
338 |
+
// Produce necessary gradients for optimization, corresponding
|
339 |
+
// to forward render pass
|
340 |
+
void CudaRasterizer::Rasterizer::backward(
|
341 |
+
const int P, int D, int M, int R,
|
342 |
+
const float* background,
|
343 |
+
const int width, int height,
|
344 |
+
const float* means3D,
|
345 |
+
const float* shs,
|
346 |
+
const float* colors_precomp,
|
347 |
+
const float* scales,
|
348 |
+
const float scale_modifier,
|
349 |
+
const float* rotations,
|
350 |
+
const float* cov3D_precomp,
|
351 |
+
const float* viewmatrix,
|
352 |
+
const float* projmatrix,
|
353 |
+
const float* campos,
|
354 |
+
const float tan_fovx, float tan_fovy,
|
355 |
+
const int* radii,
|
356 |
+
char* geom_buffer,
|
357 |
+
char* binning_buffer,
|
358 |
+
char* img_buffer,
|
359 |
+
const float* dL_dpix,
|
360 |
+
float* dL_dmean2D,
|
361 |
+
float* dL_dconic,
|
362 |
+
float* dL_dopacity,
|
363 |
+
float* dL_dcolor,
|
364 |
+
float* dL_dmean3D,
|
365 |
+
float* dL_dcov3D,
|
366 |
+
float* dL_dsh,
|
367 |
+
float* dL_dscale,
|
368 |
+
float* dL_drot,
|
369 |
+
bool debug)
|
370 |
+
{
|
371 |
+
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
|
372 |
+
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
|
373 |
+
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
|
374 |
+
|
375 |
+
if (radii == nullptr)
|
376 |
+
{
|
377 |
+
radii = geomState.internal_radii;
|
378 |
+
}
|
379 |
+
|
380 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
381 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
382 |
+
|
383 |
+
const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
384 |
+
const dim3 block(BLOCK_X, BLOCK_Y, 1);
|
385 |
+
|
386 |
+
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
|
387 |
+
// opacity and RGB of Gaussians from per-pixel loss gradients.
|
388 |
+
// If we were given precomputed colors and not SHs, use them.
|
389 |
+
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
|
390 |
+
CHECK_CUDA(BACKWARD::render(
|
391 |
+
tile_grid,
|
392 |
+
block,
|
393 |
+
imgState.ranges,
|
394 |
+
binningState.point_list,
|
395 |
+
width, height,
|
396 |
+
background,
|
397 |
+
geomState.means2D,
|
398 |
+
geomState.conic_opacity,
|
399 |
+
color_ptr,
|
400 |
+
imgState.accum_alpha,
|
401 |
+
imgState.n_contrib,
|
402 |
+
dL_dpix,
|
403 |
+
(float3*)dL_dmean2D,
|
404 |
+
(float4*)dL_dconic,
|
405 |
+
dL_dopacity,
|
406 |
+
dL_dcolor), debug)
|
407 |
+
|
408 |
+
// Take care of the rest of preprocessing. Was the precomputed covariance
|
409 |
+
// given to us or a scales/rot pair? If precomputed, pass that. If not,
|
410 |
+
// use the one we computed ourselves.
|
411 |
+
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
|
412 |
+
CHECK_CUDA(BACKWARD::preprocess(P, D, M,
|
413 |
+
(float3*)means3D,
|
414 |
+
radii,
|
415 |
+
shs,
|
416 |
+
geomState.clamped,
|
417 |
+
(glm::vec3*)scales,
|
418 |
+
(glm::vec4*)rotations,
|
419 |
+
scale_modifier,
|
420 |
+
cov3D_ptr,
|
421 |
+
viewmatrix,
|
422 |
+
projmatrix,
|
423 |
+
focal_x, focal_y,
|
424 |
+
tan_fovx, tan_fovy,
|
425 |
+
(glm::vec3*)campos,
|
426 |
+
(float3*)dL_dmean2D,
|
427 |
+
dL_dconic,
|
428 |
+
(glm::vec3*)dL_dmean3D,
|
429 |
+
dL_dcolor,
|
430 |
+
dL_dcov3D,
|
431 |
+
dL_dsh,
|
432 |
+
(glm::vec3*)dL_dscale,
|
433 |
+
(glm::vec4*)dL_drot), debug)
|
434 |
+
}
|
diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include <iostream>
|
15 |
+
#include <vector>
|
16 |
+
#include "rasterizer.h"
|
17 |
+
#include <cuda_runtime_api.h>
|
18 |
+
|
19 |
+
namespace CudaRasterizer
|
20 |
+
{
|
21 |
+
template <typename T>
|
22 |
+
static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
|
23 |
+
{
|
24 |
+
std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
|
25 |
+
ptr = reinterpret_cast<T*>(offset);
|
26 |
+
chunk = reinterpret_cast<char*>(ptr + count);
|
27 |
+
}
|
28 |
+
|
29 |
+
struct GeometryState
|
30 |
+
{
|
31 |
+
size_t scan_size;
|
32 |
+
float* depths;
|
33 |
+
char* scanning_space;
|
34 |
+
bool* clamped;
|
35 |
+
int* internal_radii;
|
36 |
+
float2* means2D;
|
37 |
+
float* cov3D;
|
38 |
+
float4* conic_opacity;
|
39 |
+
float* rgb;
|
40 |
+
uint32_t* point_offsets;
|
41 |
+
uint32_t* tiles_touched;
|
42 |
+
|
43 |
+
static GeometryState fromChunk(char*& chunk, size_t P);
|
44 |
+
};
|
45 |
+
|
46 |
+
struct ImageState
|
47 |
+
{
|
48 |
+
uint2* ranges;
|
49 |
+
uint32_t* n_contrib;
|
50 |
+
float* accum_alpha;
|
51 |
+
|
52 |
+
static ImageState fromChunk(char*& chunk, size_t N);
|
53 |
+
};
|
54 |
+
|
55 |
+
struct BinningState
|
56 |
+
{
|
57 |
+
size_t sorting_size;
|
58 |
+
uint64_t* point_list_keys_unsorted;
|
59 |
+
uint64_t* point_list_keys;
|
60 |
+
uint32_t* point_list_unsorted;
|
61 |
+
uint32_t* point_list;
|
62 |
+
char* list_sorting_space;
|
63 |
+
|
64 |
+
static BinningState fromChunk(char*& chunk, size_t P);
|
65 |
+
};
|
66 |
+
|
67 |
+
template<typename T>
|
68 |
+
size_t required(size_t P)
|
69 |
+
{
|
70 |
+
char* size = nullptr;
|
71 |
+
T::fromChunk(size, P);
|
72 |
+
return ((size_t)size) + 128;
|
73 |
+
}
|
74 |
+
};
|
diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from typing import NamedTuple
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch
|
15 |
+
from . import _C
|
16 |
+
|
17 |
+
def cpu_deep_copy_tuple(input_tuple):
|
18 |
+
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
|
19 |
+
return tuple(copied_tensors)
|
20 |
+
|
21 |
+
def rasterize_gaussians(
|
22 |
+
means3D,
|
23 |
+
means2D,
|
24 |
+
sh,
|
25 |
+
colors_precomp,
|
26 |
+
opacities,
|
27 |
+
scales,
|
28 |
+
rotations,
|
29 |
+
cov3Ds_precomp,
|
30 |
+
raster_settings,
|
31 |
+
):
|
32 |
+
return _RasterizeGaussians.apply(
|
33 |
+
means3D,
|
34 |
+
means2D,
|
35 |
+
sh,
|
36 |
+
colors_precomp,
|
37 |
+
opacities,
|
38 |
+
scales,
|
39 |
+
rotations,
|
40 |
+
cov3Ds_precomp,
|
41 |
+
raster_settings,
|
42 |
+
)
|
43 |
+
|
44 |
+
class _RasterizeGaussians(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(
|
47 |
+
ctx,
|
48 |
+
means3D,
|
49 |
+
means2D,
|
50 |
+
sh,
|
51 |
+
colors_precomp,
|
52 |
+
opacities,
|
53 |
+
scales,
|
54 |
+
rotations,
|
55 |
+
cov3Ds_precomp,
|
56 |
+
raster_settings,
|
57 |
+
):
|
58 |
+
|
59 |
+
# Restructure arguments the way that the C++ lib expects them
|
60 |
+
args = (
|
61 |
+
raster_settings.bg,
|
62 |
+
means3D,
|
63 |
+
colors_precomp,
|
64 |
+
opacities,
|
65 |
+
scales,
|
66 |
+
rotations,
|
67 |
+
raster_settings.scale_modifier,
|
68 |
+
cov3Ds_precomp,
|
69 |
+
raster_settings.viewmatrix,
|
70 |
+
raster_settings.projmatrix,
|
71 |
+
raster_settings.tanfovx,
|
72 |
+
raster_settings.tanfovy,
|
73 |
+
raster_settings.image_height,
|
74 |
+
raster_settings.image_width,
|
75 |
+
sh,
|
76 |
+
raster_settings.sh_degree,
|
77 |
+
raster_settings.campos,
|
78 |
+
raster_settings.prefiltered,
|
79 |
+
raster_settings.debug
|
80 |
+
)
|
81 |
+
|
82 |
+
# Invoke C++/CUDA rasterizer
|
83 |
+
if raster_settings.debug:
|
84 |
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
85 |
+
try:
|
86 |
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
87 |
+
except Exception as ex:
|
88 |
+
torch.save(cpu_args, "snapshot_fw.dump")
|
89 |
+
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
90 |
+
raise ex
|
91 |
+
else:
|
92 |
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
93 |
+
|
94 |
+
# Keep relevant tensors for backward
|
95 |
+
ctx.raster_settings = raster_settings
|
96 |
+
ctx.num_rendered = num_rendered
|
97 |
+
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
|
98 |
+
return color, radii
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def backward(ctx, grad_out_color, _):
|
102 |
+
|
103 |
+
# Restore necessary values from context
|
104 |
+
num_rendered = ctx.num_rendered
|
105 |
+
raster_settings = ctx.raster_settings
|
106 |
+
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
|
107 |
+
|
108 |
+
# Restructure args as C++ method expects them
|
109 |
+
args = (raster_settings.bg,
|
110 |
+
means3D,
|
111 |
+
radii,
|
112 |
+
colors_precomp,
|
113 |
+
scales,
|
114 |
+
rotations,
|
115 |
+
raster_settings.scale_modifier,
|
116 |
+
cov3Ds_precomp,
|
117 |
+
raster_settings.viewmatrix,
|
118 |
+
raster_settings.projmatrix,
|
119 |
+
raster_settings.tanfovx,
|
120 |
+
raster_settings.tanfovy,
|
121 |
+
grad_out_color,
|
122 |
+
sh,
|
123 |
+
raster_settings.sh_degree,
|
124 |
+
raster_settings.campos,
|
125 |
+
geomBuffer,
|
126 |
+
num_rendered,
|
127 |
+
binningBuffer,
|
128 |
+
imgBuffer,
|
129 |
+
raster_settings.debug)
|
130 |
+
|
131 |
+
# Compute gradients for relevant tensors by invoking backward method
|
132 |
+
if raster_settings.debug:
|
133 |
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
134 |
+
try:
|
135 |
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
136 |
+
except Exception as ex:
|
137 |
+
torch.save(cpu_args, "snapshot_bw.dump")
|
138 |
+
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
|
139 |
+
raise ex
|
140 |
+
else:
|
141 |
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
142 |
+
|
143 |
+
grads = (
|
144 |
+
grad_means3D,
|
145 |
+
grad_means2D,
|
146 |
+
grad_sh,
|
147 |
+
grad_colors_precomp,
|
148 |
+
grad_opacities,
|
149 |
+
grad_scales,
|
150 |
+
grad_rotations,
|
151 |
+
grad_cov3Ds_precomp,
|
152 |
+
None,
|
153 |
+
)
|
154 |
+
|
155 |
+
return grads
|
156 |
+
|
157 |
+
class GaussianRasterizationSettings(NamedTuple):
|
158 |
+
image_height: int
|
159 |
+
image_width: int
|
160 |
+
tanfovx : float
|
161 |
+
tanfovy : float
|
162 |
+
bg : torch.Tensor
|
163 |
+
scale_modifier : float
|
164 |
+
viewmatrix : torch.Tensor
|
165 |
+
projmatrix : torch.Tensor
|
166 |
+
sh_degree : int
|
167 |
+
campos : torch.Tensor
|
168 |
+
prefiltered : bool
|
169 |
+
debug : bool
|
170 |
+
|
171 |
+
class GaussianRasterizer(nn.Module):
|
172 |
+
def __init__(self, raster_settings):
|
173 |
+
super().__init__()
|
174 |
+
self.raster_settings = raster_settings
|
175 |
+
|
176 |
+
def markVisible(self, positions):
|
177 |
+
# Mark visible points (based on frustum culling for camera) with a boolean
|
178 |
+
with torch.no_grad():
|
179 |
+
raster_settings = self.raster_settings
|
180 |
+
visible = _C.mark_visible(
|
181 |
+
positions,
|
182 |
+
raster_settings.viewmatrix,
|
183 |
+
raster_settings.projmatrix)
|
184 |
+
|
185 |
+
return visible
|
186 |
+
|
187 |
+
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
|
188 |
+
|
189 |
+
raster_settings = self.raster_settings
|
190 |
+
|
191 |
+
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
|
192 |
+
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
|
193 |
+
|
194 |
+
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
|
195 |
+
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
|
196 |
+
|
197 |
+
if shs is None:
|
198 |
+
shs = torch.Tensor([])
|
199 |
+
if colors_precomp is None:
|
200 |
+
colors_precomp = torch.Tensor([])
|
201 |
+
|
202 |
+
if scales is None:
|
203 |
+
scales = torch.Tensor([])
|
204 |
+
if rotations is None:
|
205 |
+
rotations = torch.Tensor([])
|
206 |
+
if cov3D_precomp is None:
|
207 |
+
cov3D_precomp = torch.Tensor([])
|
208 |
+
|
209 |
+
# Invoke C++/CUDA rasterization routine
|
210 |
+
return rasterize_gaussians(
|
211 |
+
means3D,
|
212 |
+
means2D,
|
213 |
+
shs,
|
214 |
+
colors_precomp,
|
215 |
+
opacities,
|
216 |
+
scales,
|
217 |
+
rotations,
|
218 |
+
cov3D_precomp,
|
219 |
+
raster_settings,
|
220 |
+
)
|
221 |
+
|
diff-gaussian-rasterization/ext.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <torch/extension.h>
|
13 |
+
#include "rasterize_points.h"
|
14 |
+
|
15 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
16 |
+
m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
|
17 |
+
m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
|
18 |
+
m.def("mark_visible", &markVisible);
|
19 |
+
}
|
diff-gaussian-rasterization/rasterize_points.cu
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <math.h>
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <cstdio>
|
15 |
+
#include <sstream>
|
16 |
+
#include <iostream>
|
17 |
+
#include <tuple>
|
18 |
+
#include <stdio.h>
|
19 |
+
#include <cuda_runtime_api.h>
|
20 |
+
#include <memory>
|
21 |
+
#include "cuda_rasterizer/config.h"
|
22 |
+
#include "cuda_rasterizer/rasterizer.h"
|
23 |
+
#include <fstream>
|
24 |
+
#include <string>
|
25 |
+
#include <functional>
|
26 |
+
|
27 |
+
std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
|
28 |
+
auto lambda = [&t](size_t N) {
|
29 |
+
t.resize_({(long long)N});
|
30 |
+
return reinterpret_cast<char*>(t.contiguous().data_ptr());
|
31 |
+
};
|
32 |
+
return lambda;
|
33 |
+
}
|
34 |
+
|
35 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
36 |
+
RasterizeGaussiansCUDA(
|
37 |
+
const torch::Tensor& background,
|
38 |
+
const torch::Tensor& means3D,
|
39 |
+
const torch::Tensor& colors,
|
40 |
+
const torch::Tensor& opacity,
|
41 |
+
const torch::Tensor& scales,
|
42 |
+
const torch::Tensor& rotations,
|
43 |
+
const float scale_modifier,
|
44 |
+
const torch::Tensor& cov3D_precomp,
|
45 |
+
const torch::Tensor& viewmatrix,
|
46 |
+
const torch::Tensor& projmatrix,
|
47 |
+
const float tan_fovx,
|
48 |
+
const float tan_fovy,
|
49 |
+
const int image_height,
|
50 |
+
const int image_width,
|
51 |
+
const torch::Tensor& sh,
|
52 |
+
const int degree,
|
53 |
+
const torch::Tensor& campos,
|
54 |
+
const bool prefiltered,
|
55 |
+
const bool debug)
|
56 |
+
{
|
57 |
+
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
|
58 |
+
AT_ERROR("means3D must have dimensions (num_points, 3)");
|
59 |
+
}
|
60 |
+
|
61 |
+
const int P = means3D.size(0);
|
62 |
+
const int H = image_height;
|
63 |
+
const int W = image_width;
|
64 |
+
|
65 |
+
auto int_opts = means3D.options().dtype(torch::kInt32);
|
66 |
+
auto float_opts = means3D.options().dtype(torch::kFloat32);
|
67 |
+
|
68 |
+
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
|
69 |
+
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
|
70 |
+
|
71 |
+
torch::Device device(torch::kCUDA);
|
72 |
+
torch::TensorOptions options(torch::kByte);
|
73 |
+
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
|
74 |
+
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
|
75 |
+
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
|
76 |
+
std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
|
77 |
+
std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
|
78 |
+
std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
|
79 |
+
|
80 |
+
int rendered = 0;
|
81 |
+
if(P != 0)
|
82 |
+
{
|
83 |
+
int M = 0;
|
84 |
+
if(sh.size(0) != 0)
|
85 |
+
{
|
86 |
+
M = sh.size(1);
|
87 |
+
}
|
88 |
+
|
89 |
+
rendered = CudaRasterizer::Rasterizer::forward(
|
90 |
+
geomFunc,
|
91 |
+
binningFunc,
|
92 |
+
imgFunc,
|
93 |
+
P, degree, M,
|
94 |
+
background.contiguous().data<float>(),
|
95 |
+
W, H,
|
96 |
+
means3D.contiguous().data<float>(),
|
97 |
+
sh.contiguous().data_ptr<float>(),
|
98 |
+
colors.contiguous().data<float>(),
|
99 |
+
opacity.contiguous().data<float>(),
|
100 |
+
scales.contiguous().data_ptr<float>(),
|
101 |
+
scale_modifier,
|
102 |
+
rotations.contiguous().data_ptr<float>(),
|
103 |
+
cov3D_precomp.contiguous().data<float>(),
|
104 |
+
viewmatrix.contiguous().data<float>(),
|
105 |
+
projmatrix.contiguous().data<float>(),
|
106 |
+
campos.contiguous().data<float>(),
|
107 |
+
tan_fovx,
|
108 |
+
tan_fovy,
|
109 |
+
prefiltered,
|
110 |
+
out_color.contiguous().data<float>(),
|
111 |
+
radii.contiguous().data<int>(),
|
112 |
+
debug);
|
113 |
+
}
|
114 |
+
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
|
115 |
+
}
|
116 |
+
|
117 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
118 |
+
RasterizeGaussiansBackwardCUDA(
|
119 |
+
const torch::Tensor& background,
|
120 |
+
const torch::Tensor& means3D,
|
121 |
+
const torch::Tensor& radii,
|
122 |
+
const torch::Tensor& colors,
|
123 |
+
const torch::Tensor& scales,
|
124 |
+
const torch::Tensor& rotations,
|
125 |
+
const float scale_modifier,
|
126 |
+
const torch::Tensor& cov3D_precomp,
|
127 |
+
const torch::Tensor& viewmatrix,
|
128 |
+
const torch::Tensor& projmatrix,
|
129 |
+
const float tan_fovx,
|
130 |
+
const float tan_fovy,
|
131 |
+
const torch::Tensor& dL_dout_color,
|
132 |
+
const torch::Tensor& sh,
|
133 |
+
const int degree,
|
134 |
+
const torch::Tensor& campos,
|
135 |
+
const torch::Tensor& geomBuffer,
|
136 |
+
const int R,
|
137 |
+
const torch::Tensor& binningBuffer,
|
138 |
+
const torch::Tensor& imageBuffer,
|
139 |
+
const bool debug)
|
140 |
+
{
|
141 |
+
const int P = means3D.size(0);
|
142 |
+
const int H = dL_dout_color.size(1);
|
143 |
+
const int W = dL_dout_color.size(2);
|
144 |
+
|
145 |
+
int M = 0;
|
146 |
+
if(sh.size(0) != 0)
|
147 |
+
{
|
148 |
+
M = sh.size(1);
|
149 |
+
}
|
150 |
+
|
151 |
+
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
|
152 |
+
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
|
153 |
+
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
|
154 |
+
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
|
155 |
+
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
|
156 |
+
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
|
157 |
+
torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
|
158 |
+
torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
|
159 |
+
torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
|
160 |
+
|
161 |
+
if(P != 0)
|
162 |
+
{
|
163 |
+
CudaRasterizer::Rasterizer::backward(P, degree, M, R,
|
164 |
+
background.contiguous().data<float>(),
|
165 |
+
W, H,
|
166 |
+
means3D.contiguous().data<float>(),
|
167 |
+
sh.contiguous().data<float>(),
|
168 |
+
colors.contiguous().data<float>(),
|
169 |
+
scales.data_ptr<float>(),
|
170 |
+
scale_modifier,
|
171 |
+
rotations.data_ptr<float>(),
|
172 |
+
cov3D_precomp.contiguous().data<float>(),
|
173 |
+
viewmatrix.contiguous().data<float>(),
|
174 |
+
projmatrix.contiguous().data<float>(),
|
175 |
+
campos.contiguous().data<float>(),
|
176 |
+
tan_fovx,
|
177 |
+
tan_fovy,
|
178 |
+
radii.contiguous().data<int>(),
|
179 |
+
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
|
180 |
+
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
|
181 |
+
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
|
182 |
+
dL_dout_color.contiguous().data<float>(),
|
183 |
+
dL_dmeans2D.contiguous().data<float>(),
|
184 |
+
dL_dconic.contiguous().data<float>(),
|
185 |
+
dL_dopacity.contiguous().data<float>(),
|
186 |
+
dL_dcolors.contiguous().data<float>(),
|
187 |
+
dL_dmeans3D.contiguous().data<float>(),
|
188 |
+
dL_dcov3D.contiguous().data<float>(),
|
189 |
+
dL_dsh.contiguous().data<float>(),
|
190 |
+
dL_dscales.contiguous().data<float>(),
|
191 |
+
dL_drotations.contiguous().data<float>(),
|
192 |
+
debug);
|
193 |
+
}
|
194 |
+
|
195 |
+
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
|
196 |
+
}
|
197 |
+
|
198 |
+
torch::Tensor markVisible(
|
199 |
+
torch::Tensor& means3D,
|
200 |
+
torch::Tensor& viewmatrix,
|
201 |
+
torch::Tensor& projmatrix)
|
202 |
+
{
|
203 |
+
const int P = means3D.size(0);
|
204 |
+
|
205 |
+
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
|
206 |
+
|
207 |
+
if(P != 0)
|
208 |
+
{
|
209 |
+
CudaRasterizer::Rasterizer::markVisible(P,
|
210 |
+
means3D.contiguous().data<float>(),
|
211 |
+
viewmatrix.contiguous().data<float>(),
|
212 |
+
projmatrix.contiguous().data<float>(),
|
213 |
+
present.contiguous().data<bool>());
|
214 |
+
}
|
215 |
+
|
216 |
+
return present;
|
217 |
+
}
|
diff-gaussian-rasterization/rasterize_points.h
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact [email protected]
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <cstdio>
|
15 |
+
#include <tuple>
|
16 |
+
#include <string>
|
17 |
+
|
18 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
19 |
+
RasterizeGaussiansCUDA(
|
20 |
+
const torch::Tensor& background,
|
21 |
+
const torch::Tensor& means3D,
|
22 |
+
const torch::Tensor& colors,
|
23 |
+
const torch::Tensor& opacity,
|
24 |
+
const torch::Tensor& scales,
|
25 |
+
const torch::Tensor& rotations,
|
26 |
+
const float scale_modifier,
|
27 |
+
const torch::Tensor& cov3D_precomp,
|
28 |
+
const torch::Tensor& viewmatrix,
|
29 |
+
const torch::Tensor& projmatrix,
|
30 |
+
const float tan_fovx,
|
31 |
+
const float tan_fovy,
|
32 |
+
const int image_height,
|
33 |
+
const int image_width,
|
34 |
+
const torch::Tensor& sh,
|
35 |
+
const int degree,
|
36 |
+
const torch::Tensor& campos,
|
37 |
+
const bool prefiltered,
|
38 |
+
const bool debug);
|
39 |
+
|
40 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
41 |
+
RasterizeGaussiansBackwardCUDA(
|
42 |
+
const torch::Tensor& background,
|
43 |
+
const torch::Tensor& means3D,
|
44 |
+
const torch::Tensor& radii,
|
45 |
+
const torch::Tensor& colors,
|
46 |
+
const torch::Tensor& scales,
|
47 |
+
const torch::Tensor& rotations,
|
48 |
+
const float scale_modifier,
|
49 |
+
const torch::Tensor& cov3D_precomp,
|
50 |
+
const torch::Tensor& viewmatrix,
|
51 |
+
const torch::Tensor& projmatrix,
|
52 |
+
const float tan_fovx,
|
53 |
+
const float tan_fovy,
|
54 |
+
const torch::Tensor& dL_dout_color,
|
55 |
+
const torch::Tensor& sh,
|
56 |
+
const int degree,
|
57 |
+
const torch::Tensor& campos,
|
58 |
+
const torch::Tensor& geomBuffer,
|
59 |
+
const int R,
|
60 |
+
const torch::Tensor& binningBuffer,
|
61 |
+
const torch::Tensor& imageBuffer,
|
62 |
+
const bool debug);
|
63 |
+
|
64 |
+
torch::Tensor markVisible(
|
65 |
+
torch::Tensor& means3D,
|
66 |
+
torch::Tensor& viewmatrix,
|
67 |
+
torch::Tensor& projmatrix);
|
diff-gaussian-rasterization/setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from setuptools import setup
|
13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
14 |
+
import os
|
15 |
+
os.path.dirname(os.path.abspath(__file__))
|
16 |
+
|
17 |
+
setup(
|
18 |
+
name="diff_gaussian_rasterization",
|
19 |
+
packages=['diff_gaussian_rasterization'],
|
20 |
+
ext_modules=[
|
21 |
+
CUDAExtension(
|
22 |
+
name="diff_gaussian_rasterization._C",
|
23 |
+
sources=[
|
24 |
+
"cuda_rasterizer/rasterizer_impl.cu",
|
25 |
+
"cuda_rasterizer/forward.cu",
|
26 |
+
"cuda_rasterizer/backward.cu",
|
27 |
+
"rasterize_points.cu",
|
28 |
+
"ext.cpp"],
|
29 |
+
extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
|
30 |
+
],
|
31 |
+
cmdclass={
|
32 |
+
'build_ext': BuildExtension
|
33 |
+
}
|
34 |
+
)
|
diff-gaussian-rasterization/third_party/stbi_image_write.h
ADDED
@@ -0,0 +1,1724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* stb_image_write - v1.16 - public domain - http://nothings.org/stb
|
2 |
+
writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015
|
3 |
+
no warranty implied; use at your own risk
|
4 |
+
|
5 |
+
Before #including,
|
6 |
+
|
7 |
+
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
8 |
+
|
9 |
+
in the file that you want to have the implementation.
|
10 |
+
|
11 |
+
Will probably not work correctly with strict-aliasing optimizations.
|
12 |
+
|
13 |
+
ABOUT:
|
14 |
+
|
15 |
+
This header file is a library for writing images to C stdio or a callback.
|
16 |
+
|
17 |
+
The PNG output is not optimal; it is 20-50% larger than the file
|
18 |
+
written by a decent optimizing implementation; though providing a custom
|
19 |
+
zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that.
|
20 |
+
This library is designed for source code compactness and simplicity,
|
21 |
+
not optimal image file size or run-time performance.
|
22 |
+
|
23 |
+
BUILDING:
|
24 |
+
|
25 |
+
You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.
|
26 |
+
You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace
|
27 |
+
malloc,realloc,free.
|
28 |
+
You can #define STBIW_MEMMOVE() to replace memmove()
|
29 |
+
You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function
|
30 |
+
for PNG compression (instead of the builtin one), it must have the following signature:
|
31 |
+
unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality);
|
32 |
+
The returned data will be freed with STBIW_FREE() (free() by default),
|
33 |
+
so it must be heap allocated with STBIW_MALLOC() (malloc() by default),
|
34 |
+
|
35 |
+
UNICODE:
|
36 |
+
|
37 |
+
If compiling for Windows and you wish to use Unicode filenames, compile
|
38 |
+
with
|
39 |
+
#define STBIW_WINDOWS_UTF8
|
40 |
+
and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert
|
41 |
+
Windows wchar_t filenames to utf8.
|
42 |
+
|
43 |
+
USAGE:
|
44 |
+
|
45 |
+
There are five functions, one for each image file format:
|
46 |
+
|
47 |
+
int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
|
48 |
+
int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
49 |
+
int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
50 |
+
int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality);
|
51 |
+
int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
52 |
+
|
53 |
+
void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically
|
54 |
+
|
55 |
+
There are also five equivalent functions that use an arbitrary write function. You are
|
56 |
+
expected to open/close your file-equivalent before and after calling these:
|
57 |
+
|
58 |
+
int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
|
59 |
+
int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
60 |
+
int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
61 |
+
int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
|
62 |
+
int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
|
63 |
+
|
64 |
+
where the callback is:
|
65 |
+
void stbi_write_func(void *context, void *data, int size);
|
66 |
+
|
67 |
+
You can configure it with these global variables:
|
68 |
+
int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE
|
69 |
+
int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression
|
70 |
+
int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode
|
71 |
+
|
72 |
+
|
73 |
+
You can define STBI_WRITE_NO_STDIO to disable the file variant of these
|
74 |
+
functions, so the library will not use stdio.h at all. However, this will
|
75 |
+
also disable HDR writing, because it requires stdio for formatted output.
|
76 |
+
|
77 |
+
Each function returns 0 on failure and non-0 on success.
|
78 |
+
|
79 |
+
The functions create an image file defined by the parameters. The image
|
80 |
+
is a rectangle of pixels stored from left-to-right, top-to-bottom.
|
81 |
+
Each pixel contains 'comp' channels of data stored interleaved with 8-bits
|
82 |
+
per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is
|
83 |
+
monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.
|
84 |
+
The *data pointer points to the first byte of the top-left-most pixel.
|
85 |
+
For PNG, "stride_in_bytes" is the distance in bytes from the first byte of
|
86 |
+
a row of pixels to the first byte of the next row of pixels.
|
87 |
+
|
88 |
+
PNG creates output files with the same number of components as the input.
|
89 |
+
The BMP format expands Y to RGB in the file format and does not
|
90 |
+
output alpha.
|
91 |
+
|
92 |
+
PNG supports writing rectangles of data even when the bytes storing rows of
|
93 |
+
data are not consecutive in memory (e.g. sub-rectangles of a larger image),
|
94 |
+
by supplying the stride between the beginning of adjacent rows. The other
|
95 |
+
formats do not. (Thus you cannot write a native-format BMP through the BMP
|
96 |
+
writer, both because it is in BGR order and because it may have padding
|
97 |
+
at the end of the line.)
|
98 |
+
|
99 |
+
PNG allows you to set the deflate compression level by setting the global
|
100 |
+
variable 'stbi_write_png_compression_level' (it defaults to 8).
|
101 |
+
|
102 |
+
HDR expects linear float data. Since the format is always 32-bit rgb(e)
|
103 |
+
data, alpha (if provided) is discarded, and for monochrome data it is
|
104 |
+
replicated across all three channels.
|
105 |
+
|
106 |
+
TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed
|
107 |
+
data, set the global variable 'stbi_write_tga_with_rle' to 0.
|
108 |
+
|
109 |
+
JPEG does ignore alpha channels in input data; quality is between 1 and 100.
|
110 |
+
Higher quality looks better but results in a bigger image.
|
111 |
+
JPEG baseline (no JPEG progressive).
|
112 |
+
|
113 |
+
CREDITS:
|
114 |
+
|
115 |
+
|
116 |
+
Sean Barrett - PNG/BMP/TGA
|
117 |
+
Baldur Karlsson - HDR
|
118 |
+
Jean-Sebastien Guay - TGA monochrome
|
119 |
+
Tim Kelsey - misc enhancements
|
120 |
+
Alan Hickman - TGA RLE
|
121 |
+
Emmanuel Julien - initial file IO callback implementation
|
122 |
+
Jon Olick - original jo_jpeg.cpp code
|
123 |
+
Daniel Gibson - integrate JPEG, allow external zlib
|
124 |
+
Aarni Koskela - allow choosing PNG filter
|
125 |
+
|
126 |
+
bugfixes:
|
127 |
+
github:Chribba
|
128 |
+
Guillaume Chereau
|
129 |
+
github:jry2
|
130 |
+
github:romigrou
|
131 |
+
Sergio Gonzalez
|
132 |
+
Jonas Karlsson
|
133 |
+
Filip Wasil
|
134 |
+
Thatcher Ulrich
|
135 |
+
github:poppolopoppo
|
136 |
+
Patrick Boettcher
|
137 |
+
github:xeekworx
|
138 |
+
Cap Petschulat
|
139 |
+
Simon Rodriguez
|
140 |
+
Ivan Tikhonov
|
141 |
+
github:ignotion
|
142 |
+
Adam Schackart
|
143 |
+
Andrew Kensler
|
144 |
+
|
145 |
+
LICENSE
|
146 |
+
|
147 |
+
See end of file for license information.
|
148 |
+
|
149 |
+
*/
|
150 |
+
|
151 |
+
#ifndef INCLUDE_STB_IMAGE_WRITE_H
|
152 |
+
#define INCLUDE_STB_IMAGE_WRITE_H
|
153 |
+
|
154 |
+
#include <stdlib.h>
|
155 |
+
|
156 |
+
// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline'
|
157 |
+
#ifndef STBIWDEF
|
158 |
+
#ifdef STB_IMAGE_WRITE_STATIC
|
159 |
+
#define STBIWDEF static
|
160 |
+
#else
|
161 |
+
#ifdef __cplusplus
|
162 |
+
#define STBIWDEF extern "C"
|
163 |
+
#else
|
164 |
+
#define STBIWDEF extern
|
165 |
+
#endif
|
166 |
+
#endif
|
167 |
+
#endif
|
168 |
+
|
169 |
+
#ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations
|
170 |
+
STBIWDEF int stbi_write_tga_with_rle;
|
171 |
+
STBIWDEF int stbi_write_png_compression_level;
|
172 |
+
STBIWDEF int stbi_write_force_png_filter;
|
173 |
+
#endif
|
174 |
+
|
175 |
+
#ifndef STBI_WRITE_NO_STDIO
|
176 |
+
STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
|
177 |
+
STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
178 |
+
STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
179 |
+
STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
180 |
+
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
|
181 |
+
|
182 |
+
#ifdef STBIW_WINDOWS_UTF8
|
183 |
+
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
|
184 |
+
#endif
|
185 |
+
#endif
|
186 |
+
|
187 |
+
typedef void stbi_write_func(void *context, void *data, int size);
|
188 |
+
|
189 |
+
STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
|
190 |
+
STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
191 |
+
STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
192 |
+
STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
|
193 |
+
STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
|
194 |
+
|
195 |
+
STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean);
|
196 |
+
|
197 |
+
#endif//INCLUDE_STB_IMAGE_WRITE_H
|
198 |
+
|
199 |
+
#ifdef STB_IMAGE_WRITE_IMPLEMENTATION
|
200 |
+
|
201 |
+
#ifdef _WIN32
|
202 |
+
#ifndef _CRT_SECURE_NO_WARNINGS
|
203 |
+
#define _CRT_SECURE_NO_WARNINGS
|
204 |
+
#endif
|
205 |
+
#ifndef _CRT_NONSTDC_NO_DEPRECATE
|
206 |
+
#define _CRT_NONSTDC_NO_DEPRECATE
|
207 |
+
#endif
|
208 |
+
#endif
|
209 |
+
|
210 |
+
#ifndef STBI_WRITE_NO_STDIO
|
211 |
+
#include <stdio.h>
|
212 |
+
#endif // STBI_WRITE_NO_STDIO
|
213 |
+
|
214 |
+
#include <stdarg.h>
|
215 |
+
#include <stdlib.h>
|
216 |
+
#include <string.h>
|
217 |
+
#include <math.h>
|
218 |
+
|
219 |
+
#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED))
|
220 |
+
// ok
|
221 |
+
#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED)
|
222 |
+
// ok
|
223 |
+
#else
|
224 |
+
#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)."
|
225 |
+
#endif
|
226 |
+
|
227 |
+
#ifndef STBIW_MALLOC
|
228 |
+
#define STBIW_MALLOC(sz) malloc(sz)
|
229 |
+
#define STBIW_REALLOC(p,newsz) realloc(p,newsz)
|
230 |
+
#define STBIW_FREE(p) free(p)
|
231 |
+
#endif
|
232 |
+
|
233 |
+
#ifndef STBIW_REALLOC_SIZED
|
234 |
+
#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz)
|
235 |
+
#endif
|
236 |
+
|
237 |
+
|
238 |
+
#ifndef STBIW_MEMMOVE
|
239 |
+
#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)
|
240 |
+
#endif
|
241 |
+
|
242 |
+
|
243 |
+
#ifndef STBIW_ASSERT
|
244 |
+
#include <assert.h>
|
245 |
+
#define STBIW_ASSERT(x) assert(x)
|
246 |
+
#endif
|
247 |
+
|
248 |
+
#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff)
|
249 |
+
|
250 |
+
#ifdef STB_IMAGE_WRITE_STATIC
|
251 |
+
static int stbi_write_png_compression_level = 8;
|
252 |
+
static int stbi_write_tga_with_rle = 1;
|
253 |
+
static int stbi_write_force_png_filter = -1;
|
254 |
+
#else
|
255 |
+
int stbi_write_png_compression_level = 8;
|
256 |
+
int stbi_write_tga_with_rle = 1;
|
257 |
+
int stbi_write_force_png_filter = -1;
|
258 |
+
#endif
|
259 |
+
|
260 |
+
static int stbi__flip_vertically_on_write = 0;
|
261 |
+
|
262 |
+
STBIWDEF void stbi_flip_vertically_on_write(int flag)
|
263 |
+
{
|
264 |
+
stbi__flip_vertically_on_write = flag;
|
265 |
+
}
|
266 |
+
|
267 |
+
typedef struct
|
268 |
+
{
|
269 |
+
stbi_write_func *func;
|
270 |
+
void *context;
|
271 |
+
unsigned char buffer[64];
|
272 |
+
int buf_used;
|
273 |
+
} stbi__write_context;
|
274 |
+
|
275 |
+
// initialize a callback-based context
|
276 |
+
static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context)
|
277 |
+
{
|
278 |
+
s->func = c;
|
279 |
+
s->context = context;
|
280 |
+
}
|
281 |
+
|
282 |
+
#ifndef STBI_WRITE_NO_STDIO
|
283 |
+
|
284 |
+
static void stbi__stdio_write(void *context, void *data, int size)
|
285 |
+
{
|
286 |
+
fwrite(data,1,size,(FILE*) context);
|
287 |
+
}
|
288 |
+
|
289 |
+
#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
|
290 |
+
#ifdef __cplusplus
|
291 |
+
#define STBIW_EXTERN extern "C"
|
292 |
+
#else
|
293 |
+
#define STBIW_EXTERN extern
|
294 |
+
#endif
|
295 |
+
STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);
|
296 |
+
STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);
|
297 |
+
|
298 |
+
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)
|
299 |
+
{
|
300 |
+
return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);
|
301 |
+
}
|
302 |
+
#endif
|
303 |
+
|
304 |
+
static FILE *stbiw__fopen(char const *filename, char const *mode)
|
305 |
+
{
|
306 |
+
FILE *f;
|
307 |
+
#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
|
308 |
+
wchar_t wMode[64];
|
309 |
+
wchar_t wFilename[1024];
|
310 |
+
if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))
|
311 |
+
return 0;
|
312 |
+
|
313 |
+
if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))
|
314 |
+
return 0;
|
315 |
+
|
316 |
+
#if defined(_MSC_VER) && _MSC_VER >= 1400
|
317 |
+
if (0 != _wfopen_s(&f, wFilename, wMode))
|
318 |
+
f = 0;
|
319 |
+
#else
|
320 |
+
f = _wfopen(wFilename, wMode);
|
321 |
+
#endif
|
322 |
+
|
323 |
+
#elif defined(_MSC_VER) && _MSC_VER >= 1400
|
324 |
+
if (0 != fopen_s(&f, filename, mode))
|
325 |
+
f=0;
|
326 |
+
#else
|
327 |
+
f = fopen(filename, mode);
|
328 |
+
#endif
|
329 |
+
return f;
|
330 |
+
}
|
331 |
+
|
332 |
+
static int stbi__start_write_file(stbi__write_context *s, const char *filename)
|
333 |
+
{
|
334 |
+
FILE *f = stbiw__fopen(filename, "wb");
|
335 |
+
stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f);
|
336 |
+
return f != NULL;
|
337 |
+
}
|
338 |
+
|
339 |
+
static void stbi__end_write_file(stbi__write_context *s)
|
340 |
+
{
|
341 |
+
fclose((FILE *)s->context);
|
342 |
+
}
|
343 |
+
|
344 |
+
#endif // !STBI_WRITE_NO_STDIO
|
345 |
+
|
346 |
+
typedef unsigned int stbiw_uint32;
|
347 |
+
typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];
|
348 |
+
|
349 |
+
static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v)
|
350 |
+
{
|
351 |
+
while (*fmt) {
|
352 |
+
switch (*fmt++) {
|
353 |
+
case ' ': break;
|
354 |
+
case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int));
|
355 |
+
s->func(s->context,&x,1);
|
356 |
+
break; }
|
357 |
+
case '2': { int x = va_arg(v,int);
|
358 |
+
unsigned char b[2];
|
359 |
+
b[0] = STBIW_UCHAR(x);
|
360 |
+
b[1] = STBIW_UCHAR(x>>8);
|
361 |
+
s->func(s->context,b,2);
|
362 |
+
break; }
|
363 |
+
case '4': { stbiw_uint32 x = va_arg(v,int);
|
364 |
+
unsigned char b[4];
|
365 |
+
b[0]=STBIW_UCHAR(x);
|
366 |
+
b[1]=STBIW_UCHAR(x>>8);
|
367 |
+
b[2]=STBIW_UCHAR(x>>16);
|
368 |
+
b[3]=STBIW_UCHAR(x>>24);
|
369 |
+
s->func(s->context,b,4);
|
370 |
+
break; }
|
371 |
+
default:
|
372 |
+
STBIW_ASSERT(0);
|
373 |
+
return;
|
374 |
+
}
|
375 |
+
}
|
376 |
+
}
|
377 |
+
|
378 |
+
static void stbiw__writef(stbi__write_context *s, const char *fmt, ...)
|
379 |
+
{
|
380 |
+
va_list v;
|
381 |
+
va_start(v, fmt);
|
382 |
+
stbiw__writefv(s, fmt, v);
|
383 |
+
va_end(v);
|
384 |
+
}
|
385 |
+
|
386 |
+
static void stbiw__write_flush(stbi__write_context *s)
|
387 |
+
{
|
388 |
+
if (s->buf_used) {
|
389 |
+
s->func(s->context, &s->buffer, s->buf_used);
|
390 |
+
s->buf_used = 0;
|
391 |
+
}
|
392 |
+
}
|
393 |
+
|
394 |
+
static void stbiw__putc(stbi__write_context *s, unsigned char c)
|
395 |
+
{
|
396 |
+
s->func(s->context, &c, 1);
|
397 |
+
}
|
398 |
+
|
399 |
+
static void stbiw__write1(stbi__write_context *s, unsigned char a)
|
400 |
+
{
|
401 |
+
if ((size_t)s->buf_used + 1 > sizeof(s->buffer))
|
402 |
+
stbiw__write_flush(s);
|
403 |
+
s->buffer[s->buf_used++] = a;
|
404 |
+
}
|
405 |
+
|
406 |
+
static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c)
|
407 |
+
{
|
408 |
+
int n;
|
409 |
+
if ((size_t)s->buf_used + 3 > sizeof(s->buffer))
|
410 |
+
stbiw__write_flush(s);
|
411 |
+
n = s->buf_used;
|
412 |
+
s->buf_used = n+3;
|
413 |
+
s->buffer[n+0] = a;
|
414 |
+
s->buffer[n+1] = b;
|
415 |
+
s->buffer[n+2] = c;
|
416 |
+
}
|
417 |
+
|
418 |
+
static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d)
|
419 |
+
{
|
420 |
+
unsigned char bg[3] = { 255, 0, 255}, px[3];
|
421 |
+
int k;
|
422 |
+
|
423 |
+
if (write_alpha < 0)
|
424 |
+
stbiw__write1(s, d[comp - 1]);
|
425 |
+
|
426 |
+
switch (comp) {
|
427 |
+
case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case
|
428 |
+
case 1:
|
429 |
+
if (expand_mono)
|
430 |
+
stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp
|
431 |
+
else
|
432 |
+
stbiw__write1(s, d[0]); // monochrome TGA
|
433 |
+
break;
|
434 |
+
case 4:
|
435 |
+
if (!write_alpha) {
|
436 |
+
// composite against pink background
|
437 |
+
for (k = 0; k < 3; ++k)
|
438 |
+
px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255;
|
439 |
+
stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]);
|
440 |
+
break;
|
441 |
+
}
|
442 |
+
/* FALLTHROUGH */
|
443 |
+
case 3:
|
444 |
+
stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]);
|
445 |
+
break;
|
446 |
+
}
|
447 |
+
if (write_alpha > 0)
|
448 |
+
stbiw__write1(s, d[comp - 1]);
|
449 |
+
}
|
450 |
+
|
451 |
+
static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)
|
452 |
+
{
|
453 |
+
stbiw_uint32 zero = 0;
|
454 |
+
int i,j, j_end;
|
455 |
+
|
456 |
+
if (y <= 0)
|
457 |
+
return;
|
458 |
+
|
459 |
+
if (stbi__flip_vertically_on_write)
|
460 |
+
vdir *= -1;
|
461 |
+
|
462 |
+
if (vdir < 0) {
|
463 |
+
j_end = -1; j = y-1;
|
464 |
+
} else {
|
465 |
+
j_end = y; j = 0;
|
466 |
+
}
|
467 |
+
|
468 |
+
for (; j != j_end; j += vdir) {
|
469 |
+
for (i=0; i < x; ++i) {
|
470 |
+
unsigned char *d = (unsigned char *) data + (j*x+i)*comp;
|
471 |
+
stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d);
|
472 |
+
}
|
473 |
+
stbiw__write_flush(s);
|
474 |
+
s->func(s->context, &zero, scanline_pad);
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)
|
479 |
+
{
|
480 |
+
if (y < 0 || x < 0) {
|
481 |
+
return 0;
|
482 |
+
} else {
|
483 |
+
va_list v;
|
484 |
+
va_start(v, fmt);
|
485 |
+
stbiw__writefv(s, fmt, v);
|
486 |
+
va_end(v);
|
487 |
+
stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono);
|
488 |
+
return 1;
|
489 |
+
}
|
490 |
+
}
|
491 |
+
|
492 |
+
static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data)
|
493 |
+
{
|
494 |
+
if (comp != 4) {
|
495 |
+
// write RGB bitmap
|
496 |
+
int pad = (-x*3) & 3;
|
497 |
+
return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad,
|
498 |
+
"11 4 22 4" "4 44 22 444444",
|
499 |
+
'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header
|
500 |
+
40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header
|
501 |
+
} else {
|
502 |
+
// RGBA bitmaps need a v4 header
|
503 |
+
// use BI_BITFIELDS mode with 32bpp and alpha mask
|
504 |
+
// (straight BI_RGB with alpha mask doesn't work in most readers)
|
505 |
+
return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0,
|
506 |
+
"11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444",
|
507 |
+
'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header
|
508 |
+
108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
|
513 |
+
{
|
514 |
+
stbi__write_context s = { 0 };
|
515 |
+
stbi__start_write_callbacks(&s, func, context);
|
516 |
+
return stbi_write_bmp_core(&s, x, y, comp, data);
|
517 |
+
}
|
518 |
+
|
519 |
+
#ifndef STBI_WRITE_NO_STDIO
|
520 |
+
STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)
|
521 |
+
{
|
522 |
+
stbi__write_context s = { 0 };
|
523 |
+
if (stbi__start_write_file(&s,filename)) {
|
524 |
+
int r = stbi_write_bmp_core(&s, x, y, comp, data);
|
525 |
+
stbi__end_write_file(&s);
|
526 |
+
return r;
|
527 |
+
} else
|
528 |
+
return 0;
|
529 |
+
}
|
530 |
+
#endif //!STBI_WRITE_NO_STDIO
|
531 |
+
|
532 |
+
static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data)
|
533 |
+
{
|
534 |
+
int has_alpha = (comp == 2 || comp == 4);
|
535 |
+
int colorbytes = has_alpha ? comp-1 : comp;
|
536 |
+
int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3
|
537 |
+
|
538 |
+
if (y < 0 || x < 0)
|
539 |
+
return 0;
|
540 |
+
|
541 |
+
if (!stbi_write_tga_with_rle) {
|
542 |
+
return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0,
|
543 |
+
"111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8);
|
544 |
+
} else {
|
545 |
+
int i,j,k;
|
546 |
+
int jend, jdir;
|
547 |
+
|
548 |
+
stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8);
|
549 |
+
|
550 |
+
if (stbi__flip_vertically_on_write) {
|
551 |
+
j = 0;
|
552 |
+
jend = y;
|
553 |
+
jdir = 1;
|
554 |
+
} else {
|
555 |
+
j = y-1;
|
556 |
+
jend = -1;
|
557 |
+
jdir = -1;
|
558 |
+
}
|
559 |
+
for (; j != jend; j += jdir) {
|
560 |
+
unsigned char *row = (unsigned char *) data + j * x * comp;
|
561 |
+
int len;
|
562 |
+
|
563 |
+
for (i = 0; i < x; i += len) {
|
564 |
+
unsigned char *begin = row + i * comp;
|
565 |
+
int diff = 1;
|
566 |
+
len = 1;
|
567 |
+
|
568 |
+
if (i < x - 1) {
|
569 |
+
++len;
|
570 |
+
diff = memcmp(begin, row + (i + 1) * comp, comp);
|
571 |
+
if (diff) {
|
572 |
+
const unsigned char *prev = begin;
|
573 |
+
for (k = i + 2; k < x && len < 128; ++k) {
|
574 |
+
if (memcmp(prev, row + k * comp, comp)) {
|
575 |
+
prev += comp;
|
576 |
+
++len;
|
577 |
+
} else {
|
578 |
+
--len;
|
579 |
+
break;
|
580 |
+
}
|
581 |
+
}
|
582 |
+
} else {
|
583 |
+
for (k = i + 2; k < x && len < 128; ++k) {
|
584 |
+
if (!memcmp(begin, row + k * comp, comp)) {
|
585 |
+
++len;
|
586 |
+
} else {
|
587 |
+
break;
|
588 |
+
}
|
589 |
+
}
|
590 |
+
}
|
591 |
+
}
|
592 |
+
|
593 |
+
if (diff) {
|
594 |
+
unsigned char header = STBIW_UCHAR(len - 1);
|
595 |
+
stbiw__write1(s, header);
|
596 |
+
for (k = 0; k < len; ++k) {
|
597 |
+
stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp);
|
598 |
+
}
|
599 |
+
} else {
|
600 |
+
unsigned char header = STBIW_UCHAR(len - 129);
|
601 |
+
stbiw__write1(s, header);
|
602 |
+
stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin);
|
603 |
+
}
|
604 |
+
}
|
605 |
+
}
|
606 |
+
stbiw__write_flush(s);
|
607 |
+
}
|
608 |
+
return 1;
|
609 |
+
}
|
610 |
+
|
611 |
+
STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
|
612 |
+
{
|
613 |
+
stbi__write_context s = { 0 };
|
614 |
+
stbi__start_write_callbacks(&s, func, context);
|
615 |
+
return stbi_write_tga_core(&s, x, y, comp, (void *) data);
|
616 |
+
}
|
617 |
+
|
618 |
+
#ifndef STBI_WRITE_NO_STDIO
|
619 |
+
STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)
|
620 |
+
{
|
621 |
+
stbi__write_context s = { 0 };
|
622 |
+
if (stbi__start_write_file(&s,filename)) {
|
623 |
+
int r = stbi_write_tga_core(&s, x, y, comp, (void *) data);
|
624 |
+
stbi__end_write_file(&s);
|
625 |
+
return r;
|
626 |
+
} else
|
627 |
+
return 0;
|
628 |
+
}
|
629 |
+
#endif
|
630 |
+
|
631 |
+
// *************************************************************************************************
|
632 |
+
// Radiance RGBE HDR writer
|
633 |
+
// by Baldur Karlsson
|
634 |
+
|
635 |
+
#define stbiw__max(a, b) ((a) > (b) ? (a) : (b))
|
636 |
+
|
637 |
+
#ifndef STBI_WRITE_NO_STDIO
|
638 |
+
|
639 |
+
static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)
|
640 |
+
{
|
641 |
+
int exponent;
|
642 |
+
float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));
|
643 |
+
|
644 |
+
if (maxcomp < 1e-32f) {
|
645 |
+
rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;
|
646 |
+
} else {
|
647 |
+
float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;
|
648 |
+
|
649 |
+
rgbe[0] = (unsigned char)(linear[0] * normalize);
|
650 |
+
rgbe[1] = (unsigned char)(linear[1] * normalize);
|
651 |
+
rgbe[2] = (unsigned char)(linear[2] * normalize);
|
652 |
+
rgbe[3] = (unsigned char)(exponent + 128);
|
653 |
+
}
|
654 |
+
}
|
655 |
+
|
656 |
+
static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte)
|
657 |
+
{
|
658 |
+
unsigned char lengthbyte = STBIW_UCHAR(length+128);
|
659 |
+
STBIW_ASSERT(length+128 <= 255);
|
660 |
+
s->func(s->context, &lengthbyte, 1);
|
661 |
+
s->func(s->context, &databyte, 1);
|
662 |
+
}
|
663 |
+
|
664 |
+
static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data)
|
665 |
+
{
|
666 |
+
unsigned char lengthbyte = STBIW_UCHAR(length);
|
667 |
+
STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code
|
668 |
+
s->func(s->context, &lengthbyte, 1);
|
669 |
+
s->func(s->context, data, length);
|
670 |
+
}
|
671 |
+
|
672 |
+
static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline)
|
673 |
+
{
|
674 |
+
unsigned char scanlineheader[4] = { 2, 2, 0, 0 };
|
675 |
+
unsigned char rgbe[4];
|
676 |
+
float linear[3];
|
677 |
+
int x;
|
678 |
+
|
679 |
+
scanlineheader[2] = (width&0xff00)>>8;
|
680 |
+
scanlineheader[3] = (width&0x00ff);
|
681 |
+
|
682 |
+
/* skip RLE for images too small or large */
|
683 |
+
if (width < 8 || width >= 32768) {
|
684 |
+
for (x=0; x < width; x++) {
|
685 |
+
switch (ncomp) {
|
686 |
+
case 4: /* fallthrough */
|
687 |
+
case 3: linear[2] = scanline[x*ncomp + 2];
|
688 |
+
linear[1] = scanline[x*ncomp + 1];
|
689 |
+
linear[0] = scanline[x*ncomp + 0];
|
690 |
+
break;
|
691 |
+
default:
|
692 |
+
linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
|
693 |
+
break;
|
694 |
+
}
|
695 |
+
stbiw__linear_to_rgbe(rgbe, linear);
|
696 |
+
s->func(s->context, rgbe, 4);
|
697 |
+
}
|
698 |
+
} else {
|
699 |
+
int c,r;
|
700 |
+
/* encode into scratch buffer */
|
701 |
+
for (x=0; x < width; x++) {
|
702 |
+
switch(ncomp) {
|
703 |
+
case 4: /* fallthrough */
|
704 |
+
case 3: linear[2] = scanline[x*ncomp + 2];
|
705 |
+
linear[1] = scanline[x*ncomp + 1];
|
706 |
+
linear[0] = scanline[x*ncomp + 0];
|
707 |
+
break;
|
708 |
+
default:
|
709 |
+
linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
|
710 |
+
break;
|
711 |
+
}
|
712 |
+
stbiw__linear_to_rgbe(rgbe, linear);
|
713 |
+
scratch[x + width*0] = rgbe[0];
|
714 |
+
scratch[x + width*1] = rgbe[1];
|
715 |
+
scratch[x + width*2] = rgbe[2];
|
716 |
+
scratch[x + width*3] = rgbe[3];
|
717 |
+
}
|
718 |
+
|
719 |
+
s->func(s->context, scanlineheader, 4);
|
720 |
+
|
721 |
+
/* RLE each component separately */
|
722 |
+
for (c=0; c < 4; c++) {
|
723 |
+
unsigned char *comp = &scratch[width*c];
|
724 |
+
|
725 |
+
x = 0;
|
726 |
+
while (x < width) {
|
727 |
+
// find first run
|
728 |
+
r = x;
|
729 |
+
while (r+2 < width) {
|
730 |
+
if (comp[r] == comp[r+1] && comp[r] == comp[r+2])
|
731 |
+
break;
|
732 |
+
++r;
|
733 |
+
}
|
734 |
+
if (r+2 >= width)
|
735 |
+
r = width;
|
736 |
+
// dump up to first run
|
737 |
+
while (x < r) {
|
738 |
+
int len = r-x;
|
739 |
+
if (len > 128) len = 128;
|
740 |
+
stbiw__write_dump_data(s, len, &comp[x]);
|
741 |
+
x += len;
|
742 |
+
}
|
743 |
+
// if there's a run, output it
|
744 |
+
if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd
|
745 |
+
// find next byte after run
|
746 |
+
while (r < width && comp[r] == comp[x])
|
747 |
+
++r;
|
748 |
+
// output run up to r
|
749 |
+
while (x < r) {
|
750 |
+
int len = r-x;
|
751 |
+
if (len > 127) len = 127;
|
752 |
+
stbiw__write_run_data(s, len, comp[x]);
|
753 |
+
x += len;
|
754 |
+
}
|
755 |
+
}
|
756 |
+
}
|
757 |
+
}
|
758 |
+
}
|
759 |
+
}
|
760 |
+
|
761 |
+
static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data)
|
762 |
+
{
|
763 |
+
if (y <= 0 || x <= 0 || data == NULL)
|
764 |
+
return 0;
|
765 |
+
else {
|
766 |
+
// Each component is stored separately. Allocate scratch space for full output scanline.
|
767 |
+
unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);
|
768 |
+
int i, len;
|
769 |
+
char buffer[128];
|
770 |
+
char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n";
|
771 |
+
s->func(s->context, header, sizeof(header)-1);
|
772 |
+
|
773 |
+
#ifdef __STDC_LIB_EXT1__
|
774 |
+
len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
|
775 |
+
#else
|
776 |
+
len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
|
777 |
+
#endif
|
778 |
+
s->func(s->context, buffer, len);
|
779 |
+
|
780 |
+
for(i=0; i < y; i++)
|
781 |
+
stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i));
|
782 |
+
STBIW_FREE(scratch);
|
783 |
+
return 1;
|
784 |
+
}
|
785 |
+
}
|
786 |
+
|
787 |
+
STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data)
|
788 |
+
{
|
789 |
+
stbi__write_context s = { 0 };
|
790 |
+
stbi__start_write_callbacks(&s, func, context);
|
791 |
+
return stbi_write_hdr_core(&s, x, y, comp, (float *) data);
|
792 |
+
}
|
793 |
+
|
794 |
+
STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)
|
795 |
+
{
|
796 |
+
stbi__write_context s = { 0 };
|
797 |
+
if (stbi__start_write_file(&s,filename)) {
|
798 |
+
int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data);
|
799 |
+
stbi__end_write_file(&s);
|
800 |
+
return r;
|
801 |
+
} else
|
802 |
+
return 0;
|
803 |
+
}
|
804 |
+
#endif // STBI_WRITE_NO_STDIO
|
805 |
+
|
806 |
+
|
807 |
+
//////////////////////////////////////////////////////////////////////////////
|
808 |
+
//
|
809 |
+
// PNG writer
|
810 |
+
//
|
811 |
+
|
812 |
+
#ifndef STBIW_ZLIB_COMPRESS
|
813 |
+
// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()
|
814 |
+
#define stbiw__sbraw(a) ((int *) (void *) (a) - 2)
|
815 |
+
#define stbiw__sbm(a) stbiw__sbraw(a)[0]
|
816 |
+
#define stbiw__sbn(a) stbiw__sbraw(a)[1]
|
817 |
+
|
818 |
+
#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))
|
819 |
+
#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)
|
820 |
+
#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))
|
821 |
+
|
822 |
+
#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))
|
823 |
+
#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0)
|
824 |
+
#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)
|
825 |
+
|
826 |
+
static void *stbiw__sbgrowf(void **arr, int increment, int itemsize)
|
827 |
+
{
|
828 |
+
int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;
|
829 |
+
void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2);
|
830 |
+
STBIW_ASSERT(p);
|
831 |
+
if (p) {
|
832 |
+
if (!*arr) ((int *) p)[1] = 0;
|
833 |
+
*arr = (void *) ((int *) p + 2);
|
834 |
+
stbiw__sbm(*arr) = m;
|
835 |
+
}
|
836 |
+
return *arr;
|
837 |
+
}
|
838 |
+
|
839 |
+
static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)
|
840 |
+
{
|
841 |
+
while (*bitcount >= 8) {
|
842 |
+
stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer));
|
843 |
+
*bitbuffer >>= 8;
|
844 |
+
*bitcount -= 8;
|
845 |
+
}
|
846 |
+
return data;
|
847 |
+
}
|
848 |
+
|
849 |
+
static int stbiw__zlib_bitrev(int code, int codebits)
|
850 |
+
{
|
851 |
+
int res=0;
|
852 |
+
while (codebits--) {
|
853 |
+
res = (res << 1) | (code & 1);
|
854 |
+
code >>= 1;
|
855 |
+
}
|
856 |
+
return res;
|
857 |
+
}
|
858 |
+
|
859 |
+
static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)
|
860 |
+
{
|
861 |
+
int i;
|
862 |
+
for (i=0; i < limit && i < 258; ++i)
|
863 |
+
if (a[i] != b[i]) break;
|
864 |
+
return i;
|
865 |
+
}
|
866 |
+
|
867 |
+
static unsigned int stbiw__zhash(unsigned char *data)
|
868 |
+
{
|
869 |
+
stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);
|
870 |
+
hash ^= hash << 3;
|
871 |
+
hash += hash >> 5;
|
872 |
+
hash ^= hash << 4;
|
873 |
+
hash += hash >> 17;
|
874 |
+
hash ^= hash << 25;
|
875 |
+
hash += hash >> 6;
|
876 |
+
return hash;
|
877 |
+
}
|
878 |
+
|
879 |
+
#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))
|
880 |
+
#define stbiw__zlib_add(code,codebits) \
|
881 |
+
(bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())
|
882 |
+
#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)
|
883 |
+
// default huffman tables
|
884 |
+
#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8)
|
885 |
+
#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9)
|
886 |
+
#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7)
|
887 |
+
#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8)
|
888 |
+
#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))
|
889 |
+
#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))
|
890 |
+
|
891 |
+
#define stbiw__ZHASH 16384
|
892 |
+
|
893 |
+
#endif // STBIW_ZLIB_COMPRESS
|
894 |
+
|
895 |
+
STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)
|
896 |
+
{
|
897 |
+
#ifdef STBIW_ZLIB_COMPRESS
|
898 |
+
// user provided a zlib compress implementation, use that
|
899 |
+
return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality);
|
900 |
+
#else // use builtin
|
901 |
+
static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };
|
902 |
+
static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 };
|
903 |
+
static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };
|
904 |
+
static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };
|
905 |
+
unsigned int bitbuf=0;
|
906 |
+
int i,j, bitcount=0;
|
907 |
+
unsigned char *out = NULL;
|
908 |
+
unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));
|
909 |
+
if (hash_table == NULL)
|
910 |
+
return NULL;
|
911 |
+
if (quality < 5) quality = 5;
|
912 |
+
|
913 |
+
stbiw__sbpush(out, 0x78); // DEFLATE 32K window
|
914 |
+
stbiw__sbpush(out, 0x5e); // FLEVEL = 1
|
915 |
+
stbiw__zlib_add(1,1); // BFINAL = 1
|
916 |
+
stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman
|
917 |
+
|
918 |
+
for (i=0; i < stbiw__ZHASH; ++i)
|
919 |
+
hash_table[i] = NULL;
|
920 |
+
|
921 |
+
i=0;
|
922 |
+
while (i < data_len-3) {
|
923 |
+
// hash next 3 bytes of data to be compressed
|
924 |
+
int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;
|
925 |
+
unsigned char *bestloc = 0;
|
926 |
+
unsigned char **hlist = hash_table[h];
|
927 |
+
int n = stbiw__sbcount(hlist);
|
928 |
+
for (j=0; j < n; ++j) {
|
929 |
+
if (hlist[j]-data > i-32768) { // if entry lies within window
|
930 |
+
int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);
|
931 |
+
if (d >= best) { best=d; bestloc=hlist[j]; }
|
932 |
+
}
|
933 |
+
}
|
934 |
+
// when hash table entry is too long, delete half the entries
|
935 |
+
if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {
|
936 |
+
STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);
|
937 |
+
stbiw__sbn(hash_table[h]) = quality;
|
938 |
+
}
|
939 |
+
stbiw__sbpush(hash_table[h],data+i);
|
940 |
+
|
941 |
+
if (bestloc) {
|
942 |
+
// "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal
|
943 |
+
h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);
|
944 |
+
hlist = hash_table[h];
|
945 |
+
n = stbiw__sbcount(hlist);
|
946 |
+
for (j=0; j < n; ++j) {
|
947 |
+
if (hlist[j]-data > i-32767) {
|
948 |
+
int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);
|
949 |
+
if (e > best) { // if next match is better, bail on current match
|
950 |
+
bestloc = NULL;
|
951 |
+
break;
|
952 |
+
}
|
953 |
+
}
|
954 |
+
}
|
955 |
+
}
|
956 |
+
|
957 |
+
if (bestloc) {
|
958 |
+
int d = (int) (data+i - bestloc); // distance back
|
959 |
+
STBIW_ASSERT(d <= 32767 && best <= 258);
|
960 |
+
for (j=0; best > lengthc[j+1]-1; ++j);
|
961 |
+
stbiw__zlib_huff(j+257);
|
962 |
+
if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);
|
963 |
+
for (j=0; d > distc[j+1]-1; ++j);
|
964 |
+
stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);
|
965 |
+
if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);
|
966 |
+
i += best;
|
967 |
+
} else {
|
968 |
+
stbiw__zlib_huffb(data[i]);
|
969 |
+
++i;
|
970 |
+
}
|
971 |
+
}
|
972 |
+
// write out final bytes
|
973 |
+
for (;i < data_len; ++i)
|
974 |
+
stbiw__zlib_huffb(data[i]);
|
975 |
+
stbiw__zlib_huff(256); // end of block
|
976 |
+
// pad with 0 bits to byte boundary
|
977 |
+
while (bitcount)
|
978 |
+
stbiw__zlib_add(0,1);
|
979 |
+
|
980 |
+
for (i=0; i < stbiw__ZHASH; ++i)
|
981 |
+
(void) stbiw__sbfree(hash_table[i]);
|
982 |
+
STBIW_FREE(hash_table);
|
983 |
+
|
984 |
+
// store uncompressed instead if compression was worse
|
985 |
+
if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) {
|
986 |
+
stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1
|
987 |
+
for (j = 0; j < data_len;) {
|
988 |
+
int blocklen = data_len - j;
|
989 |
+
if (blocklen > 32767) blocklen = 32767;
|
990 |
+
stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression
|
991 |
+
stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN
|
992 |
+
stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8));
|
993 |
+
stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN
|
994 |
+
stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8));
|
995 |
+
memcpy(out+stbiw__sbn(out), data+j, blocklen);
|
996 |
+
stbiw__sbn(out) += blocklen;
|
997 |
+
j += blocklen;
|
998 |
+
}
|
999 |
+
}
|
1000 |
+
|
1001 |
+
{
|
1002 |
+
// compute adler32 on input
|
1003 |
+
unsigned int s1=1, s2=0;
|
1004 |
+
int blocklen = (int) (data_len % 5552);
|
1005 |
+
j=0;
|
1006 |
+
while (j < data_len) {
|
1007 |
+
for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; }
|
1008 |
+
s1 %= 65521; s2 %= 65521;
|
1009 |
+
j += blocklen;
|
1010 |
+
blocklen = 5552;
|
1011 |
+
}
|
1012 |
+
stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8));
|
1013 |
+
stbiw__sbpush(out, STBIW_UCHAR(s2));
|
1014 |
+
stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8));
|
1015 |
+
stbiw__sbpush(out, STBIW_UCHAR(s1));
|
1016 |
+
}
|
1017 |
+
*out_len = stbiw__sbn(out);
|
1018 |
+
// make returned pointer freeable
|
1019 |
+
STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);
|
1020 |
+
return (unsigned char *) stbiw__sbraw(out);
|
1021 |
+
#endif // STBIW_ZLIB_COMPRESS
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
static unsigned int stbiw__crc32(unsigned char *buffer, int len)
|
1025 |
+
{
|
1026 |
+
#ifdef STBIW_CRC32
|
1027 |
+
return STBIW_CRC32(buffer, len);
|
1028 |
+
#else
|
1029 |
+
static unsigned int crc_table[256] =
|
1030 |
+
{
|
1031 |
+
0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,
|
1032 |
+
0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,
|
1033 |
+
0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
|
1034 |
+
0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,
|
1035 |
+
0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,
|
1036 |
+
0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
|
1037 |
+
0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,
|
1038 |
+
0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,
|
1039 |
+
0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
|
1040 |
+
0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,
|
1041 |
+
0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,
|
1042 |
+
0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
|
1043 |
+
0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,
|
1044 |
+
0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,
|
1045 |
+
0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
|
1046 |
+
0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,
|
1047 |
+
0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,
|
1048 |
+
0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
|
1049 |
+
0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,
|
1050 |
+
0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,
|
1051 |
+
0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
|
1052 |
+
0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,
|
1053 |
+
0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,
|
1054 |
+
0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
|
1055 |
+
0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,
|
1056 |
+
0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,
|
1057 |
+
0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
|
1058 |
+
0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,
|
1059 |
+
0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,
|
1060 |
+
0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
|
1061 |
+
0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,
|
1062 |
+
0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D
|
1063 |
+
};
|
1064 |
+
|
1065 |
+
unsigned int crc = ~0u;
|
1066 |
+
int i;
|
1067 |
+
for (i=0; i < len; ++i)
|
1068 |
+
crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];
|
1069 |
+
return ~crc;
|
1070 |
+
#endif
|
1071 |
+
}
|
1072 |
+
|
1073 |
+
#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4)
|
1074 |
+
#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));
|
1075 |
+
#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])
|
1076 |
+
|
1077 |
+
static void stbiw__wpcrc(unsigned char **data, int len)
|
1078 |
+
{
|
1079 |
+
unsigned int crc = stbiw__crc32(*data - len - 4, len+4);
|
1080 |
+
stbiw__wp32(*data, crc);
|
1081 |
+
}
|
1082 |
+
|
1083 |
+
static unsigned char stbiw__paeth(int a, int b, int c)
|
1084 |
+
{
|
1085 |
+
int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);
|
1086 |
+
if (pa <= pb && pa <= pc) return STBIW_UCHAR(a);
|
1087 |
+
if (pb <= pc) return STBIW_UCHAR(b);
|
1088 |
+
return STBIW_UCHAR(c);
|
1089 |
+
}
|
1090 |
+
|
1091 |
+
// @OPTIMIZE: provide an option that always forces left-predict or paeth predict
|
1092 |
+
static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer)
|
1093 |
+
{
|
1094 |
+
static int mapping[] = { 0,1,2,3,4 };
|
1095 |
+
static int firstmap[] = { 0,1,0,5,6 };
|
1096 |
+
int *mymap = (y != 0) ? mapping : firstmap;
|
1097 |
+
int i;
|
1098 |
+
int type = mymap[filter_type];
|
1099 |
+
unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y);
|
1100 |
+
int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes;
|
1101 |
+
|
1102 |
+
if (type==0) {
|
1103 |
+
memcpy(line_buffer, z, width*n);
|
1104 |
+
return;
|
1105 |
+
}
|
1106 |
+
|
1107 |
+
// first loop isn't optimized since it's just one pixel
|
1108 |
+
for (i = 0; i < n; ++i) {
|
1109 |
+
switch (type) {
|
1110 |
+
case 1: line_buffer[i] = z[i]; break;
|
1111 |
+
case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break;
|
1112 |
+
case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break;
|
1113 |
+
case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break;
|
1114 |
+
case 5: line_buffer[i] = z[i]; break;
|
1115 |
+
case 6: line_buffer[i] = z[i]; break;
|
1116 |
+
}
|
1117 |
+
}
|
1118 |
+
switch (type) {
|
1119 |
+
case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break;
|
1120 |
+
case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break;
|
1121 |
+
case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break;
|
1122 |
+
case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break;
|
1123 |
+
case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break;
|
1124 |
+
case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;
|
1125 |
+
}
|
1126 |
+
}
|
1127 |
+
|
1128 |
+
STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)
|
1129 |
+
{
|
1130 |
+
int force_filter = stbi_write_force_png_filter;
|
1131 |
+
int ctype[5] = { -1, 0, 4, 2, 6 };
|
1132 |
+
unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };
|
1133 |
+
unsigned char *out,*o, *filt, *zlib;
|
1134 |
+
signed char *line_buffer;
|
1135 |
+
int j,zlen;
|
1136 |
+
|
1137 |
+
if (stride_bytes == 0)
|
1138 |
+
stride_bytes = x * n;
|
1139 |
+
|
1140 |
+
if (force_filter >= 5) {
|
1141 |
+
force_filter = -1;
|
1142 |
+
}
|
1143 |
+
|
1144 |
+
filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;
|
1145 |
+
line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }
|
1146 |
+
for (j=0; j < y; ++j) {
|
1147 |
+
int filter_type;
|
1148 |
+
if (force_filter > -1) {
|
1149 |
+
filter_type = force_filter;
|
1150 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer);
|
1151 |
+
} else { // Estimate the best filter by running through all of them:
|
1152 |
+
int best_filter = 0, best_filter_val = 0x7fffffff, est, i;
|
1153 |
+
for (filter_type = 0; filter_type < 5; filter_type++) {
|
1154 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer);
|
1155 |
+
|
1156 |
+
// Estimate the entropy of the line using this filter; the less, the better.
|
1157 |
+
est = 0;
|
1158 |
+
for (i = 0; i < x*n; ++i) {
|
1159 |
+
est += abs((signed char) line_buffer[i]);
|
1160 |
+
}
|
1161 |
+
if (est < best_filter_val) {
|
1162 |
+
best_filter_val = est;
|
1163 |
+
best_filter = filter_type;
|
1164 |
+
}
|
1165 |
+
}
|
1166 |
+
if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it
|
1167 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer);
|
1168 |
+
filter_type = best_filter;
|
1169 |
+
}
|
1170 |
+
}
|
1171 |
+
// when we get here, filter_type contains the filter type, and line_buffer contains the data
|
1172 |
+
filt[j*(x*n+1)] = (unsigned char) filter_type;
|
1173 |
+
STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);
|
1174 |
+
}
|
1175 |
+
STBIW_FREE(line_buffer);
|
1176 |
+
zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level);
|
1177 |
+
STBIW_FREE(filt);
|
1178 |
+
if (!zlib) return 0;
|
1179 |
+
|
1180 |
+
// each tag requires 12 bytes of overhead
|
1181 |
+
out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);
|
1182 |
+
if (!out) return 0;
|
1183 |
+
*out_len = 8 + 12+13 + 12+zlen + 12;
|
1184 |
+
|
1185 |
+
o=out;
|
1186 |
+
STBIW_MEMMOVE(o,sig,8); o+= 8;
|
1187 |
+
stbiw__wp32(o, 13); // header length
|
1188 |
+
stbiw__wptag(o, "IHDR");
|
1189 |
+
stbiw__wp32(o, x);
|
1190 |
+
stbiw__wp32(o, y);
|
1191 |
+
*o++ = 8;
|
1192 |
+
*o++ = STBIW_UCHAR(ctype[n]);
|
1193 |
+
*o++ = 0;
|
1194 |
+
*o++ = 0;
|
1195 |
+
*o++ = 0;
|
1196 |
+
stbiw__wpcrc(&o,13);
|
1197 |
+
|
1198 |
+
stbiw__wp32(o, zlen);
|
1199 |
+
stbiw__wptag(o, "IDAT");
|
1200 |
+
STBIW_MEMMOVE(o, zlib, zlen);
|
1201 |
+
o += zlen;
|
1202 |
+
STBIW_FREE(zlib);
|
1203 |
+
stbiw__wpcrc(&o, zlen);
|
1204 |
+
|
1205 |
+
stbiw__wp32(o,0);
|
1206 |
+
stbiw__wptag(o, "IEND");
|
1207 |
+
stbiw__wpcrc(&o,0);
|
1208 |
+
|
1209 |
+
STBIW_ASSERT(o == out + *out_len);
|
1210 |
+
|
1211 |
+
return out;
|
1212 |
+
}
|
1213 |
+
|
1214 |
+
#ifndef STBI_WRITE_NO_STDIO
|
1215 |
+
STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)
|
1216 |
+
{
|
1217 |
+
FILE *f;
|
1218 |
+
int len;
|
1219 |
+
unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
|
1220 |
+
if (png == NULL) return 0;
|
1221 |
+
|
1222 |
+
f = stbiw__fopen(filename, "wb");
|
1223 |
+
if (!f) { STBIW_FREE(png); return 0; }
|
1224 |
+
fwrite(png, 1, len, f);
|
1225 |
+
fclose(f);
|
1226 |
+
STBIW_FREE(png);
|
1227 |
+
return 1;
|
1228 |
+
}
|
1229 |
+
#endif
|
1230 |
+
|
1231 |
+
STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes)
|
1232 |
+
{
|
1233 |
+
int len;
|
1234 |
+
unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
|
1235 |
+
if (png == NULL) return 0;
|
1236 |
+
func(context, png, len);
|
1237 |
+
STBIW_FREE(png);
|
1238 |
+
return 1;
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
|
1242 |
+
/* ***************************************************************************
|
1243 |
+
*
|
1244 |
+
* JPEG writer
|
1245 |
+
*
|
1246 |
+
* This is based on Jon Olick's jo_jpeg.cpp:
|
1247 |
+
* public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html
|
1248 |
+
*/
|
1249 |
+
|
1250 |
+
static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18,
|
1251 |
+
24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 };
|
1252 |
+
|
1253 |
+
static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) {
|
1254 |
+
int bitBuf = *bitBufP, bitCnt = *bitCntP;
|
1255 |
+
bitCnt += bs[1];
|
1256 |
+
bitBuf |= bs[0] << (24 - bitCnt);
|
1257 |
+
while(bitCnt >= 8) {
|
1258 |
+
unsigned char c = (bitBuf >> 16) & 255;
|
1259 |
+
stbiw__putc(s, c);
|
1260 |
+
if(c == 255) {
|
1261 |
+
stbiw__putc(s, 0);
|
1262 |
+
}
|
1263 |
+
bitBuf <<= 8;
|
1264 |
+
bitCnt -= 8;
|
1265 |
+
}
|
1266 |
+
*bitBufP = bitBuf;
|
1267 |
+
*bitCntP = bitCnt;
|
1268 |
+
}
|
1269 |
+
|
1270 |
+
static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) {
|
1271 |
+
float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p;
|
1272 |
+
float z1, z2, z3, z4, z5, z11, z13;
|
1273 |
+
|
1274 |
+
float tmp0 = d0 + d7;
|
1275 |
+
float tmp7 = d0 - d7;
|
1276 |
+
float tmp1 = d1 + d6;
|
1277 |
+
float tmp6 = d1 - d6;
|
1278 |
+
float tmp2 = d2 + d5;
|
1279 |
+
float tmp5 = d2 - d5;
|
1280 |
+
float tmp3 = d3 + d4;
|
1281 |
+
float tmp4 = d3 - d4;
|
1282 |
+
|
1283 |
+
// Even part
|
1284 |
+
float tmp10 = tmp0 + tmp3; // phase 2
|
1285 |
+
float tmp13 = tmp0 - tmp3;
|
1286 |
+
float tmp11 = tmp1 + tmp2;
|
1287 |
+
float tmp12 = tmp1 - tmp2;
|
1288 |
+
|
1289 |
+
d0 = tmp10 + tmp11; // phase 3
|
1290 |
+
d4 = tmp10 - tmp11;
|
1291 |
+
|
1292 |
+
z1 = (tmp12 + tmp13) * 0.707106781f; // c4
|
1293 |
+
d2 = tmp13 + z1; // phase 5
|
1294 |
+
d6 = tmp13 - z1;
|
1295 |
+
|
1296 |
+
// Odd part
|
1297 |
+
tmp10 = tmp4 + tmp5; // phase 2
|
1298 |
+
tmp11 = tmp5 + tmp6;
|
1299 |
+
tmp12 = tmp6 + tmp7;
|
1300 |
+
|
1301 |
+
// The rotator is modified from fig 4-8 to avoid extra negations.
|
1302 |
+
z5 = (tmp10 - tmp12) * 0.382683433f; // c6
|
1303 |
+
z2 = tmp10 * 0.541196100f + z5; // c2-c6
|
1304 |
+
z4 = tmp12 * 1.306562965f + z5; // c2+c6
|
1305 |
+
z3 = tmp11 * 0.707106781f; // c4
|
1306 |
+
|
1307 |
+
z11 = tmp7 + z3; // phase 5
|
1308 |
+
z13 = tmp7 - z3;
|
1309 |
+
|
1310 |
+
*d5p = z13 + z2; // phase 6
|
1311 |
+
*d3p = z13 - z2;
|
1312 |
+
*d1p = z11 + z4;
|
1313 |
+
*d7p = z11 - z4;
|
1314 |
+
|
1315 |
+
*d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6;
|
1316 |
+
}
|
1317 |
+
|
1318 |
+
static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) {
|
1319 |
+
int tmp1 = val < 0 ? -val : val;
|
1320 |
+
val = val < 0 ? val-1 : val;
|
1321 |
+
bits[1] = 1;
|
1322 |
+
while(tmp1 >>= 1) {
|
1323 |
+
++bits[1];
|
1324 |
+
}
|
1325 |
+
bits[0] = val & ((1<<bits[1])-1);
|
1326 |
+
}
|
1327 |
+
|
1328 |
+
static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt, float *CDU, int du_stride, float *fdtbl, int DC, const unsigned short HTDC[256][2], const unsigned short HTAC[256][2]) {
|
1329 |
+
const unsigned short EOB[2] = { HTAC[0x00][0], HTAC[0x00][1] };
|
1330 |
+
const unsigned short M16zeroes[2] = { HTAC[0xF0][0], HTAC[0xF0][1] };
|
1331 |
+
int dataOff, i, j, n, diff, end0pos, x, y;
|
1332 |
+
int DU[64];
|
1333 |
+
|
1334 |
+
// DCT rows
|
1335 |
+
for(dataOff=0, n=du_stride*8; dataOff<n; dataOff+=du_stride) {
|
1336 |
+
stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+1], &CDU[dataOff+2], &CDU[dataOff+3], &CDU[dataOff+4], &CDU[dataOff+5], &CDU[dataOff+6], &CDU[dataOff+7]);
|
1337 |
+
}
|
1338 |
+
// DCT columns
|
1339 |
+
for(dataOff=0; dataOff<8; ++dataOff) {
|
1340 |
+
stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+du_stride], &CDU[dataOff+du_stride*2], &CDU[dataOff+du_stride*3], &CDU[dataOff+du_stride*4],
|
1341 |
+
&CDU[dataOff+du_stride*5], &CDU[dataOff+du_stride*6], &CDU[dataOff+du_stride*7]);
|
1342 |
+
}
|
1343 |
+
// Quantize/descale/zigzag the coefficients
|
1344 |
+
for(y = 0, j=0; y < 8; ++y) {
|
1345 |
+
for(x = 0; x < 8; ++x,++j) {
|
1346 |
+
float v;
|
1347 |
+
i = y*du_stride+x;
|
1348 |
+
v = CDU[i]*fdtbl[j];
|
1349 |
+
// DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? ceilf(v - 0.5f) : floorf(v + 0.5f));
|
1350 |
+
// ceilf() and floorf() are C99, not C89, but I /think/ they're not needed here anyway?
|
1351 |
+
DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? v - 0.5f : v + 0.5f);
|
1352 |
+
}
|
1353 |
+
}
|
1354 |
+
|
1355 |
+
// Encode DC
|
1356 |
+
diff = DU[0] - DC;
|
1357 |
+
if (diff == 0) {
|
1358 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[0]);
|
1359 |
+
} else {
|
1360 |
+
unsigned short bits[2];
|
1361 |
+
stbiw__jpg_calcBits(diff, bits);
|
1362 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[bits[1]]);
|
1363 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
|
1364 |
+
}
|
1365 |
+
// Encode ACs
|
1366 |
+
end0pos = 63;
|
1367 |
+
for(; (end0pos>0)&&(DU[end0pos]==0); --end0pos) {
|
1368 |
+
}
|
1369 |
+
// end0pos = first element in reverse order !=0
|
1370 |
+
if(end0pos == 0) {
|
1371 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
|
1372 |
+
return DU[0];
|
1373 |
+
}
|
1374 |
+
for(i = 1; i <= end0pos; ++i) {
|
1375 |
+
int startpos = i;
|
1376 |
+
int nrzeroes;
|
1377 |
+
unsigned short bits[2];
|
1378 |
+
for (; DU[i]==0 && i<=end0pos; ++i) {
|
1379 |
+
}
|
1380 |
+
nrzeroes = i-startpos;
|
1381 |
+
if ( nrzeroes >= 16 ) {
|
1382 |
+
int lng = nrzeroes>>4;
|
1383 |
+
int nrmarker;
|
1384 |
+
for (nrmarker=1; nrmarker <= lng; ++nrmarker)
|
1385 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes);
|
1386 |
+
nrzeroes &= 15;
|
1387 |
+
}
|
1388 |
+
stbiw__jpg_calcBits(DU[i], bits);
|
1389 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]);
|
1390 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
|
1391 |
+
}
|
1392 |
+
if(end0pos != 63) {
|
1393 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
|
1394 |
+
}
|
1395 |
+
return DU[0];
|
1396 |
+
}
|
1397 |
+
|
1398 |
+
static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {
|
1399 |
+
// Constants that don't pollute global namespace
|
1400 |
+
static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
|
1401 |
+
static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
1402 |
+
static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d};
|
1403 |
+
static const unsigned char std_ac_luminance_values[] = {
|
1404 |
+
0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,
|
1405 |
+
0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,
|
1406 |
+
0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,
|
1407 |
+
0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,
|
1408 |
+
0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,
|
1409 |
+
0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,
|
1410 |
+
0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
|
1411 |
+
};
|
1412 |
+
static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0};
|
1413 |
+
static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
1414 |
+
static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77};
|
1415 |
+
static const unsigned char std_ac_chrominance_values[] = {
|
1416 |
+
0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,
|
1417 |
+
0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,
|
1418 |
+
0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,
|
1419 |
+
0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,
|
1420 |
+
0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,
|
1421 |
+
0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,
|
1422 |
+
0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
|
1423 |
+
};
|
1424 |
+
// Huffman tables
|
1425 |
+
static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}};
|
1426 |
+
static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}};
|
1427 |
+
static const unsigned short YAC_HT[256][2] = {
|
1428 |
+
{10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1429 |
+
{12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1430 |
+
{28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1431 |
+
{58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1432 |
+
{59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1433 |
+
{122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1434 |
+
{123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1435 |
+
{250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1436 |
+
{504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1437 |
+
{505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1438 |
+
{506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1439 |
+
{1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1440 |
+
{1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1441 |
+
{2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1442 |
+
{65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1443 |
+
{2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
|
1444 |
+
};
|
1445 |
+
static const unsigned short UVAC_HT[256][2] = {
|
1446 |
+
{0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1447 |
+
{11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1448 |
+
{26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1449 |
+
{27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1450 |
+
{58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1451 |
+
{59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1452 |
+
{121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1453 |
+
{122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1454 |
+
{249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1455 |
+
{503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1456 |
+
{504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1457 |
+
{505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1458 |
+
{506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1459 |
+
{2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1460 |
+
{16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0},
|
1461 |
+
{1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
|
1462 |
+
};
|
1463 |
+
static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22,
|
1464 |
+
37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99};
|
1465 |
+
static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99,
|
1466 |
+
99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99};
|
1467 |
+
static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f,
|
1468 |
+
1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f };
|
1469 |
+
|
1470 |
+
int row, col, i, k, subsample;
|
1471 |
+
float fdtbl_Y[64], fdtbl_UV[64];
|
1472 |
+
unsigned char YTable[64], UVTable[64];
|
1473 |
+
|
1474 |
+
if(!data || !width || !height || comp > 4 || comp < 1) {
|
1475 |
+
return 0;
|
1476 |
+
}
|
1477 |
+
|
1478 |
+
quality = quality ? quality : 90;
|
1479 |
+
subsample = quality <= 90 ? 1 : 0;
|
1480 |
+
quality = quality < 1 ? 1 : quality > 100 ? 100 : quality;
|
1481 |
+
quality = quality < 50 ? 5000 / quality : 200 - quality * 2;
|
1482 |
+
|
1483 |
+
for(i = 0; i < 64; ++i) {
|
1484 |
+
int uvti, yti = (YQT[i]*quality+50)/100;
|
1485 |
+
YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti);
|
1486 |
+
uvti = (UVQT[i]*quality+50)/100;
|
1487 |
+
UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti);
|
1488 |
+
}
|
1489 |
+
|
1490 |
+
for(row = 0, k = 0; row < 8; ++row) {
|
1491 |
+
for(col = 0; col < 8; ++col, ++k) {
|
1492 |
+
fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
|
1493 |
+
fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
|
1494 |
+
}
|
1495 |
+
}
|
1496 |
+
|
1497 |
+
// Write Headers
|
1498 |
+
{
|
1499 |
+
static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 };
|
1500 |
+
static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 };
|
1501 |
+
const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width),
|
1502 |
+
3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 };
|
1503 |
+
s->func(s->context, (void*)head0, sizeof(head0));
|
1504 |
+
s->func(s->context, (void*)YTable, sizeof(YTable));
|
1505 |
+
stbiw__putc(s, 1);
|
1506 |
+
s->func(s->context, UVTable, sizeof(UVTable));
|
1507 |
+
s->func(s->context, (void*)head1, sizeof(head1));
|
1508 |
+
s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
|
1509 |
+
s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
|
1510 |
+
stbiw__putc(s, 0x10); // HTYACinfo
|
1511 |
+
s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1);
|
1512 |
+
s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values));
|
1513 |
+
stbiw__putc(s, 1); // HTUDCinfo
|
1514 |
+
s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1);
|
1515 |
+
s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values));
|
1516 |
+
stbiw__putc(s, 0x11); // HTUACinfo
|
1517 |
+
s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1);
|
1518 |
+
s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values));
|
1519 |
+
s->func(s->context, (void*)head2, sizeof(head2));
|
1520 |
+
}
|
1521 |
+
|
1522 |
+
// Encode 8x8 macroblocks
|
1523 |
+
{
|
1524 |
+
static const unsigned short fillBits[] = {0x7F, 7};
|
1525 |
+
int DCY=0, DCU=0, DCV=0;
|
1526 |
+
int bitBuf=0, bitCnt=0;
|
1527 |
+
// comp == 2 is grey+alpha (alpha is ignored)
|
1528 |
+
int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0;
|
1529 |
+
const unsigned char *dataR = (const unsigned char *)data;
|
1530 |
+
const unsigned char *dataG = dataR + ofsG;
|
1531 |
+
const unsigned char *dataB = dataR + ofsB;
|
1532 |
+
int x, y, pos;
|
1533 |
+
if(subsample) {
|
1534 |
+
for(y = 0; y < height; y += 16) {
|
1535 |
+
for(x = 0; x < width; x += 16) {
|
1536 |
+
float Y[256], U[256], V[256];
|
1537 |
+
for(row = y, pos = 0; row < y+16; ++row) {
|
1538 |
+
// row >= height => use last input row
|
1539 |
+
int clamped_row = (row < height) ? row : height - 1;
|
1540 |
+
int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
|
1541 |
+
for(col = x; col < x+16; ++col, ++pos) {
|
1542 |
+
// if col >= width => use pixel from last input column
|
1543 |
+
int p = base_p + ((col < width) ? col : (width-1))*comp;
|
1544 |
+
float r = dataR[p], g = dataG[p], b = dataB[p];
|
1545 |
+
Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
|
1546 |
+
U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
|
1547 |
+
V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
|
1548 |
+
}
|
1549 |
+
}
|
1550 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
1551 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
1552 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
1553 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
1554 |
+
|
1555 |
+
// subsample U,V
|
1556 |
+
{
|
1557 |
+
float subU[64], subV[64];
|
1558 |
+
int yy, xx;
|
1559 |
+
for(yy = 0, pos = 0; yy < 8; ++yy) {
|
1560 |
+
for(xx = 0; xx < 8; ++xx, ++pos) {
|
1561 |
+
int j = yy*32+xx*2;
|
1562 |
+
subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f;
|
1563 |
+
subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f;
|
1564 |
+
}
|
1565 |
+
}
|
1566 |
+
DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
|
1567 |
+
DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
|
1568 |
+
}
|
1569 |
+
}
|
1570 |
+
}
|
1571 |
+
} else {
|
1572 |
+
for(y = 0; y < height; y += 8) {
|
1573 |
+
for(x = 0; x < width; x += 8) {
|
1574 |
+
float Y[64], U[64], V[64];
|
1575 |
+
for(row = y, pos = 0; row < y+8; ++row) {
|
1576 |
+
// row >= height => use last input row
|
1577 |
+
int clamped_row = (row < height) ? row : height - 1;
|
1578 |
+
int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
|
1579 |
+
for(col = x; col < x+8; ++col, ++pos) {
|
1580 |
+
// if col >= width => use pixel from last input column
|
1581 |
+
int p = base_p + ((col < width) ? col : (width-1))*comp;
|
1582 |
+
float r = dataR[p], g = dataG[p], b = dataB[p];
|
1583 |
+
Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
|
1584 |
+
U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
|
1585 |
+
V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
|
1586 |
+
}
|
1587 |
+
}
|
1588 |
+
|
1589 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
1590 |
+
DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
|
1591 |
+
DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
|
1592 |
+
}
|
1593 |
+
}
|
1594 |
+
}
|
1595 |
+
|
1596 |
+
// Do the bit alignment of the EOI marker
|
1597 |
+
stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits);
|
1598 |
+
}
|
1599 |
+
|
1600 |
+
// EOI
|
1601 |
+
stbiw__putc(s, 0xFF);
|
1602 |
+
stbiw__putc(s, 0xD9);
|
1603 |
+
|
1604 |
+
return 1;
|
1605 |
+
}
|
1606 |
+
|
1607 |
+
STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality)
|
1608 |
+
{
|
1609 |
+
stbi__write_context s = { 0 };
|
1610 |
+
stbi__start_write_callbacks(&s, func, context);
|
1611 |
+
return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);
|
1612 |
+
}
|
1613 |
+
|
1614 |
+
|
1615 |
+
#ifndef STBI_WRITE_NO_STDIO
|
1616 |
+
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)
|
1617 |
+
{
|
1618 |
+
stbi__write_context s = { 0 };
|
1619 |
+
if (stbi__start_write_file(&s,filename)) {
|
1620 |
+
int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);
|
1621 |
+
stbi__end_write_file(&s);
|
1622 |
+
return r;
|
1623 |
+
} else
|
1624 |
+
return 0;
|
1625 |
+
}
|
1626 |
+
#endif
|
1627 |
+
|
1628 |
+
#endif // STB_IMAGE_WRITE_IMPLEMENTATION
|
1629 |
+
|
1630 |
+
/* Revision history
|
1631 |
+
1.16 (2021-07-11)
|
1632 |
+
make Deflate code emit uncompressed blocks when it would otherwise expand
|
1633 |
+
support writing BMPs with alpha channel
|
1634 |
+
1.15 (2020-07-13) unknown
|
1635 |
+
1.14 (2020-02-02) updated JPEG writer to downsample chroma channels
|
1636 |
+
1.13
|
1637 |
+
1.12
|
1638 |
+
1.11 (2019-08-11)
|
1639 |
+
|
1640 |
+
1.10 (2019-02-07)
|
1641 |
+
support utf8 filenames in Windows; fix warnings and platform ifdefs
|
1642 |
+
1.09 (2018-02-11)
|
1643 |
+
fix typo in zlib quality API, improve STB_I_W_STATIC in C++
|
1644 |
+
1.08 (2018-01-29)
|
1645 |
+
add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter
|
1646 |
+
1.07 (2017-07-24)
|
1647 |
+
doc fix
|
1648 |
+
1.06 (2017-07-23)
|
1649 |
+
writing JPEG (using Jon Olick's code)
|
1650 |
+
1.05 ???
|
1651 |
+
1.04 (2017-03-03)
|
1652 |
+
monochrome BMP expansion
|
1653 |
+
1.03 ???
|
1654 |
+
1.02 (2016-04-02)
|
1655 |
+
avoid allocating large structures on the stack
|
1656 |
+
1.01 (2016-01-16)
|
1657 |
+
STBIW_REALLOC_SIZED: support allocators with no realloc support
|
1658 |
+
avoid race-condition in crc initialization
|
1659 |
+
minor compile issues
|
1660 |
+
1.00 (2015-09-14)
|
1661 |
+
installable file IO function
|
1662 |
+
0.99 (2015-09-13)
|
1663 |
+
warning fixes; TGA rle support
|
1664 |
+
0.98 (2015-04-08)
|
1665 |
+
added STBIW_MALLOC, STBIW_ASSERT etc
|
1666 |
+
0.97 (2015-01-18)
|
1667 |
+
fixed HDR asserts, rewrote HDR rle logic
|
1668 |
+
0.96 (2015-01-17)
|
1669 |
+
add HDR output
|
1670 |
+
fix monochrome BMP
|
1671 |
+
0.95 (2014-08-17)
|
1672 |
+
add monochrome TGA output
|
1673 |
+
0.94 (2014-05-31)
|
1674 |
+
rename private functions to avoid conflicts with stb_image.h
|
1675 |
+
0.93 (2014-05-27)
|
1676 |
+
warning fixes
|
1677 |
+
0.92 (2010-08-01)
|
1678 |
+
casts to unsigned char to fix warnings
|
1679 |
+
0.91 (2010-07-17)
|
1680 |
+
first public release
|
1681 |
+
0.90 first internal release
|
1682 |
+
*/
|
1683 |
+
|
1684 |
+
/*
|
1685 |
+
------------------------------------------------------------------------------
|
1686 |
+
This software is available under 2 licenses -- choose whichever you prefer.
|
1687 |
+
------------------------------------------------------------------------------
|
1688 |
+
ALTERNATIVE A - MIT License
|
1689 |
+
Copyright (c) 2017 Sean Barrett
|
1690 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
1691 |
+
this software and associated documentation files (the "Software"), to deal in
|
1692 |
+
the Software without restriction, including without limitation the rights to
|
1693 |
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
1694 |
+
of the Software, and to permit persons to whom the Software is furnished to do
|
1695 |
+
so, subject to the following conditions:
|
1696 |
+
The above copyright notice and this permission notice shall be included in all
|
1697 |
+
copies or substantial portions of the Software.
|
1698 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
1699 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
1700 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
1701 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
1702 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
1703 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
1704 |
+
SOFTWARE.
|
1705 |
+
------------------------------------------------------------------------------
|
1706 |
+
ALTERNATIVE B - Public Domain (www.unlicense.org)
|
1707 |
+
This is free and unencumbered software released into the public domain.
|
1708 |
+
Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
|
1709 |
+
software, either in source code form or as a compiled binary, for any purpose,
|
1710 |
+
commercial or non-commercial, and by any means.
|
1711 |
+
In jurisdictions that recognize copyright laws, the author or authors of this
|
1712 |
+
software dedicate any and all copyright interest in the software to the public
|
1713 |
+
domain. We make this dedication for the benefit of the public at large and to
|
1714 |
+
the detriment of our heirs and successors. We intend this dedication to be an
|
1715 |
+
overt act of relinquishment in perpetuity of all present and future rights to
|
1716 |
+
this software under copyright law.
|
1717 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
1718 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
1719 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
1720 |
+
AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
1721 |
+
ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
1722 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
1723 |
+
------------------------------------------------------------------------------
|
1724 |
+
*/
|
util/text_img.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import rembg
|
3 |
+
import torch
|
4 |
+
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
|
11 |
+
# pipe.to("cuda")
|
12 |
+
|
13 |
+
def check_prompt(prompt):
|
14 |
+
if prompt is None:
|
15 |
+
raise gr.Error("Please enter a prompt!")
|
16 |
+
|
17 |
+
controlnet = ControlNetModel.from_pretrained(
|
18 |
+
"diffusers/controlnet-canny-sdxl-1.0",
|
19 |
+
torch_dtype=torch.float16,
|
20 |
+
use_safetensors=True
|
21 |
+
)
|
22 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
|
23 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
24 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
25 |
+
controlnet=controlnet,
|
26 |
+
vae=vae,
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
use_safetensors=True
|
29 |
+
)
|
30 |
+
|
31 |
+
pipe.to("cuda")
|
32 |
+
|
33 |
+
# Function to generate an image from text using diffusion
|
34 |
+
@spaces.GPU
|
35 |
+
def generate_image(prompt, negative_prompt, control_image, scale=0.5):
|
36 |
+
prompt += "no background, side view, minimalist shot, single shoe, no legs, product photo"
|
37 |
+
|
38 |
+
canny_image = get_canny(control_image)
|
39 |
+
|
40 |
+
image = pipe(
|
41 |
+
prompt,
|
42 |
+
negative_prompt=negative_prompt,
|
43 |
+
image=canny_image,
|
44 |
+
controlnet_conditioning_scale=scale,
|
45 |
+
).images[0]
|
46 |
+
image2 = rembg.remove(image)
|
47 |
+
|
48 |
+
return image2
|
49 |
+
|
50 |
+
def get_canny(image):
|
51 |
+
image = np.array(image)
|
52 |
+
|
53 |
+
low_threshold = 100
|
54 |
+
high_threshold = 200
|
55 |
+
|
56 |
+
image = cv2.Canny(image,low_threshold,high_threshold)
|
57 |
+
image = image[:,:,None]
|
58 |
+
image = np.concatenate([image, image, image], axis=2)
|
59 |
+
canny_image = Image.fromarray(image)
|
60 |
+
return canny_image
|