Originial Source:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
@torch.no_grad() # disable gradient calculation for efficiency
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
"""
Args:
hn: near plane distance
hf: far plane distance
dataset: dataset to render
chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
img_index (int, optional): image index to render. Defaults to 0.
nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
H (int, optional): image height. Defaults to 400.
W (int, optional): image width. Defaults to 400.
Returns:
None: None
"""
# Load model
ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
data = [] # list of regenerated pixel values
for i in range(int(np.ceil(H / chunk_size))): # iterate over chunks
# Get chunk of rays
ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
# print(ray_origins_.shape)
ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
# print(ray_directions_.shape)
# Regenerate pixel values
regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
data.append(regenerated_px_values)
img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3) # concatenate chunks(1, 400, 400, 3) -> (400, 400, 3) when chunk_size=1
# Plot image
plt.figure()
plt.imshow(img)
plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight') # bbox_inches='tight' removes white space around image
plt.close()
class NerfModel(nn.Module):
def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
super(NerfModel, self).__init__() # call __init__() of nn.Module
# Embedding layers
# positional encoding
self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
# density estimation
self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim + 1), )
# color estimation
self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2), nn.ReLU(), )
self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )
# hyperparameters
self.embedding_dim_pos = embedding_dim_pos
self.embedding_dim_direction = embedding_dim_direction
self.relu = nn.ReLU()
@staticmethod
def positional_encoding(x, L): # block1
out = [x]
for j in range(L):
out.append(torch.sin(2 ** j * x))
out.append(torch.cos(2 ** j * x))
return torch.cat(out, dim=1)
def forward(self, o, d):
emb_x = self.positional_encoding(o, self.embedding_dim_pos) # emb_x: [batch_size, embedding_dim_pos * 6]
emb_d = self.positional_encoding(d, self.embedding_dim_direction) # emb_d: [batch_size, embedding_dim_direction * 6]
h = self.block1(emb_x) # h: [batch_size, hidden_dim]
tmp = self.block2(torch.cat((h, emb_x), dim=1)) # tmp: [batch_size, hidden_dim + 1]
h, sigma = tmp[:, :-1], self.relu(tmp[:, -1]) # h: [batch_size, hidden_dim], sigma: [batch_size]
h = self.block3(torch.cat((h, emb_d), dim=1)) # h: [batch_size, hidden_dim // 2]
c = self.block4(h) # c: [batch_size, 3]
return c, sigma
def compute_accumulated_transmittance(alphas):
accumulated_transmittance = torch.cumprod(alphas, 1)
return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
accumulated_transmittance[:, :-1]), dim=-1)
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
# Get the device where the input tensors are stored (CPU or GPU)
device = ray_origins.device
# Generate a linearly spaced tensor from hn to hf with nb_bins elements
t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
# Perturb sampling along each ray.
mid = (t[:, :-1] + t[:, 1:]) / 2.
lower = torch.cat((t[:, :1], mid), -1)
upper = torch.cat((mid, t[:, -1:]), -1)
u = torch.rand(t.shape, device=device)
t = lower + (upper - lower) * u # [batch_size, nb_bins]
# Compute the difference between consecutive values of t along each ray
delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
# Compute the 3D points along each ray
x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1) # [batch_size, nb_bins, 3]
# Expand the ray_directions tensor to have nb_bins rows and transpose to match x tensor
ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
# Compute the output color and density sigma
colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
colors = colors.reshape(x.shape)
sigma = sigma.reshape(x.shape[:-1])
# Compute the accumulated transmittance
alpha = 1 - torch.exp(-sigma * delta) # [batch_size, nb_bins]
weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
# Compute the pixel values as a weighted sum of colors along each ray
c = (weights * colors).sum(dim=1) # Pixel values
# Regularization for white background
weight_sum = weights.sum(-1).sum(-1)
# Return the rendered image
return c + 1 - weight_sum.unsqueeze(-1)
def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
nb_bins=192, H=400, W=400):
training_loss = []
for _ in tqdm(range(nb_epochs)):
for batch in data_loader:
# get batch to device that will be used to run the model(GPU or CPU)
ray_origins = batch[:, :3].to(device)
ray_directions = batch[:, 3:6].to(device)
ground_truth_px_values = batch[:, 6:].to(device)
# forward pass
regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
# backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss.append(loss.item())
scheduler.step()
for img_index in range(200):
test(hn, hf, testing_dataset, img_index=img_index, nb_bins=nb_bins, H=H, W=W)
return training_loss
if __name__ == 'main':
device = 'cuda' # 'cuda' only
# Load the training and testing datasets
training_dataset = torch.from_numpy(np.load('training_data.pkl', allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
# Create the NeRF model
model = NerfModel(hidden_dim=256).to(device)
# Create the optimizer and the learning rate scheduler
model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
# Create the data loader
data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True)
train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6, nb_bins=192, H=400,
W=400)
반응형
'MACHINE LEARNING > Artificial Neural Network' 카테고리의 다른 글
NMT with attention / Neural Machine Translation 설명 (0) | 2023.05.18 |
---|---|
Time Series Transformer 의미 및 모델 (0) | 2023.05.09 |
Autoencoder VS Seq2Seq 차이 비교 (0) | 2023.04.13 |
TSC / Time series classification 시계열 분류 정리 (0) | 2023.03.10 |
Likelihood, posteriori, prior (+Bayesian Statistics) 연관성 정리 (0) | 2023.02.08 |