pytorch module 정리
torchvision.datasets (cococaptions)
yuuuun
2021. 4. 20. 15:59
반응형
pytorch.org/vision/stable/datasets.html
torchvision.datasets — Torchvision master documentation
torchvision.datasets All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.
pytorch.org
img_data = torchvision.datasets.ImageNet('[imagenet directory]')
data_loader = torch.utils.data.DataLoader(img_data, batch_size=4, shuffle=True, num_workers=args.nThreads)
CocoCaptions
from torchvision.dataset import CocoCaptions
CLASS torchvision.datasets.CocoCaptions(root: str, annFile: str,
transform: Union[Callable, NoneType] = None,
target_transform: Union[Callable, NoneType] = None,
transforms: Union[Callable, NoneType] = None)
root: image가 다운 받아져있는 directory
annFile: json annotation file 저장되어 있는 directory
transform: 어떤 return을 반환해줄지 정의
target_transform: 어떤 data를 선택할지 정의
CLIP 코드 中 (check code)
import torchvision.transforms as transforms
import torchvision.datasets as dset
improt random
def rand_choice(x):
return random.choice(x)
tfms = transforms.Compose([transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))])
target_tfm = rand_choice
cap = dset.CocoCaptions(root = './datasets/train2014/',
annFile = './datasets/annotations/captions_train2014.json',
transform=tfms,
target_transform=target_tfm,)
tfms: return할 데이터의 형태
target_tfm: 기존에는 lambda로 설정되어 있었으나 오류가 나서 수정
root내의 파일 형태는 jpg 파일
반응형