pytorch unfold:extract patches from image
a tutorial about how to extract patches from a large image and to rebuild the original image from the extracted patches
Using pytorch unfold and fold to construct the sliding window manually
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from skimage import io
import PIL
import os
import mimetypes
import torchvision.transforms as transforms
import glob
from skimage.io import imread
from natsort import natsorted
import re
import numba
from fastai2.vision.all import *
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from pdb import set_trace
x = torch.arange(48).view(3, 4, 4)
x.shape
# x.view(8,8)
x
print('test')
x.unfold(0, 2, 1).shape
x.unfold(0, 2, 1)
print('exp1')
x.unfold(0, 3, 3).shape
x.unfold(0, 3, 3)
print('exp2')
x.unfold(0, 3, 3).unfold(1, 2, 2).shape
x.unfold(0, 3, 3).unfold(1, 2, 2)
print('exp3')
x.unfold(0, 3, 3).unfold(1, 2, 2).unfold(2, 2, 2).shape
x.unfold(0, 3, 3).unfold(1, 2, 2).unfold(2, 2, 2)
temp = torch.randint(0, 10, (3, 5176, 3793))
temp.shape
patches = temp.unfold(0, 3, 3)
patches.shape
test_eq(temp.unfold(0, 3, 3), temp.unfold(0, 3, 4))
patches = patches.unfold(1, 128, 128)
patches.shape
patches = patches.unfold(2, 128, 128)
# test_eq(temp.unfold(0,3,3),temp.unfold(0,3,66))
patches.shape
math.floor((5176-128)/128)+1
math.floor((3793-128)/128)+1
important
eg.
(a,b) = x.shape x.unfold(c,d,e) where d is the size and e is the step
from here we can see it:the shape value at dimension c after unfold method is that: eg. at a 's dimension:
**(math.floor(a-d)/e +1,b,d)**
BTW: the last one is to append the size value in the unfold method
inp = torch.randn(1,3,10,12)
w = torch.randn(2,3,4,5)
inp_unf = torch.nn.functional.unfold(inp,(4,5))
inp_unf.shape
# !wget https://eoimages.gsfc.nasa.gov/images/imagerecords/88000/88094/niobrara_photo_lrg.jpg
patch_size=512
stride=patch_size
pil2tensor = transforms.ToTensor()
file=Path('niobrara_photo_lrg.jpg')
filename=file.stem
im1 = Image.open(file)
print(im1.shape)
# im1.resize(5120,5120)
im1 = im1.resize((1500,1500),Image.BILINEAR)
im1
rgb_image = pil2tensor(im1)
rgb_image.shape
rgb_image.data.type()
patches = rgb_image.data.unfold(0, 3, 3).unfold(1, patch_size, stride).unfold(2, patch_size, stride)
print(patches.shape)
a = list(patches.shape)
a
torch.from_numpy(np.arange(0,a[1]))
patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].shape
x = patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].split(1, dim=1)
x = patches.split(1, dim=1)
# x = patches.split(1, dim=2)
len(x)
x[0].shape
x[1].shape
to_pil = ToPILImage()
math.floor(1500/512)
6000/512
x = patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].split(1, dim=1)
for i in list(np.arange(a[1])):
y = x[i][:,:,torch.from_numpy(np.arange(0,a[2])),:,:,:].split(1, dim=2)
for j in list(np.arange(a[2])):
img = to_pil(y[j].squeeze(0).squeeze(0).squeeze(0))
img
# set_trace()
# save_image(y[j], filename+'-'+str(i)+'-'+str(j)+'.png')
def split_tensor(tensor, tile_size=256):
mask = torch.ones_like(tensor)
# use torch.nn.Unfold
stride = tile_size//2
unfold = nn.Unfold(kernel_size=(tile_size, tile_size), stride=stride)
# Apply to mask and original image
mask_p = unfold(mask)
patches = unfold(tensor)
patches = patches.reshape(3, tile_size, tile_size, -1).permute(3, 0, 1, 2)
if tensor.is_cuda:
patches_base = torch.zeros(patches.size(), device=tensor.get_device())
else:
patches_base = torch.zeros(patches.size())
tiles = []
for t in range(patches.size(0)):
tiles.append(patches[[t], :, :, :])
return tiles, mask_p, patches_base, (tensor.size(2), tensor.size(3))
def rebuild_tensor(tensor_list, mask_t, base_tensor, t_size, tile_size=256):
stride = tile_size//2
# base_tensor here is used as a container
for t, tile in enumerate(tensor_list):
print(tile.size())
base_tensor[[t], :, :] = tile
base_tensor = base_tensor.permute(1, 2, 3, 0).reshape(3*tile_size*tile_size, base_tensor.size(0)).unsqueeze(0)
fold = nn.Fold(output_size=(t_size[0], t_size[1]), kernel_size=(tile_size, tile_size), stride=stride)
# https://discuss.pytorch.org/t/seemlessly-blending-tensors-together/65235/2?u=bowenroom
output_tensor = fold(base_tensor)/fold(mask_t)
# output_tensor = fold(base_tensor)
return output_tensor
# %%time
test_image = 'test_image.jpg'
image_size=1024
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
input_tensor = Loader(Image.open(file).convert('RGB')).unsqueeze(0).cuda()
# Split image into overlapping tiles
tile_tensors, mask_t, base_tensor, t_size = split_tensor(input_tensor, 660)
# Put tiles back together
output_tensor = rebuild_tensor(tile_tensors, mask_t, base_tensor, t_size, 660)
# Save Output
Image2PIL = transforms.ToPILImage()
print(f'the whole length of the patches is {len(tile_tensors)}')
# show small patches
for i in range(len(tile_tensors)):
print(f'the current is {i}')
Image2PIL(tile_tensors[i].cpu().squeeze(0))
print('the reconstruct image')
Image2PIL(output_tensor.cpu().squeeze(0))
# Image2PIL(output_tensor.cpu().squeeze(0)).save('output_image.png')
6000/512
len(tile_tensors)
tile_tensors[0].size()
tile_tensors[0].squeeze(0)
temp = PILImage(Image2PIL(tile_tensors[0].cpu().squeeze(0)))
temp
temp.shape