In the last few months, I have been building an in-browser differentiable photo editor. The goal is to make each image processing function easily parameterisable by a machine learning model or be a learned transformation. I want to be able to use each in a loss function for downstream tasks. Today I wanted to share an automatic white balance correction approach (and any other global image transformations).
I’ll make no hard claims about the state of the art results, as I am primarily looking for a parameter efficient (4k parameters!) and fast approach that performs better than grey world and YUV coordinates based transforms. However, in my preliminary tests, it works pretty well compared to U-NET based approaches. Importantly, it can be trained in a couple of hours on a GTX 1070 without any hyperparameter tuning. It is also performant in the browser on a lightweight device, which is my primary use case.
So, let’s jump into the code.
class WhiteBalance(nn.Module):
def __init__(self, hidden_dim=10):
super(WhiteBalance, self).__init__()
self.parameter_network = nn.Sequential(
nn.Conv2d(3, 7, kernel_size=(3, 3), padding=1),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(7, 14, kernel_size=(3, 3), padding=1),
nn.MaxPool2d(kernel_size=2),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(14, 3, kernel_size=(3, 3), padding=1),
)
self.feature_dim = 192
self.output_dim = 96 # Enough Parameters for a 16 dim inner neural network
self.weight_size = self.output_dim//2
self.parameter_network.apply(weights_init)
self.linear = nn.Sequential(
nn.Linear(self.feature_dim, hidden_dim),
nn.LeakyReLU(negative_slope=0.2),
nn.Linear(hidden_dim, self.output_dim),
)
def inner_network(self, image, params):
batch, c, h, w = image.size()
# split parameters for each layer of the inner network
w_1, w_2 = params[:, :self.weight_size], params[:, self.weight_size:]
hidden_dim_inner = self.weight_size//c
# Reshape matrix into hidden_dim x channels
w_1, w_2 = w_1.view(-1, hidden_dim_inner, c), w_2.view(-1, c, hidden_dim_inner)
# apply batch of neural networks to image pixelwise
pixels = image.permute(0, 2, 3, 1).contiguous().view(batch, -1, c)
output = torch.selu(pixels.bmm(w_1.permute(0, 2, 1)))
output = output.bmm(w_2.permute(0, 2, 1)).view(-1, h, w, c).permute(0, 3, 1, 2)
return output
def forward(self, input_image, full_resolution_image):
batch = input_image.size(0)
# Output Parameters to a small inner neural network using downsampled input image or patch.
params = self.linear(self.parameter_network(input_image).view(batch, self.feature_dim))
# Apply pixelwise neural network to full scale image
output = self.inner_network(full_resolution_image, params)
# Scale in [0, 1]
return torch.clamp(output, 0, 1)
So, a little about the architecture, the model takes a 32x32 sRGB image as input and outputs the parameters for a small two-layer neural network with a hidden dimension of 16 and an input and output size of 3. What this means, we learn to output a pixel-wise transformation that can apply to an image of any resolution.
There are a lot of benefits to this.
- We can easily pass the parameters to a WebGL shader.
- We can create a 3D lookup table (LUT) by feeding each pixel color progressively, allowing us to apply the same white balancing over a video in realtime.
- It is pretty flexible in terms of what type of nonlinear pixel-wise transforms we can learn.
As far as training goes, the dataset would be before and after pairs of white-balanced photos with an MAE loss or similar. I noticed performance gains from training with a higher resolution than the input to the parameter network. For example, you might have a 32x32 image as input. However, you compute the loss on a 224x224 version. It seems to have a regularisation effect.
Think something like this...
input_image, larger_image, larger_target = batch
output = model(input_image, larger_image)
loss = F.l1_loss(output, target)
Here are some samples. (Ignore duplicates, testing slight input variations)
Before
After
There are ways to speed this up and reduce parameters further if you are looking for an option for a low energy device. For example, you could do a convolutional channel-wise average pooling over the output of the parameter network (perhaps with a learned weighting?). That would allow you to remove the linear networks, as well as the inner neural network.
But I prefer the flexibility of a global nonlinear pixel-wise transform for my use case.
Thanks for Reading, if you're interested in this sorta stuff I will be posting quite often to my Twitter.
Top comments (0)