import mmcv
import matplotlib.pyplot as plt
from fastcore.basics import *
from fastai.vision.all import *
from fastai.torch_basics import *
import warnings
warnings.filterwarnings("ignore")
import kornia
from kornia.constants import Resample
from kornia.color import *
from kornia import augmentation as K
import kornia.augmentation as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF
from torchvision.transforms import transforms
from torchvision.transforms import PILToTensor
from einops import rearrange, reduce, repeat
set_seed(105)
train_a_path = Path("/home/ubuntu/sharedData/swp/dlLab/fastaiRepository/fastai/data/rsData/kaggleOriginal/Potsdam/2_Ortho_RGB/")
label_a_path = Path("/home/ubuntu/sharedData/swp/dlLab/fastaiRepository/fastai/data/rsData/kaggleOriginal/Potsdam/5_labels_for_participants/")
dsm_path = Path("/home/ubuntu/sharedData/swp/dlLab/fastaiRepository/fastai/data/rsData/kaggleOriginal/Potsdam/1_dsm/1_DSM/")
ndsm_path = Path("/home/ubuntu/sharedData/swp/dlLab/fastaiRepository/fastai/data/rsData/kaggleOriginal/Potsdam/1_dsm_normalisation/1_DSM_normalisation/")
imgNames = get_image_files(train_a_path)
lblNames = get_image_files(label_a_path)
dsmNames = get_image_files(dsm_path)
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
rgbImage = Image.open(imgNames[0])
lblImage = Image.open(lblNames[0])
dsmImage = Image.open(dsmNames[0])
rgbTensor = image2tensor(rgbImage)
lblTensor = image2tensor(lblImage)
dsmTensor = image2tensor(dsmImage)
temp = TensorImage(rgbImage)
temp.shape
temp.show()
(6000, 6000, 3)
<AxesSubplot:>
temp.resize_(520,512,3)
TensorImage([[[ 72,  74,  65],
         [ 76,  78,  71],
         [ 76,  79,  72],
         ...,
         [ 72,  80,  64],
         [ 68,  78,  59],
         [ 60,  72,  49]],

        [[ 54,  64,  40],
         [ 52,  64,  39],
         [ 55,  67,  43],
         ...,
         [ 51,  55,  52],
         [ 49,  50,  50],
         [ 47,  48,  46]],

        [[ 46,  50,  45],
         [ 46,  53,  48],
         [ 45,  52,  45],
         ...,
         [180, 193, 197],
         [181, 193, 198],
         [181, 194, 199]],

        ...,

        [[ 56,  66,  61],
         [ 56,  67,  60],
         [ 55,  65,  57],
         ...,
         [ 64,  81,  79],
         [ 63,  78,  76],
         [ 64,  78,  78]],

        [[ 67,  82,  82],
         [ 64,  77,  78],
         [ 66,  79,  81],
         ...,
         [ 81,  80,  75],
         [ 82,  83,  79],
         [ 82,  84,  81]],

        [[ 78,  81,  76],
         [ 77,  81,  75],
         [ 79,  84,  76],
         ...,
         [105, 125, 128],
         [104, 123, 127],
         [102, 122, 128]]], dtype=torch.uint8)
temp.show()
<AxesSubplot:>
temp.shape
(520, 512, 3)
rearrange(temp,'h w c -> w h c').shape
(512, 520, 3)
rearrange(temp,'g b (c1 c2) -> (c1 b) (c2 g)',c1=3).shape
(1536, 520)