import mmcv
import matplotlib.pyplot as plt
from fastcore.basics import *
from fastai.vision.all import *
from fastai.torch_basics import *
import warnings
warnings.filterwarnings("ignore")
import kornia
from kornia.constants import Resample
from kornia.color import *
from kornia import augmentation as K
# import kornia.augmentation as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF
from torchvision.transforms import transforms
from torchvision.transforms import PILToTensor
from functools import partial
from timm.models.layers import trunc_normal_, DropPath
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import _cfg
from einops import rearrange
from timm.models.registry import register_model

set_seed(105)
EPS = 1e-12

def choose_layer_norm(name, num_features, n_dims=2, eps=EPS, **kwargs):
    if name in ['BN', 'batch', 'batch_norm']:
        if n_dims == 1:
            layer_norm = nn.BatchNorm1d(num_features, eps=eps)
        elif n_dims == 2:
            layer_norm = nn.BatchNorm2d(num_features, eps=eps)
        else:
            raise NotImplementedError("n_dims is expected 1 or 2, but give {}.".format(n_dims))
    else:
        raise NotImplementedError("Not support {} layer normalization.".format(name))
    
    return layer_norm

def choose_nonlinear(name, **kwargs):
    if name == 'relu':
        nonlinear = nn.ReLU()
    else:
        raise NotImplementedError("Invalid nonlinear function is specified. Choose 'relu' instead of {}.".format(name))
    
    return nonlinear
from torch.nn.modules.utils import _pair
_pair(1)
(1, 1)
class ConvBlock2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, norm=True, nonlinear='relu', eps=EPS):
        super().__init__()

        assert stride == 1, "`stride` is expected 1"

        self.kernel_size = _pair(kernel_size)
        self.dilation = _pair(dilation)

        self.norm = norm
        self.nonlinear = nonlinear

        if self.norm:
            if type(self.norm) is bool:
                name = 'BN'
            else:
                name = self.norm
            self.norm2d = choose_layer_norm(name, in_channels, n_dims=2, eps=eps)
        
        if self.nonlinear is not None:
            self.nonlinear2d = choose_nonlinear(self.nonlinear)
        
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation)

    def forward(self, input):
        """
        Args:
            input (batch_size, in_channels, H, W)
        Returns:
            output (batch_size, out_channels, H, W)
        """
        Kh, Kw = self.kernel_size
        Dh, Dw = self.dilation
        
        padding_height = (Kh - 1) * Dh
        padding_width = (Kw - 1) * Dw
        padding_up = padding_height // 2
        padding_bottom = padding_height - padding_up
        padding_left = padding_width // 2
        padding_right = padding_width - padding_left

        x = input

        if self.norm:
            x = self.norm2d(x)
        if self.nonlinear:
            x = self.nonlinear2d(x)
        
        x = F.pad(x, (padding_left, padding_right, padding_up, padding_bottom))
        output = self.conv2d(x)

        return output
temp = ConvBlock2d(3,128,3,1,1,norm=False,nonlinear=None)
temp(torch.randn(1,3,32,32)).shape
torch.Size([1, 128, 32, 32])
EPS = 1e-12

class D2BlockFixedDilation(nn.Module):
    def __init__(self, in_channels, growth_rate, kernel_size, dilation=1, norm=True, nonlinear='relu', depth=None, eps=EPS):
        """
        Args:
            in_channels <int>: # of input channels
            growth_rate <int> or <list<int>>: # of output channels
            kernel_size <int> or <tuple<int>>: Kernel size
            dilation <int>: Dilation od dilated convolution.
            norm <bool> or <list<bool>>: Applies batch normalization.
            nonlinear <str> or <list<str>>: Applies nonlinear function.
            depth <int>: If `growth_rate` is given by list, len(growth_rate) must be equal to `depth`.
        """
        super().__init__()

        if type(growth_rate) is int:
            assert depth is not None, "Specify `depth`"
            growth_rate = [growth_rate] * depth
        elif type(growth_rate) is list:
            if depth is not None:
                assert depth == len(growth_rate), "`depth` is different from `len(growth_rate)`"
            depth = len(growth_rate)
        else:
            raise ValueError("Not support growth_rate={}".format(growth_rate))
        
        if not type(dilation) is int:
            raise ValueError("Not support dilated={}".format(dilated))
        
        if type(norm) is bool:
            assert depth is not None, "Specify `depth`"
            norm = [norm] * depth
        elif type(norm) is list:
            if depth is not None:
                assert depth == len(norm), "`depth` is different from `len(norm)`"
            depth = len(norm)
        else:
            raise ValueError("Not support norm={}".format(norm))

        if type(nonlinear) is bool or type(nonlinear) is str:
            assert depth is not None, "Specify `depth`"
            nonlinear = [nonlinear] * depth
        elif type(nonlinear) is list:
            if depth is not None:
                assert depth == len(nonlinear), "`depth` is different from `len(nonlinear)`"
            depth = len(nonlinear)
        else:
            raise ValueError("Not support nonlinear={}".format(nonlinear))
        
        self.growth_rate = growth_rate
        self.depth = depth

        net = []
        _in_channels = in_channels - sum(growth_rate)

        for idx in range(depth):
            if idx == 0:
                _in_channels = in_channels
            else:
                _in_channels = growth_rate[idx - 1]
            _out_channels = sum(growth_rate[idx:])
            
            conv_block = ConvBlock2d(_in_channels, _out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, norm=norm[idx], nonlinear=nonlinear[idx], eps=eps)
            net.append(conv_block)
        
        self.net = nn.Sequential(*net)
    
    def forward(self, input):
        """
        Args:
            input: (batch_size, in_channels, H, W)
        Returns:
            output: (batch_size, out_channels, H, W), where out_channels = growth_rate[-1].
        """
        growth_rate, depth = self.growth_rate, self.depth

        x_residual = 0

        for idx in range(depth):
            if idx == 0:
                x = input
            else:
                _in_channels = growth_rate[idx - 1]
                sections = [_in_channels, sum(growth_rate[idx:])]
                x, x_residual = torch.split(x_residual, sections, dim=1)
            
            x = self.net[idx](x)
            x_residual = x_residual + x
        
        output = x_residual

        return output

class D2Block(nn.Module):
    def __init__(self, in_channels, growth_rate, kernel_size, dilated=True, norm=True, nonlinear='relu', depth=None, eps=EPS):
        """
        Args:
            in_channels <int>: # of input channels
            growth_rate <int> or <list<int>>: # of output channels
            kernel_size <int> or <tuple<int>>: Kernel size
            dilated <bool> or <list<bool>>: Applies dilated convolution.
            norm <bool> or <list<bool>>: Applies batch normalization.
            nonlinear <str> or <list<str>>: Applies nonlinear function.
            depth <int>: If `growth_rate` is given by list, len(growth_rate) must be equal to `depth`.
        """
        super().__init__()

        if type(growth_rate) is int:
            assert depth is not None, "Specify `depth`"
            growth_rate = [growth_rate] * depth
        elif type(growth_rate) is list:
            if depth is not None:
                assert depth == len(growth_rate), "`depth` is different from `len(growth_rate)`"
            depth = len(growth_rate)
        else:
            raise ValueError("Not support growth_rate={}".format(growth_rate))
        
        if type(dilated) is bool:
            assert depth is not None, "Specify `depth`"
            dilated = [dilated] * depth
        elif type(dilated) is list:
            if depth is not None:
                assert depth == len(dilated), "`depth` is different from `len(dilated)`"
            depth = len(dilated)
        else:
            raise ValueError("Not support dilated={}".format(dilated))
        
        if type(norm) is bool:
            assert depth is not None, "Specify `depth`"
            norm = [norm] * depth
        elif type(norm) is list:
            if depth is not None:
                assert depth == len(norm), "`depth` is different from `len(norm)`"
            depth = len(norm)
        else:
            raise ValueError("Not support norm={}".format(norm))

        if type(nonlinear) is bool or type(nonlinear) is str:
            assert depth is not None, "Specify `depth`"
            nonlinear = [nonlinear] * depth
        elif type(nonlinear) is list:
            if depth is not None:
                assert depth == len(nonlinear), "`depth` is different from `len(nonlinear)`"
            depth = len(nonlinear)
        else:
            raise ValueError("Not support nonlinear={}".format(nonlinear))
        
        self.growth_rate = growth_rate
        self.depth = depth

        net = []
        _in_channels = in_channels - sum(growth_rate)

        for idx in range(depth):
            if idx == 0:
                _in_channels = in_channels
            else:
                _in_channels = growth_rate[idx - 1]
            _out_channels = sum(growth_rate[idx:])
            
            if dilated[idx]:
                dilation = 2**idx
            else:
                dilation = 1
            
            conv_block = ConvBlock2d(_in_channels, _out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, norm=norm[idx], nonlinear=nonlinear[idx], eps=eps)
            net.append(conv_block)
        
        self.net = nn.Sequential(*net)
    
    def forward(self, input):
        """
        Args:
            input: (batch_size, in_channels, H, W)
        Returns:
            output: (batch_size, out_channels, H, W), where out_channels = growth_rate[-1].
        """
        growth_rate, depth = self.growth_rate, self.depth

        for idx in range(depth):
            if idx == 0:
                x = input
                x_residual = 0
            else:
                _in_channels = growth_rate[idx - 1]
                sections = [_in_channels, sum(growth_rate[idx:])]
                 x, x_residual = torch.split(x_residual, sections, dim=1)
            
            x = self.net[idx](x)
            x_residual = x_residual + x
        
        output = x_residual

        return output

def _test_d2block():
    batch_size = 4
    n_bins, n_frames = 64, 64
    in_channels = 3
    kernel_size = (3, 3)
    depth = 4

    input = torch.randn(batch_size, in_channels, n_bins, n_frames)

    print("-"*10, "D2 Block when `growth_rate` is given as int and `dilated` is given as bool.", "-"*10)

    growth_rate = 2
    dilated = True
    model = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated, depth=depth)

    print("-"*10, "D2 Block", "-"*10)
    print(model)
    output = model(input)
    print(input.size(), output.size())
    print()

    # print("-"*10, "D2 Block when `growth_rate` is given as list and `dilated` is given as bool.", "-"*10)

    # growth_rate = [3, 4, 5, 6] # depth = 4
    # dilated = False
    # model = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated)

    # print(model)
    # output = model(input)
    # print(input.size(), output.size())
    # print()

    # print("-"*10, "D2 Block when `growth_rate` is given as list and `dilated` is given as list.", "-"*10)

    # growth_rate = [3, 4, 5, 6] # depth = 4
    # dilated = [True, False, False, True] # depth = 4
    # model = D2Block(in_channels, growth_rate, kernel_size=kernel_size, dilated=dilated)

    # print(model)
    # output = model(input)
    # print(input.size(), output.size())
    print("="*10, "D2 Block", "="*10)
    _test_d2block()
========== D2 Block ==========
---------- D2 Block when `growth_rate` is given as int and `dilated` is given as bool. ----------
---------- D2 Block ----------
D2Block(
  (net): Sequential(
    (0): ConvBlock2d(
      (norm2d): BatchNorm2d(3, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
      (nonlinear2d): ReLU()
      (conv2d): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ConvBlock2d(
      (norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
      (nonlinear2d): ReLU()
      (conv2d): Conv2d(2, 6, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    )
    (2): ConvBlock2d(
      (norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
      (nonlinear2d): ReLU()
      (conv2d): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(4, 4))
    )
    (3): ConvBlock2d(
      (norm2d): BatchNorm2d(2, eps=1e-12, momentum=0.1, affine=True, track_running_stats=True)
      (nonlinear2d): ReLU()
      (conv2d): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), dilation=(8, 8))
    )
  )
)
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/new/lib/python3.8/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 478, in change_attr_expression
    value = eval(expression, frame.f_globals, frame.f_locals)
  File "<string>", line 1
    tensor([[[[ 1.3090e-01, -4.1537e-01, -1.9816e-01,  ..., -1.1555e-01,           -2.2067e-01, -2.6661e-01],          [-1.6991e-01, -3.7272e-02, -5.9325e-02,  ..., -5.6490e-01,            2.5688e-01, -2.1485e-01],          [-4.1483e-01, -8.1407e-02,  3.9883e-02,  ..., -8.5913e-02,            1.7721e-02,  3.6666e-01],          ...,          [ 1.0038e-01, -2.2772e-01, -3.6661e-01,  ...,  3.1680e-01,            1.4326e-01, -1.7341e-02],          [ 2.5856e-01, -3.5614e-01, -3.0179e-02,  ..., -7.1122e-01,           -3.5760e-02, -2.0752e-01],          [-2.4753e-01, -8.6356e-02, -9.8095e-03,  ...,  1.0479e-01,           -1.5521e-01, -3.1733e-01]],         [[ 5.7790e-02, -1.1537e-01, -3.3660e-01,  ...,  1.7819e-01,           -3.5242e-02, -7.4898e-02],          [ 1.1687e-01, -3.9582e-01, -8.4538e-02,  ...,  1.6796e-01,           -2.5183e-01, -1.7742e-01],          [ 1.3049e-01,  1.6287e-02, -1.0771e-01,  ..., -1.0322e+00,            4.9804e-01,  1.2598e-01],          ...,          [-3.1282e-01, -1.8550e-02,  9.8423e-02,  ..., -1.0855e-02,           -6.2761e-02,  6.5281e-03],          [-2.6171e-01, -4.2940e-01,  2.0093e-01,  ..., -9.7036e-02,           -3.5372e-01,  2.1260e-01],          [ 8.7151e-02, -1.3352e-01,  1.3649e-02,  ..., -2.0162e-01,            1.6274e-01, -1.8531e-01]],         [[ 1.9531e-01,  2.1278e-01,  2.5971e-01,  ..., -2.1074e-01,           -4.4137e-01,  2.1402e-01],          [ 3.8864e-01,  1.5082e-01,  1.4350e-01,  ...,  1.0126e+00,           -9.5265e-02,  3.1971e-02],          [ 3.2864e-01, -1.0494e-01,  4.4397e-01,  ...,  2.2869e-01,           -4.5287e-01, -6.4710e-01],          ...,          [ 2.0047e-01,  4.1493e-01,  3.7252e-01,  ...,  1.4543e-01,           -2.4189e-02, -8.7822e-02],          [ 2.5906e-01,  2.1651e-01, -2.7468e-02,  ...,  5.5440e-01,           -1.1034e-01,  1.8114e-01],          [ 3.5780e-01, -6.4934e-03,  3.1841e-01,  ..., -7.3575e-02,            4.5266e-01,  1.7474e-01]],         [[ 2.6019e-01, -4.9918e-03,  4.9222e-02,  ...,  4.9808e-01,            7.7268e-01,  2.4371e-01],          [ 9.6638e-04,  3.0653e-01,  1.7191e-01,  ...,  5.1715e-03,            3.4442e-01,  1.1149e-01],          [ 3.3931e-01,  2.6264e-01,  5.1961e-01,  ..., -2.9706e-01,            7.4714e-01,  1.0039e+00],          ...,          [-2.2798e-02, -9.0747e-02,  1.7157e-01,  ...,  2.4040e-01,            3.1987e-01,  2.4248e-01],          [-6.1366e-02,  2.4639e-02,  5.1985e-01,  ...,  1.8540e-01,            6.7501e-01,  4.5124e-01],          [ 1.9541e-01, -8.2181e-02,  1.5417e-01,  ..., -2.4319e-01,           -1.6177e-03, -3.6424e-02]],         [[ 5.8356e-02,  3.1807e-01,  3.7832e-01,  ...,  5.6302e-01,            7.0850e-01,  1.2015e-01],          [ 1.9481e-02,  8.6745e-02, -1.1216e-01,  ...,  6.8010e-01,            4.3849e-01,  2.1097e-01],          [ 4.7859e-01,  1.3357e-01,  3.6910e-01,  ...,  1.6181e-01,            6.6159e-01,  7.6544e-01],          ...,          [-9.6094e-02,  6.2552e-01,  7.3016e-01,  ...,  2.5252e-01,            4.1118e-01,  5.3823e-02],          [-1.4834e-01,  1.0271e-01,  2.6638e-01,  ...,  9.2894e-01,            6.1686e-01,  4.3951e-01],          [ 2.5091e-01,  1.1119e-01,  3.5358e-01,  ..., -1.8557e-01,            3.4108e-01,  5.0081e-01]],         [[ 3.9389e-02,  1.1824e-01, -5.2395e-03,  ..., -5.7838e-01,           -4.4865e-02,  7.4557e-02],          [-6.5465e-02, -4.7576e-01, -6.0455e-02,  ..., -1.2431e-01,            4.9461e-02,  3.8061e-02],          [ 1.7141e-01,  6.3351e-02,  2.4807e-01,  ...,  1.9720e-01,           -4.8070e-01,  1.3536e-01],          ...,          [-4.2736e-01, -3.8638e-01,  4.1305e-01,  ..., -8.2113e-01,           -1.7334e-03,  6.8984e-02],          [-3.0139e-01, -6.1888e-01,  4.0083e-02,  ...,  1.8445e-01,           -3.9581e-02, -1.6408e-01],          [ 1.3904e-01, -2.3817e-01,  8.3790e-02,  ..., -1.6300e-01,           -1.1475e-01, -7.9784e-02]]],        [[[ 1.4356e-01,  8.8703e-03,  6.0797e-02,  ...,  2.9055e-02,           -1.7647e-01, -1.8921e-01],          [-7.3982e-01, -4.1122e-01, -2.0923e-01,  ...,  1.7836e-01,           -5.5942e-01, -1.4614e-01],          [ 3.3344e-01,  2.3998e-02, -3.5281e-01,  ..., -4.1098e-01,            2.5884e-01, -9.2192e-04],          ...,          [-3.5500e-01, -2.9288e-01,  1.9643e-01,  ...,  3.4869e-01,            5.0756e-01, -4.1429e-01],          [ 1.4664e-01,  1.5499e-01, -9.7463e-02,  ..., -5.8297e-01,           -6.3597e-02,  3.4977e-01],          [-9.2896e-02, -2.7687e-01,  1.7470e-01,  ...,  1.9764e-01,            1.5789e-01,  2.0480e-01]],         [[-1.0954e-01, -2.6265e-01,  4.6650e-02,  ..., -4.5176e-01,           -1.5760e-01, -1.2904e-01],          [ 1.4853e-01, -6.7585e-02,  9.0678e-02,  ..., -4.9944e-01,           -1.5137e-01, -1.7148e-01],          [-8.4873e-02,  2.6314e-01, -5.3545e-01,  ...,  1.9188e-01,           -1.7622e-01, -7.6135e-02],          ...,          [-1.7582e-01, -2.1637e-01, -3.2816e-02,  ..., -3.4230e-01,           -2.4278e-01, -8.0402e-01],          [-4.0788e-01,  4.0176e-02, -5.4633e-01,  ..., -2.9893e-01,           -1.4787e-01, -4.2392e-01],          [ 4.5329e-02, -3.3867e-01, -4.6240e-01,  ..., -2.0694e-01,           -3.9650e-01, -4.4701e-01]],         [[-1.8092e-01,  3.8991e-01,  3.6510e-01,  ...,  7.7998e-03,            4.9165e-01, -1.5063e-01],          [ 3.7547e-01,  2.3853e-01,  9.3737e-01,  ...,  7.0939e-01,            2.4517e-01,  1.2706e-01],          [-1.7488e-01,  7.8685e-01,  7.2161e-02,  ...,  6.9449e-01,           -1.8818e-01,  9.1750e-02],          ...,          [ 3.2627e-01,  3.2579e-01, -4.3036e-01,  ...,  5.3843e-01,            8.0948e-01, -3.4934e-01],          [-2.9014e-01,  3.3120e-01,  4.5309e-01,  ..., -5.2109e-02,            8.9996e-01, -1.8583e-01],          [ 4.6105e-01,  5.5430e-01,  1.9388e-01,  ..., -1.6808e-01,            8.7509e-01,  1.5337e-02]],         [[-1.4089e-01, -1.2434e-01,  4.6653e-01,  ..., -3.4566e-02,            3.4214e-01,  3.2655e-01],          [-1.4582e-01,  6.0733e-01,  4.8809e-01,  ..., -1.9225e-01,            1.7463e-01,  9.5746e-02],          [-4.0453e-01,  3.3570e-01, -1.4281e-01,  ...,  5.2611e-02,            3.7899e-01,  1.5127e-01],          ...,          [ 3.6659e-01, -3.7853e-02,  8.7981e-02,  ...,  2.4901e-01,            1.2098e-01,  2.5506e-01],          [ 4.1407e-01,  6.2428e-01,  3.5363e-01,  ...,  2.3824e-01,            6.8821e-01,  2.0267e-02],          [ 8.2573e-02,  2.2750e-01,  2.2867e-02,  ..., -1.2145e-01,            2.3449e-01, -2.6963e-01]],         [[-8.6067e-02,  3.5226e-01,  7.2136e-02,  ...,  4.4474e-01,            7.7592e-01,  2.9618e-01],          [ 5.4441e-01,  7.5268e-01,  1.9816e-01,  ..., -1.8014e-01,            4.8248e-01,  1.9109e-01],          [ 1.5967e-02,  4.2697e-01, -1.7190e-01,  ...,  6.5677e-01,            3.9801e-01,  6.6904e-04],          ...,          [ 2.4438e-01,  1.8839e-01,  2.1768e-01,  ...,  1.7727e-01,           -2.7860e-01, -8.4015e-02],          [ 2.7309e-01,  2.0295e-01,  2.4740e-01,  ...,  3.1903e-01,            8.4223e-02, -1.5729e-01],          [ 8.9682e-02,  3.8643e-01, -1.8528e-01,  ..., -1.6292e-01,            1.0070e-01, -2.9439e-01]],         [[-2.0255e-01,  3.8269e-01, -6.5444e-03,  ...,  6.0098e-02,            5.6779e-01, -1.7993e-01],          [-8.0255e-01,  1.3640e-01,  1.1707e-02,  ..., -7.4824e-01,            1.9584e-02, -1.4349e-01],          [-1.2504e-01,  2.2301e-01, -5.5353e-01,  ..., -2.1798e-01,            1.4390e-02, -1.3819e-01],          ...,          [-1.1011e-01,  1.1887e-01, -7.1152e-02,  ..., -3.5547e-01,           -1.6031e-01, -5.2682e-01],          [ 9.6868e-02,  8.8074e-02, -2.7179e-01,  ..., -2.6354e-01,            1.2194e-01, -4.9255e-01],          [-2.0110e-01, -2.8900e-01, -2.6629e-01,  ..., -1.4134e-01,           -2.0730e-01, -5.0878e-01]]],        [[[ 5.6126e-02, -2.5404e-02, -4.1849e-01,  ..., -1.2481e-02,           -7.2202e-01, -1.1423e-01],          [ 1.0557e-01, -4.9420e-01,  2.0791e-01,  ..., -1.2306e-02,            2.3137e-01, -5.8570e-02],          [-4.1903e-01,  3.3547e-01,  2.6958e-01,  ...,  1.0241e-01,            2.3915e-01, -3.3904e-01],          ...,          [ 3.4523e-02,  1.6014e-01, -4.1179e-01,  ...,  1.2029e-01,            7.2006e-02,  3.4913e-02],          [-1.0237e-01, -2.7229e-01,  2.0040e-01,  ..., -7.1299e-01,            4.2478e-02, -2.4862e-01],          [ 2.7004e-01, -1.0191e-01, -4.2506e-01,  ...,  2.5071e-01,            1.8583e-01, -3.9743e-01]],         [[-2.1440e-02,  2.9341e-01, -1.0359e+00,  ...,  5.9206e-02,           -4.3417e-01, -2.0778e-01],          [ 8.9767e-02, -9.1228e-02, -4.0957e-01,  ...,  3.5518e-02,           -5.6046e-01,  1.7276e-02],          [ 4.3644e-01, -6.1414e-01, -2.8989e-01,  ..., -2.2452e-01,           -4.4003e-02, -3.7126e-01],          ...,          [-4.0940e-01, -7.2889e-02, -1.1963e-01,  ...,  1.9516e-02,           -4.6536e-01, -2.0901e-01],          [-3.6333e-01, -4.9914e-01,  1.8891e-01,  ..., -1.9815e-01,            9.3363e-02, -1.0661e-02],          [-3.8549e-01, -1.7502e-01, -1.0987e-01,  ..., -5.7205e-01,           -1.0420e-01, -2.6746e-01]],         [[-1.6177e-01,  4.3002e-01, -1.7754e-01,  ...,  6.6978e-02,            5.5545e-01,  3.2683e-01],          [ 1.1671e-01,  6.4944e-01, -2.4066e-01,  ...,  9.1957e-01,           -3.6894e-02,  2.2699e-01],          [ 4.5652e-01,  5.8080e-02,  4.6585e-01,  ...,  5.6875e-01,            4.5249e-01, -8.2676e-02],          ...,          [ 1.1102e-02,  7.2689e-01,  5.3866e-01,  ...,  2.2208e-01,            4.7787e-01,  1.3757e-01],          [ 4.4706e-01,  8.6774e-01,  3.5510e-01,  ...,  5.3292e-01,            2.0083e-01,  6.7404e-01],          [ 1.1253e-01,  3.5022e-01,  2.6619e-01,  ..., -3.2192e-02,            4.6967e-01,  1.1762e-01]],         [[ 4.0240e-01,  7.9310e-01,  6.5073e-02,  ...,  5.6078e-01,            4.4267e-01,  3.8579e-01],          [ 3.7860e-01,  5.9048e-01,  1.6851e-01,  ..., -1.3461e-01,           -1.0553e-01,  5.0805e-01],          [-1.1219e-01,  4.1083e-01,  7.1022e-02,  ..., -1.1048e-02,            1.8843e-01,  1.6078e-01],          ...,          [ 4.1006e-02,  1.5088e-01, -5.6003e-02,  ...,  1.2331e-01,            2.2651e-01,  6.0313e-02],          [ 2.2112e-01, -3.8992e-02, -2.8221e-02,  ...,  3.9878e-01,            7.0052e-01,  9.5154e-02],          [-2.3062e-01,  6.3578e-02,  8.0448e-02,  ..., -5.0950e-01,            3.7462e-02, -1.7051e-01]],         [[ 4.2451e-01,  3.2481e-01,  2.0671e-01,  ...,  6.2420e-01,            7.4580e-01,  5.3606e-01],          [ 1.2897e-01,  3.5371e-01,  1.4678e-01,  ...,  2.7612e-01,            2.5483e-01,  5.9154e-01],          [ 2.5120e-01,  2.6160e-01,  3.5832e-01,  ...,  2.3914e-01,            1.1501e-01,  3.4694e-01],          ...,          [-4.5868e-02,  6.3412e-01,  5.5902e-01,  ...,  9.2754e-02,            3.6898e-01,  1.7211e-01],          [ 1.9321e-01,  4.5260e-01,  7.2063e-02,  ...,  3.5093e-01,            3.3534e-01,  5.2381e-03],          [-4.5890e-01,  1.8137e-01,  2.6613e-01,  ..., -2.4522e-01,            2.4212e-01,  5.9463e-02]],         [[ 1.8611e-01,  9.0050e-01, -1.1732e-01,  ...,  6.1897e-01,            2.1543e-02,  1.7325e-02],          [-1.0462e-01,  4.7290e-01, -4.7110e-01,  ..., -4.9565e-01,           -3.6357e-01,  2.3255e-01],          [-3.0707e-01,  3.1617e-01, -1.1300e+00,  ..., -4.3797e-01,           -7.2460e-01, -2.3399e-01],          ...,          [-1.9016e-01, -3.3995e-01, -6.3654e-01,  ..., -4.3309e-01,           -1.7814e-02, -4.9589e-01],          [-6.3855e-01, -3.1695e-01, -6.9847e-02,  ..., -5.6335e-01,           -8.6568e-02, -8.5706e-02],          [-5.0186e-01, -9.4074e-02, -4.1813e-01,  ..., -3.6641e-01,           -9.6743e-02, -1.6949e-01]]],        [[[-8.1937e-02,  2.1681e-01, -5.2875e-02,  ..., -4.4092e-01,            2.6772e-01, -2.4320e-01],          [-5.9855e-01,  4.1130e-02,  4.1174e-01,  ...,  7.3259e-02,           -5.0143e-01, -1.3249e-01],          [ 4.0575e-01,  4.6065e-01, -4.4359e-01,  ..., -9.5927e-02,           -2.1272e-01,  3.1202e-01],          ...,          [ 2.7425e-01,  4.4362e-01, -1.9357e-01,  ...,  7.7396e-01,            9.9243e-02, -2.1813e-01],          [ 1.3178e-01, -3.6618e-01, -1.2347e-01,  ..., -1.1928e-01,           -1.0700e+00,  7.3413e-01],          [-1.5958e-02,  2.6776e-01,  3.8485e-01,  ...,  1.5780e-01,            5.9023e-01,  7.9637e-02]],         [[-1.6003e-01, -2.2649e-02, -1.5973e-01,  ...,  5.9173e-02,            1.7555e-01,  3.1533e-02],          [-4.5756e-01, -6.2946e-01, -4.6876e-01,  ..., -2.9179e-01,            1.7184e-01, -1.4880e-01],          [-7.4202e-01,  2.3086e-01, -9.0355e-02,  ...,  5.0716e-01,            1.3546e-01,  2.3537e-02],          ...,          [-2.7658e-01, -2.8616e-01, -8.9722e-01,  ..., -5.4695e-01,           -1.9833e-01, -4.1697e-01],          [-3.7228e-01, -5.7368e-02, -2.9318e-01,  ...,  4.6841e-02,            3.2001e-01,  1.1788e-01],          [-1.1929e-01, -3.6330e-01, -1.8080e-01,  ..., -2.0359e-01,           -4.8472e-01,  1.7183e-01]],         [[-8.2337e-02,  1.0726e-01,  2.8453e-02,  ..., -1.9916e-02,           -1.3832e-01,  1.9504e-01],          [ 5.1024e-01,  3.0944e-01,  1.3533e-01,  ...,  2.9242e-01,            3.8062e-01, -3.4277e-01],          [-1.7284e-01, -1.2114e-01,  1.4709e-01,  ...,  5.8859e-02,            3.1158e-01, -7.0858e-02],          ...,          [ 2.9475e-01, -2.5680e-01,  1.4026e-01,  ..., -7.1591e-02,            5.5963e-01, -4.0796e-01],          [ 2.1493e-01,  5.9311e-01,  5.2926e-01,  ...,  2.5379e-01,            5.6689e-01, -7.1665e-01],          [ 2.9765e-01,  6.3153e-02,  7.2703e-01,  ...,  6.0896e-01,            2.1499e-01,  7.1044e-01]],         [[ 4.6605e-01,  6.7653e-01, -7.1833e-02,  ...,  3.0033e-02,            7.0425e-01,  4.8540e-01],          [ 4.8456e-01,  4.0220e-01,  4.0125e-02,  ..., -2.0781e-01,            3.0123e-01,  8.6450e-02],          [-4.0775e-01,  6.8943e-01,  4.0457e-01,  ...,  6.9082e-01,            8.5107e-03,  3.6140e-01],          ...,          [-1.4108e-02,  2.0690e-01, -4.4909e-02,  ..., -6.0628e-02,            6.1117e-01,  1.0082e-01],          [ 1.7421e-01,  7.4176e-02,  4.2727e-01,  ...,  2.4365e-01,            1.3029e-01,  8.5371e-01],          [ 1.4887e-01, -3.1160e-02,  1.5774e-01,  ...,  1.0996e-01,           -8.8344e-01,  1.0888e-01]],         [[ 3.8795e-01,  3.5317e-01,  9.9006e-02,  ...,  8.5936e-01,            6.3873e-01,  2.4603e-01],          [ 6.0738e-01,  3.6789e-01, -1.1297e-01,  ...,  1.5349e-01,            5.1060e-01,  4.1978e-01],          [-2.5957e-01,  6.7207e-01,  5.5577e-01,  ...,  6.6401e-01,            2.3646e-01, -6.0896e-02],          ...,          [-1.6632e-01, -1.1759e-01,  4.6079e-01,  ..., -1.4599e-01,            3.4830e-01,  4.4242e-02],          [-2.3171e-01,  3.3373e-01,  5.0213e-01,  ...,  6.5282e-01,            1.2472e+00,  6.5165e-01],          [-8.8611e-02, -3.0489e-01, -1.3409e-01,  ..., -1.3034e-01,           -5.1689e-01,  2.6151e-02]],         [[-3.5143e-01,  3.8954e-02, -1.7949e-01,  ...,  2.7233e-02,            3.8368e-01,  1.3782e-01],          [-6.3582e-01, -4.1529e-02, -3.4359e-01,  ..., -2.6442e-02,           -4.9910e-01, -1.1147e-01],          [-5.8319e-01,  2.9186e-01, -3.6855e-01,  ..., -1.9133e-01,           -1.4835e-01, -2.4480e-02],          ...,          [-2.1798e-01,  7.9533e-02,  2.9324e-01,  ..., -5.3775e-01,           -1.0022e-01, -2.4489e-01],          [-4.7352e-01, -2.6569e-01,  2.9632e-01,  ...,  2.5473e-02,           -4.2064e-01,  3.0822e-01],          [-7.4871e-02, -1.8836e-01, -7.7294e-01,  ..., -2.6800e-01,           -4.2740e-01, -3.2816e-01]]]], grad_fn=<SplitWithSizesBackward>)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          ^
SyntaxError: invalid syntax
torch.Size([4, 3, 64, 64]) torch.Size([4, 2, 64, 64])