Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import swapae.util as util | |
| import torch.nn.functional as F | |
| from swapae.models.networks import BaseNetwork | |
| from swapae.models.networks.stylegan2_layers import ConvLayer, ToRGB, EqualLinear, StyledConv | |
| class UpsamplingBlock(torch.nn.Module): | |
| def __init__(self, inch, outch, styledim, | |
| blur_kernel=[1, 3, 3, 1], use_noise=False): | |
| super().__init__() | |
| self.inch, self.outch, self.styledim = inch, outch, styledim | |
| self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True, | |
| blur_kernel=blur_kernel, use_noise=use_noise) | |
| self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False, | |
| use_noise=use_noise) | |
| def forward(self, x, style): | |
| return self.conv2(self.conv1(x, style), style) | |
| class ResolutionPreservingResnetBlock(torch.nn.Module): | |
| def __init__(self, opt, inch, outch, styledim): | |
| super().__init__() | |
| self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=False) | |
| self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False) | |
| if inch != outch: | |
| self.skip = ConvLayer(inch, outch, 1, activate=False, bias=False) | |
| else: | |
| self.skip = torch.nn.Identity() | |
| def forward(self, x, style): | |
| skip = self.skip(x) | |
| res = self.conv2(self.conv1(x, style), style) | |
| return (skip + res) / math.sqrt(2) | |
| class UpsamplingResnetBlock(torch.nn.Module): | |
| def __init__(self, inch, outch, styledim, blur_kernel=[1, 3, 3, 1], use_noise=False): | |
| super().__init__() | |
| self.inch, self.outch, self.styledim = inch, outch, styledim | |
| self.conv1 = StyledConv(inch, outch, 3, styledim, upsample=True, blur_kernel=blur_kernel, use_noise=use_noise) | |
| self.conv2 = StyledConv(outch, outch, 3, styledim, upsample=False, use_noise=use_noise) | |
| if inch != outch: | |
| self.skip = ConvLayer(inch, outch, 1, activate=True, bias=True) | |
| else: | |
| self.skip = torch.nn.Identity() | |
| def forward(self, x, style): | |
| skip = F.interpolate(self.skip(x), scale_factor=2, mode='bilinear', align_corners=False) | |
| res = self.conv2(self.conv1(x, style), style) | |
| return (skip + res) / math.sqrt(2) | |
| class GeneratorModulation(torch.nn.Module): | |
| def __init__(self, styledim, outch): | |
| super().__init__() | |
| self.scale = EqualLinear(styledim, outch) | |
| self.bias = EqualLinear(styledim, outch) | |
| def forward(self, x, style): | |
| if style.ndimension() <= 2: | |
| return x * (1 * self.scale(style)[:, :, None, None]) + self.bias(style)[:, :, None, None] | |
| else: | |
| style = F.interpolate(style, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False) | |
| return x * (1 * self.scale(style)) + self.bias(style) | |
| class StyleGAN2ResnetGenerator(BaseNetwork): | |
| """ The Generator (decoder) architecture described in Figure 18 of | |
| Swapping Autoencoder (https://arxiv.org/abs/2007.00653). | |
| At high level, the architecture consists of regular and | |
| upsampling residual blocks to transform the structure code into an RGB | |
| image. The global code is applied at each layer as modulation. | |
| Here's more detailed architecture: | |
| 1. SpatialCodeModulation: First of all, modulate the structure code | |
| with the global code. | |
| 2. HeadResnetBlock: resnets at the resolution of the structure code, | |
| which also incorporates modulation from the global code. | |
| 3. UpsamplingResnetBlock: resnets that upsamples by factor of 2 until | |
| the resolution of the output RGB image, along with the global code | |
| modulation. | |
| 4. ToRGB: Final layer that transforms the output into 3 channels (RGB). | |
| Each components of the layers borrow heavily from StyleGAN2 code, | |
| implemented by Seonghyeon Kim. | |
| https://github.com/rosinality/stylegan2-pytorch | |
| """ | |
| def modify_commandline_options(parser, is_train): | |
| parser.add_argument("--netG_scale_capacity", default=1.0, type=float) | |
| parser.add_argument( | |
| "--netG_num_base_resnet_layers", | |
| default=2, type=int, | |
| help="The number of resnet layers before the upsampling layers." | |
| ) | |
| parser.add_argument("--netG_use_noise", type=util.str2bool, nargs='?', const=True, default=True) | |
| parser.add_argument("--netG_resnet_ch", type=int, default=256) | |
| return parser | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| num_upsamplings = opt.netE_num_downsampling_sp | |
| blur_kernel = [1, 3, 3, 1] if opt.use_antialias else [1] | |
| self.global_code_ch = opt.global_code_ch + opt.num_classes | |
| self.add_module( | |
| "SpatialCodeModulation", | |
| GeneratorModulation(self.global_code_ch, opt.spatial_code_ch)) | |
| in_channel = opt.spatial_code_ch | |
| for i in range(opt.netG_num_base_resnet_layers): | |
| # gradually increase the number of channels | |
| out_channel = (i + 1) / opt.netG_num_base_resnet_layers * self.nf(0) | |
| out_channel = max(opt.spatial_code_ch, round(out_channel)) | |
| layer_name = "HeadResnetBlock%d" % i | |
| new_layer = ResolutionPreservingResnetBlock( | |
| opt, in_channel, out_channel, self.global_code_ch) | |
| self.add_module(layer_name, new_layer) | |
| in_channel = out_channel | |
| for j in range(num_upsamplings): | |
| out_channel = self.nf(j + 1) | |
| layer_name = "UpsamplingResBlock%d" % (2 ** (4 + j)) | |
| new_layer = UpsamplingResnetBlock( | |
| in_channel, out_channel, self.global_code_ch, | |
| blur_kernel, opt.netG_use_noise) | |
| self.add_module(layer_name, new_layer) | |
| in_channel = out_channel | |
| last_layer = ToRGB(out_channel, self.global_code_ch, | |
| blur_kernel=blur_kernel) | |
| self.add_module("ToRGB", last_layer) | |
| def nf(self, num_up): | |
| ch = 128 * (2 ** (self.opt.netE_num_downsampling_sp - num_up)) | |
| ch = int(min(512, ch) * self.opt.netG_scale_capacity) | |
| return ch | |
| def forward(self, spatial_code, global_code): | |
| spatial_code = util.normalize(spatial_code) | |
| global_code = util.normalize(global_code) | |
| x = self.SpatialCodeModulation(spatial_code, global_code) | |
| for i in range(self.opt.netG_num_base_resnet_layers): | |
| resblock = getattr(self, "HeadResnetBlock%d" % i) | |
| x = resblock(x, global_code) | |
| for j in range(self.opt.netE_num_downsampling_sp): | |
| key_name = 2 ** (4 + j) | |
| upsampling_layer = getattr(self, "UpsamplingResBlock%d" % key_name) | |
| x = upsampling_layer(x, global_code) | |
| rgb = self.ToRGB(x, global_code, None) | |
| return rgb | |