set_seed(999, reproducible=True)DualGAN training loop
DualGANLoss
DualGANLoss (dualgan:torch.nn.modules.module.Module, l_adv:float=1.0, l_rec:float=1.0, l_idt:float=0.0)
DualGAN loss function. The individual loss terms are also atrributes of this class that are accessed by fastai for recording during training.
Attributes:
self.dualgan (nn.Module): The DualGAN model.
self.l_A (float): lambda_A, weight of domain A losses.
self.l_B (float): lambda_B, weight of domain B losses.
self.crit (AdaptiveLoss): The adversarial loss function (either a BCE or MSE loss depending on lsgan argument)
self.real_A and self.real_B (fastai.torch_core.TensorImage): Real images from domain A and B.
self.gen_loss (torch.FloatTensor): The generator loss calculated in the forward function
self.cyc_loss (torch.FloatTensor): The cyclic loss calculated in the forward function
compute_gradient_penalty
compute_gradient_penalty (D, real_samples, fake_samples)
Calculates the gradient penalty loss for WGAN GP
DualGANTrainer
DualGANTrainer (n_crit=2, clip_value=0.1, l_gp=None)
Learner Callback for training a DualGAN model.
dual_learner
dual_learner (dls:fastai.data.load.DataLoader, m:upit.models.dualgan.DualGAN, opt_func=<function RMSProp>, loss_func=<class '__main__.DualGANLoss'>, show_imgs:bool=True, imgA:bool=True, imgB:bool=True, show_img_interval:bool=10, metrics:list=[], cbs:list=[], lr:Union[float,slice]=0.001, splitter:<built- infunctioncallable>=<function trainable_params>, path:Union[str,pathlib.Path,NoneType]=None, model_dir:Union[str,pathlib.Path]='models', wd:Union[float,int,NoneType]=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Initialize and return a Learner object with the data in dls, DualGAN model m, optimizer function opt_func, metrics metrics, and callbacks cbs. Additionally, if show_imgs is True, it will show intermediate predictions during training. It will show domain B-to-A predictions if imgA is True and/or domain A-to-B predictions if imgB is True. Additionally, it will show images every show_img_interval epochs. OtherLearner` arguments can be passed as well.
| Type | Default | Details | |
|---|---|---|---|
| dls | DataLoaders | DataLoaders containing fastai or PyTorch DataLoaders |
|
| m | DualGAN | ||
| opt_func | Optimizer | OptimWrapper | Adam | Optimization function for training |
| loss_func | callable | None | None | Loss function. Defaults to dls loss |
| show_imgs | bool | True | |
| imgA | bool | True | |
| imgB | bool | True | |
| show_img_interval | bool | 10 | |
| metrics | callable | MutableSequence | None | None | Metrics to calculate on validation set |
| cbs | Callback | MutableSequence | None | None | Callbacks to add to Learner |
| lr | float | slice | 0.001 | Default learning rate |
| splitter | callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |
| path | str | Path | None | None | Parent directory to save, load, and export models. Defaults to dls path |
| model_dir | str | Path | models | Subdirectory to save and load models |
| wd | float | int | None | None | Default weight decay |
| wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |
| train_bn | bool | True | Train frozen normalization layers |
| moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |
| default_cbs | bool | True | Include default Callbacks |
Quick Test
horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')folders = horse2zebra.ls().sorted()trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]dls = get_dls(trainA_path, trainB_path,num_A=100)dual_gan = DualGAN()
learn = dual_learner(dls, dual_gan,show_img_interval=1)learn.show_training_loop()Start Fit
- before_fit : [TrainEvalCallback, ShowImgsCallback, Recorder, ProgressCallback]
Start Epoch Loop
- before_epoch : [Recorder, ProgressCallback]
Start Train
- before_train : [TrainEvalCallback, DualGANTrainer, Recorder, ProgressCallback]
Start Batch Loop
- before_batch : [DualGANTrainer]
- after_pred : [DualGANTrainer]
- after_loss : []
- before_backward: []
- before_step : []
- after_step : []
- after_cancel_batch: []
- after_batch : [TrainEvalCallback, Recorder, ProgressCallback]
End Batch Loop
End Train
- after_cancel_train: [Recorder]
- after_train : [Recorder, ProgressCallback]
Start Valid
- before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
Start Batch Loop
- **CBs same as train batch**: []
End Batch Loop
End Valid
- after_cancel_validate: [Recorder]
- after_validate : [Recorder, ProgressCallback]
End Epoch Loop
- after_cancel_epoch: []
- after_epoch : [ShowImgsCallback, Recorder]
End Fit
- after_cancel_fit: []
- after_fit : [ProgressCallback]
test_eq(type(learn),Learner)learn.fit_flat_lin(5,5,2e-4)| epoch | train_loss | adv_loss_A | adv_loss_B | rec_loss_A | rec_loss_B | D_A_loss | D_B_loss | time |
|---|---|---|---|---|---|---|---|---|
| 0 | -1.225994 | -0.965739 | -0.931990 | 0.354460 | 0.353114 | -0.010898 | -0.002097 | 00:06 |
| 1 | -1.383732 | -0.999190 | -0.998738 | 0.258967 | 0.261373 | -0.000011 | -0.000268 | 00:06 |
| 2 | -1.448422 | -0.999503 | -0.999257 | 0.244793 | 0.240065 | 0.000066 | -0.000068 | 00:06 |
| 3 | -1.488793 | -0.999653 | -0.999547 | 0.240394 | 0.223500 | 0.000069 | -0.000088 | 00:07 |
| 4 | -1.529549 | -0.999728 | -0.999568 | 0.212509 | 0.204648 | -0.000021 | -0.000083 | 00:07 |
| 5 | -1.574171 | -0.999754 | -0.999664 | 0.190567 | 0.178492 | 0.000010 | -0.000069 | 00:08 |
| 6 | -1.633542 | -0.999816 | -0.999724 | 0.143122 | 0.139768 | 0.000010 | -0.000058 | 00:08 |
| 7 | -1.675023 | -0.999844 | -0.999737 | 0.129500 | 0.133948 | 0.000027 | -0.000105 | 00:08 |
| 8 | -1.706186 | -0.999842 | -0.999745 | 0.118132 | 0.129545 | 0.000022 | -0.000052 | 00:09 |
| 9 | -1.726011 | -0.999837 | -0.999760 | 0.118993 | 0.126145 | -0.000012 | -0.000052 | 00:09 |
/home/tmabraham/anaconda3/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")

learn.recorder.plot_loss(with_valid=False)