= torch.randn(4,3,256,256) img1
CycleGAN model
We use the models that were introduced in the cycleGAN paper.
Generator
convT_norm_relu
convT_norm_relu (ch_in:int, ch_out:int, norm_layer:torch.nn.modules.module.Module, ks:int=3, stride:int=2, bias:bool=True)
pad_conv_norm_relu
pad_conv_norm_relu (ch_in:int, ch_out:int, pad_mode:str, norm_layer:torch.nn.modules.module.Module, ks:int=3, bias:bool=True, pad=1, stride:int=1, activ:bool=True, init=<function kaiming_normal_>, init_gain:int=0.02)
ResnetBlock
ResnetBlock (dim:int, pad_mode:str='reflection', norm_layer:torch.nn.modules.module.Module=None, dropout:float=0.0, bias:bool=True)
nn.Module for the ResNet Block
resnet_generator
resnet_generator (ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:torch.nn.modules.module.Module=None, dropout:float=0.0, n_blocks:int=9, pad_mode:str='reflection')
Test generator
Let’s test for a few things: 1. The generator can indeed be initialized correctly 2. A random image can be passed into the model successfully with the correct size output 3. The CycleGAN generator is equivalent to the original implementation
First let’s create a random batch:
= resnet_generator(3,3)
m with torch.no_grad():
= m(img1)
out1 out1.shape
torch.Size([4, 3, 256, 256])
= define_G(3,3,64,'resnet_9blocks', norm='instance')
m_junyanz with torch.no_grad():
= m_junyanz(img1)
out2 out2.shape
initialize network with normal
torch.Size([4, 3, 256, 256])
compare_networks
compare_networks (a, b)
A simple function to compare the printed model representations as a proxy for actually comparing two models
test_eq(out1.shape,img1.shape)
test_eq(out2.shape,img1.shape)assert compare_networks(list(m_junyanz.children())[0],m)
Passed!
Discriminator
conv_norm_lr
conv_norm_lr (ch_in:int, ch_out:int, norm_layer:torch.nn.modules.module.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, activ:bool=True, slope:float=0.2, init=<function normal_>, init_gain:int=0.02)
discriminator
discriminator (ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:torch.nn.modules.module.Module=None, sigmoid:bool=False)
Test discriminator
Let’s test for similar things: 1. The discriminator can indeed be initialized correctly 2. A random image can be passed into the discriminator successfully with the correct size output 3. The CycleGAN discriminator is equivalent to the original implementation
= discriminator(3)
d with torch.no_grad():
= d(img1)
out1 out1.shape
torch.Size([4, 1, 30, 30])
= torch.randn(4,3,256,256) img1
= define_D(3,64,'basic',norm='instance')
d_junyanz with torch.no_grad():
= d_junyanz(img1)
out2 out2.shape
initialize network with normal
torch.Size([4, 1, 30, 30])
4, 1, 30, 30]))
test_eq(out1.shape,torch.Size([4, 1, 30, 30]))
test_eq(out2.shape,torch.Size([assert compare_networks(list(d_junyanz.children())[0],d)
Passed!
Full model
We group two discriminators and two generators in a single model, then a Callback
(defined in 02_cyclegan_training.ipynb
) will take care of training them properly. We use the PyTorchModelHubMixin
to provide support for pushing to and loading from the HuggingFace Hub.
CycleGAN
CycleGAN (ch_in:int=3, ch_out:int=3, n_features:int=64, disc_layers:int=3, gen_blocks:int=9, lsgan:bool=True, drop:float=0.0, norm_layer:torch.nn.modules.module.Module=None)
CycleGAN model.
When called, takes in input batch of real images from both domains and outputs fake images for the opposite domains (with the generators). Also outputs identity images after passing the images into generators that outputs its domain type (needed for identity loss).
Attributes:
G_A
(nn.Module
): takes real input B and generates fake input A
G_B
(nn.Module
): takes real input A and generates fake input B
D_A
(nn.Module
): trained to make the difference between real input A and fake input A
D_B
(nn.Module
): trained to make the difference between real input B and fake input B
CycleGAN.__init__
CycleGAN.__init__ (ch_in:int=3, ch_out:int=3, n_features:int=64, disc_layers:int=3, gen_blocks:int=9, lsgan:bool=True, drop:float=0.0, norm_layer:torch.nn.modules.module.Module=None)
Constructor for CycleGAN model.
Arguments:
ch_in
(int
): Number of input channels (default=3)
ch_out
(int
): Number of output channels (default=3)
n_features
(int
): Number of input features (default=64)
disc_layers
(int
): Number of discriminator layers (default=3)
gen_blocks
(int
): Number of residual blocks in the generator (default=9)
lsgan
(bool
): LSGAN training objective (output unnormalized float) or not? (default=True)
drop
(float
): Level of dropout (default=0)
norm_layer
(nn.Module
): Type of normalization layer to use in the models (default=None)
CycleGAN.forward
CycleGAN.forward (input)
Forward function for CycleGAN model. The input is a tuple of a batch of real images from both domains A and B.
ModelHubMixin.push_to_hub
ModelHubMixin.push_to_hub (repo_id:str, config:Optional[dict]=None, commit_message:str='Push model using huggingface_hub.', private:bool=False, api_endpoint:Optional[str]=None, token:Optional[str]=None, branch:Optional[str]=None, create_pr:Optional[bool]=None, allow_patterns: Union[List[str],str,NoneType]=None, ignore_pat terns:Union[List[str],str,NoneType]=None, dele te_patterns:Union[List[str],str,NoneType]=None )
Upload model checkpoint to the Hub.
Use allow_patterns
and ignore_patterns
to precisely filter which files should be pushed to the hub. Use delete_patterns
to delete existing remote files in the same commit. See [upload_folder
] reference for more details.
Args: repo_id (str
): ID of the repository to push to (example: "username/my-model"
). config (dict
, optional): Configuration object to be saved alongside the model weights. commit_message (str
, optional): Message to commit while pushing. private (bool
, optional, defaults to False
): Whether the repository created should be private. api_endpoint (str
, optional): The API endpoint to use when pushing the model to the hub. token (str
, optional): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running huggingface-cli login
. branch (str
, optional): The git branch on which to push the model. This defaults to "main"
. create_pr (boolean
, optional): Whether or not to create a Pull Request from branch
with that commit. Defaults to False
. allow_patterns (List[str]
or str
, optional): If provided, only files matching at least one pattern are pushed. ignore_patterns (List[str]
or str
, optional): If provided, files matching any of the patterns are not pushed. delete_patterns (List[str]
or str
, optional): If provided, remote files matching any of the patterns will be deleted from the repo.
Returns: The url of the commit of your model in the given repository.
ModelHubMixin.from_pretrained
ModelHubMixin.from_pretrained (cls:Type[~T], pretrained_model_name_or_path:Union[str,pa thlib.Path], force_download:bool=False, resume_download:bool=False, proxies:Optional[Dict]=None, token:Union[bool,str,NoneType]=None, cache _dir:Union[pathlib.Path,str,NoneType]=None , local_files_only:bool=False, revision:Optional[str]=None, **model_kwargs)
Download a model from the Huggingface Hub and instantiate it.
Args: pretrained_model_name_or_path (str
, Path
): - Either the model_id
(string) of a model hosted on the Hub, e.g. bigscience/bloom
. - Or a path to a directory
containing model weights saved using [~transformers.PreTrainedModel.save_pretrained
], e.g., ../path/to/my_model_directory/
. revision (str
, optional): Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the latest commit on main
branch. force_download (bool
, optional, defaults to False
): Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding the existing cache. resume_download (bool
, optional, defaults to False
): Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (Dict[str, str]
, optional): A dictionary of proxy servers to use by protocol or endpoint, e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on every request. token (str
or bool
, optional): The token to use as HTTP bearer authorization for remote files. By default, it will use the token cached when running huggingface-cli login
. cache_dir (str
, Path
, optional): Path to the folder where cached files are stored. local_files_only (bool
, optional, defaults to False
): If True
, avoid downloading the file and return the path to the local cached file if it exists. model_kwargs (Dict
, optional): Additional kwargs to pass to the model during initialization.
Quick model tests
Again, let’s check that the model can be called sucsessfully and outputs the correct shapes.
= CycleGAN()
cyclegan_model = torch.randn(4,3,256,256)
img1 = torch.randn(4,3,256,256) img2
with torch.no_grad(): cyclegan_output = cyclegan_model((img1,img2))
CPU times: user 1min 15s, sys: 6.67 s, total: 1min 22s
Wall time: 2.25 s
len(cyclegan_output),4)
test_eq(for output_batch in cyclegan_output:
test_eq(output_batch.shape,img1.shape)
'upit-cyclegan-test') cyclegan_model.push_to_hub(
Cloning https://huggingface.co/tmabraham/upit-cyclegan-test into local empty directory.
To https://huggingface.co/tmabraham/upit-cyclegan-test
a41e9e0..2331f7d main -> main
'https://huggingface.co/tmabraham/upit-cyclegan-test/commit/2331f7d345d719ac1fdfb10b2cddf58abd7931bb'
'tmabraham/upit-cyclegan-test') cyclegan_model.from_pretrained(