= torch.randn(4,3,256,256) img1
GANILLA model
We use the generator that was introduced in the GANILLA paper.
Generator
BasicBlock_Ganilla
BasicBlock_Ganilla (in_planes, planes, use_dropout, stride=1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
PyramidFeatures
PyramidFeatures (C2_size, C3_size, C4_size, C5_size, fpn_weights, feature_size=128)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
ResNet
ResNet (input_nc, output_nc, ngf, use_dropout, fpn_weights, block, layers)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
init_weights
init_weights (net, init_type='normal', gain=0.02)
ganilla_generator
ganilla_generator (input_nc, output_nc, ngf, drop, fpn_weights=[1.0, 1.0, 1.0, 1.0], init_type='normal', gain=0.02, **kwargs)
Constructs a ResNet-18 GANILLA generator.
Let’s test for a few things: 1. The generator can indeed be initialized correctly 2. A random image can be passed into the model successfully with the correct size output
First let’s create a random batch:
= ganilla_generator(3,3,64,0.5)
m with torch.no_grad():
= m(img1)
out1 out1.shape
torch.Size([4, 3, 256, 256])
Full model
We group two discriminators and two generators in a single model, then a Callback
(defined in 02_cyclegan_training.ipynb
) will take care of training them properly. The discriminator and training loop is the same as CycleGAN.
GANILLA
GANILLA (ch_in:int=3, ch_out:int=3, n_features:int=64, disc_layers:int=3, lsgan:bool=True, drop:float=0.0, norm_layer:torch.nn.modules.module.Module=None, fpn_weights:list=[1.0, 1.0, 1.0, 1.0], init_type:str='normal', gain:float=0.02, **kwargs)
GANILLA model.
When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators). Also outputs identity images after passing the images into generators that outputs its domain type (needed for identity loss).
Attributes:
G_A
(nn.Module
): takes real input B and generates fake input A
G_B
(nn.Module
): takes real input A and generates fake input B
D_A
(nn.Module
): trained to make the difference between real input A and fake input A
D_B
(nn.Module
): trained to make the difference between real input B and fake input B
GANILLA.__init__
GANILLA.__init__ (ch_in:int=3, ch_out:int=3, n_features:int=64, disc_layers:int=3, lsgan:bool=True, drop:float=0.0, norm_layer:torch.nn.modules.module.Module=None, fpn_weights:list=[1.0, 1.0, 1.0, 1.0], init_type:str='normal', gain:float=0.02, **kwargs)
Constructor for GANILLA model.
Arguments:
ch_in
(int
): Number of input channels (default=3)
ch_out
(int
): Number of output channels (default=3)
n_features
(int
): Number of input features (default=64)
disc_layers
(int
): Number of discriminator layers (default=3)
lsgan
(bool
): LSGAN training objective (output unnormalized float) or not? (default=True)
drop
(float
): Level of dropout (default=0)
norm_layer
(nn.Module
): Type of normalization layer to use in the discriminator (default=None) fpn_weights
(list
): Weights for feature pyramid network (default=[1.0, 1.0, 1.0, 1.0])
init_type
(str
): Type of initialization (default=‘normal’)
gain
(float
): Gain for initialization (default=0.02)
GANILLA.forward
GANILLA.forward (input)
Forward function for CycleGAN model. The input is a tuple of a batch of real images from both domains A and B.
Quick model tests
Again, let’s check that the model can be called sucsessfully and outputs the correct shapes.
= GANILLA()
ganilla_model = torch.randn(4,3,256,256)
img1 = torch.randn(4,3,256,256) img2
with torch.no_grad(): ganilla_output = ganilla_model((img1,img2))
CPU times: user 42.1 s, sys: 8.36 s, total: 50.5 s
Wall time: 1.38 s
len(ganilla_output),4)
test_eq(for output_batch in ganilla_output:
test_eq(output_batch.shape,img1.shape)
'upit-ganilla-test') ganilla_model.push_to_hub(
Cloning https://huggingface.co/tmabraham/upit-ganilla-test into local empty directory.
To https://huggingface.co/tmabraham/upit-ganilla-test
264c8d0..38cafb3 main -> main
'https://huggingface.co/tmabraham/upit-ganilla-test/commit/38cafb3d4cca069313b8ed03d88ecc88db28f3d5'
'tmabraham/upit-ganilla-test') ganilla_model.from_pretrained(