updated on 2020-03-01
torchvision.datasets.
ImageFolder
(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None) [SOURCE] A generic data loader where the images are arranged in this way:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
root (string) – Root directory path.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
loader (callable, optional) – A function to load an image given its path.
is_valid_file – A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)
__getitem__
(index) index (python:int) – Index
(sample, target) where target is class_index of the target class.
tuple
replace "torchvision.datasets.ImageFolder" to original ImageFolder to return image path.
below example
import torch import torchvision from torchvision import datasets, transforms class MyImageFolder(datasets.ImageFolder): def __getitem__(self, index): return super(MyImageFolder, self).__getitem__(index), self.imgs[index] # transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) # create test loader testset = MyImageFolder(root='/image_folder/test', transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) class_names = ('REAL', 'FAKE') device = 'cuda' if torch.cuda.is_available() else 'cpu'
test whether image path is returned now.
# check if path can print for i, data in enumerate(testloader): (images,labels), (path,_) = data images, labels = images.to(device), labels.to(device) print(path, "\n")('/mtcnn_detect_resized/test/REAL/dmmvuaikkv.png', '/mtcnn_detect_resized/test/REAL/dnmowthjcj.png', '/mtcnn_detect_resized/test/REAL/doniqevxeg.png', '/mtcnn_detect_resized/test/REAL/dozjwhnedd.png') ('/mtcnn_detect_resized/test/REAL/dpevefkefv.png', '/mtcnn_detect_resized/test/REAL/dpmgoiwhuf.png', '/mtcnn_detect_resized/test/REAL/dsnxgrfdmd.png', '/mtcnn_detect_resized/test/REAL/dtozwcapoa.png') ...
('/mtcnn_detect_resized/test/REAL/dvkdfhrpph.png', '/mtcnn_detect_resized/test/REAL/dvtpwatuja.png', '/mtcnn_detect_resized/test/REAL/dvwpvqdflx.png', '/mtcnn_detect_resized/test/REAL/dxfdovivlw.png')
I recommend to check source code of ImageFolder and understand what I did now.
https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L45-L74