Skip to content

Commit 797dbb6

Browse files
committed
initial commit
0 parents  commit 797dbb6

39 files changed

+2533
-0
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
data/
2+
log/
3+
__pycache__/
4+
model_saves/
5+
.idea/

README.md

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# PyTorch mip-NeRF
2+
3+
A reimplementation of mip-NeRF in PyTorch.
4+
5+
![mipnerf](misc/images/nerfTomipnerf.png)
6+
7+
Not exactly 1-to-1 with the official repo, as we organized the code to out own liking (mostly how the datasets are structued, and hyperparam changes to run the code on a consumer level graphics card), made it more modular, and removed some repetitive code, but it achieves the same results.
8+
9+
## Features
10+
11+
* Can use Spherical, or Spiral poses to generate videos for all 3 datasets
12+
* Spherical:
13+
14+
[//]: # (<video controls>)
15+
16+
[//]: # ( <source src="misc/results/lego/video.mp4" type="video/mp4">)
17+
18+
[//]: # (</video>)
19+
20+
* Spiral:
21+
22+
[//]: # (<video controls>)
23+
24+
[//]: # ( <source src="misc/results/lego/video_spiral.mp4" type="video/mp4">)
25+
26+
[//]: # (</video>)
27+
28+
* Depth and Normals video renderings:
29+
* Depth:
30+
31+
[//]: # (<video controls>)
32+
33+
[//]: # ( <source src="misc/results/lego/depth.mp4" type="video/mp4">)
34+
35+
[//]: # (</video>)
36+
37+
* Normals:
38+
39+
[//]: # (<video controls>)
40+
41+
[//]: # ( <source src="misc/results/lego/normals.mp4" type="video/mp4">)
42+
43+
[//]: # (</video>)
44+
45+
* Can extract meshes
46+
* Default Mesh
47+
48+
[//]: # (<video controls>)
49+
50+
[//]: # ( <source src="misc/results/lego/mesh.mkv" type="video/mkv">)
51+
52+
[//]: # (</video>)
53+
54+
[//]: # (<video controls>)
55+
56+
[//]: # ( <source src="misc/results/mic/mesh.mkv" type="video/mkv">)
57+
58+
[//]: # (</video>)
59+
60+
61+
62+
## Future Plans
63+
64+
In the future we plan on implementing/changing:
65+
66+
* Factoring out more repetitive/redundant code, optimize gpu memory and rps
67+
* Clean up and expand mesh extraction code
68+
* Zoomed poses for multicam dataset
69+
* [Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields](https://jonbarron.info/mipnerf360/) support
70+
* [NeRV: Neural Reflectance and Visibility Fields for Relighting and View Synthesis](https://pratulsrinivasan.github.io/nerv/) support
71+
72+
## Installation/Running
73+
74+
1. Create a conda environment using `mipNeRF.yml`
75+
2. Get the training data
76+
1. run `bash scripts/download_data.sh` to download all 3 datasets: LLFF, Blender, and Multicam.
77+
2. Individually run the bash script corresponding to an individual dataset
78+
* `bash scripts/download_llff.sh` to download LLFF
79+
* `bash scripts/download_blender.sh` to download Blender
80+
* `bash scripts/download_multicam.sh` to download Multicam (Note this will also download the blender dataset since it's derived from it)
81+
3. Optionally change config parameters: can change default parameters in `config.py` or specify with command line arguments
82+
* Default config setup to run on a high-end consumer level graphics card (~8-12GB)
83+
4. Run `python train.py` to train
84+
* `python -m tensorboard.main --logdir=log` to start the tensorboard
85+
5. Run `python visualize.py` to render a video from the trained model
86+
6. Run `python extract_mesh.py` to render a mesh from the trained model
87+
88+
## Code Structure
89+
90+
I explain the specifics of the code more in detail [here](misc/Code.md) but here is a basic rundown.
91+
92+
* `config.py`: Specifies hyperparameters.
93+
* `datasets.py`: Base generic `Dataset` class + 3 default dataset implementations.
94+
* `NeRFDataset`: Base class that all datasets should inherent from.
95+
* `Multicam`: Used for multicam data as in the original mip-NeRF paper.
96+
* `Blender`: Used for the synthetic dataset as in original NeRF.
97+
* `LLFF`: Used for the llff dataset as in the original NeRF.
98+
* `loss.py`: mip-NeRF loss, pretty much just MSE, but also calculates psnr.
99+
* `model.py`: mip-NeRF model, not as modular as the way the original authors wrote it, but easier to understand its structure when laid out verbatim like this.
100+
* `pose_utils.py`: Various functions used to generate poses.
101+
* `ray_utils.py`: Various functions related involving rays that the model uses as input, most are used within the forward function of the model.
102+
* `scheduler.py`: mip-NeRF learning rate scheduler.
103+
* `train.py`: Trains a mip-NeRF model.
104+
* `visualize.py`: Creates the videos using a trained mip-NeRF.
105+
106+
## mip-NeRF Summary
107+
108+
Here's a summary on how NeRF and mip-NeRF work that I wrote when writing this originally.
109+
110+
* [Summary](misc/Summary.md)
111+
112+
## Results
113+
114+
### LLFF - Trex
115+
116+
<div>
117+
<img src="misc/results/trex/LR.png" alt="pic0" width="49%">
118+
<img src="misc/results/trex/Evaluation_PSNR.png" alt="pic1" width="49%">
119+
</div>
120+
<div>
121+
<img src="misc/results/trex/Train_Loss.png" alt="pic2" width="49%">
122+
<img src="misc/results/trex/Train_PSNR.png" alt="pic3" width="49%">
123+
</div>
124+
125+
<br>
126+
Video:
127+
<br>
128+
129+
[//]: # (<video controls>)
130+
131+
[//]: # ( <source src="misc/results/trex/video.mp4" type="video/mp4">)
132+
133+
[//]: # (</video>)
134+
<br>
135+
Depth:
136+
<br>
137+
138+
[//]: # (<video controls>)
139+
140+
[//]: # (<source src="misc/results/trex/depth.mp4" type="video/mp4">)
141+
142+
[//]: # (</video>)
143+
<br>
144+
Normals:
145+
<br>
146+
147+
[//]: # (<video controls>)
148+
149+
[//]: # (<source src="misc/results/trex/normals.mp4" type="video/mp4">)
150+
151+
[//]: # (</video>)
152+
153+
### Blender - Lego
154+
155+
<div>
156+
<img src="misc/results/lego/LR.png" alt="pic0" width="49%">
157+
<img src="misc/results/lego/Evaluation_PSNR.png" alt="pic1" width="49%">
158+
</div>
159+
<div>
160+
<img src="misc/results/lego/Train_Loss.png" alt="pic2" width="49%">
161+
<img src="misc/results/lego/Train_PSNR.png" alt="pic3" width="49%">
162+
</div>
163+
Video:
164+
<br>
165+
166+
[//]: # (<video controls>)
167+
168+
[//]: # ( <source src="misc/results/lego/video.mp4" type="video/mp4">)
169+
170+
[//]: # (</video>)
171+
<br>
172+
Depth:
173+
<br>
174+
175+
[//]: # (<video controls>)
176+
177+
[//]: # ( <source src="misc/results/lego/depth.mp4" type="video/mp4">)
178+
179+
[//]: # (</video>)
180+
<br>
181+
Normals:
182+
<br>
183+
184+
[//]: # (<video controls>)
185+
186+
[//]: # ( <source src="misc/results/lego/normals.mp4" type="video/mp4">)
187+
188+
[//]: # (</video>)
189+
190+
### Multicam - Mic
191+
192+
<div>
193+
<img src="misc/results/mic/LR.png" alt="pic0" width="49%">
194+
<img src="misc/results/mic/Evaluation_PSNR.png" alt="pic1" width="49%">
195+
</div>
196+
<div>
197+
<img src="misc/results/mic/Train_Loss.png" alt="pic2" width="49%">
198+
<img src="misc/results/mic/Train_PSNR.png" alt="pic3" width="49%">
199+
</div>
200+
Video:
201+
<br>
202+
203+
[//]: # (<video controls>)
204+
205+
[//]: # ( <source src="misc/results/mic/video.mp4" type="video/mp4">)
206+
207+
[//]: # (</video>)
208+
<br>
209+
Depth:
210+
<br>
211+
212+
[//]: # (<video controls>)
213+
214+
[//]: # ( <source src="misc/results/mic/depth.mp4" type="video/mp4">)
215+
216+
[//]: # (</video>)
217+
<br>
218+
Normals:
219+
<br>
220+
221+
[//]: # (<video controls>)
222+
223+
[//]: # ( <source src="misc/results/mic/normals.mp4" type="video/mp4">)
224+
225+
[//]: # (</video>)
226+
227+
## References/Contributions
228+
229+
* Thanks to [Nina](https://github.com/ninaahmed) for helping with the code
230+
* [Original NeRF Code in Tensorflow](https://github.com/bmild/nerf)
231+
* [NeRF Project Page](https://www.matthewtancik.com/nerf)
232+
* [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934)
233+
* [Original mip-NeRF Code in JAX](https://github.com/google/mipnerf)
234+
* [mip-NeRF Project Page](https://jonbarron.info/mipnerf/)
235+
* [Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields](https://arxiv.org/abs/2103.13415)
236+
* [nerf_pl](https://github.com/kwea123/nerf_pl)

config.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import argparse
2+
import torch
3+
from os import path
4+
5+
6+
def get_config():
7+
config = argparse.ArgumentParser()
8+
9+
# basic hyperparams to specify where to load/save data from/to
10+
config.add_argument("--log_dir", type=str, default="log")
11+
config.add_argument("--dataset_name", type=str, default="blender")
12+
config.add_argument("--scene", type=str, default="lego")
13+
# model hyperparams
14+
config.add_argument("--use_viewdirs", action="store_false")
15+
config.add_argument("--randomized", action="store_false")
16+
config.add_argument("--ray_shape", type=str, default="cone") # should be "cylinder" if llff
17+
config.add_argument("--white_bkgd", action="store_false") # should be False if using llff
18+
config.add_argument("--override_defaults", action="store_true")
19+
config.add_argument("--num_levels", type=int, default=2)
20+
config.add_argument("--num_samples", type=int, default=128)
21+
config.add_argument("--hidden", type=int, default=256)
22+
config.add_argument("--density_noise", type=float, default=0.0)
23+
config.add_argument("--density_bias", type=float, default=-1.0)
24+
config.add_argument("--rgb_padding", type=float, default=0.001)
25+
config.add_argument("--resample_padding", type=float, default=0.01)
26+
config.add_argument("--min_deg", type=int, default=0)
27+
config.add_argument("--max_deg", type=int, default=16)
28+
config.add_argument("--viewdirs_min_deg", type=int, default=0)
29+
config.add_argument("--viewdirs_max_deg", type=int, default=4)
30+
# loss and optimizer hyperparams
31+
config.add_argument("--coarse_weight_decay", type=float, default=0.1)
32+
config.add_argument("--lr_init", type=float, default=1e-3)
33+
config.add_argument("--lr_final", type=float, default=5e-5)
34+
config.add_argument("--lr_delay_steps", type=int, default=2500)
35+
config.add_argument("--lr_delay_mult", type=float, default=0.1)
36+
config.add_argument("--weight_decay", type=float, default=1e-5)
37+
# training hyperparams
38+
config.add_argument("--factor", type=int, default=2)
39+
config.add_argument("--max_steps", type=int, default=200_000)
40+
config.add_argument("--batch_size", type=int, default=2048)
41+
config.add_argument("--do_eval", action="store_false")
42+
config.add_argument("--continue_training", action="store_true")
43+
config.add_argument("--save_every", type=int, default=1000)
44+
config.add_argument("--device", type=str, default="cuda")
45+
# visualization hyperparams
46+
config.add_argument("--chunks", type=int, default=8192)
47+
config.add_argument("--model_weight_path", default="log/model.pt")
48+
config.add_argument("--visualize_depth", action="store_true")
49+
config.add_argument("--visualize_normals", action="store_true")
50+
# extracting mesh hyperparams
51+
config.add_argument("--x_range", nargs="+", type=float, default=[-1.2, 1.2])
52+
config.add_argument("--y_range", nargs="+", type=float, default=[-1.2, 1.2])
53+
config.add_argument("--z_range", nargs="+", type=float, default=[-1.2, 1.2])
54+
config.add_argument("--grid_size", type=int, default=256)
55+
config.add_argument("--sigma_threshold", type=float, default=50.0)
56+
config.add_argument("--occ_threshold", type=float, default=0.2)
57+
58+
config = config.parse_args()
59+
60+
# default configs for llff, automatically set if dataset is llff and not override_defaults
61+
if config.dataset_name == "llff" and not config.override_defaults:
62+
config.factor = 4
63+
config.ray_shape = "cylinder"
64+
config.white_bkgd = False
65+
config.density_noise = 1.0
66+
67+
config.device = torch.device(config.device)
68+
base_data_path = "data/nerf_llff_data/"
69+
if config.dataset_name == "blender":
70+
base_data_path = "data/nerf_synthetic/"
71+
elif config.dataset_name == "multicam":
72+
base_data_path = "data/nerf_multiscale/"
73+
config.base_dir = path.join(base_data_path, config.scene)
74+
75+
return config

0 commit comments

Comments
 (0)