3.5 torchvision 經典dataset學習
前面已經學習了Dataset,DataLoader,以及常用的函數,通常足以應對大多數需求,但距離熟練編寫自己的Dataset可能還有一段距離。 為了讓大家能輕鬆掌握各種情況下的dataset編寫,本小節對torchvision中提供的幾個常見dataset進行分析,觀察它們的代碼共性,總結編寫dataset的經驗。
X-MNIST
由於MNIST資料使用廣泛,在多領域均可基於這個小資料集進行初步的研發與驗證,因此基於MNIST資料格式的各類X-MNIST資料層出不窮,在mnist.py檔中也提供了多個X-MNIST的編寫,這裡需要大家體會類繼承。
示例表明FashionMNIST、KMNIST兩個dataset僅需要修改資料url(mirrors、resources)和類別名稱(classes),其餘的函數均可複用MNIST中寫好的功能,這一點體現了物件導向程式設計的優點。
來看dataset的 getitem,十分簡潔,因為已經把圖片和標籤處理好,存在self.data和self.targets中使用了:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], int(self.targets[index])
img = Image.fromarray(img.numpy(), mode=
'L')
if
self.transform
is not None:
img = self.transform(img)
if
self.target_transform
is not None:
target = self.target_transform(target)
return
img, target
代碼參閱:D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\mnist.py
cifar-10
cifar-10是除MNIST之外使用最多的公開資料集,同樣,讓我們直接關注其 Dataset
實現的關鍵部分
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if
self.transform
is not None:
img = self.transform(img)
if
self.target_transform
is not None:
target = self.target_transform(target)
return
img, target
核心代碼還是這一行: img, target = self.data[index], self.targets[index]
接下來,去分析data和self.targets是如何從磁片上獲取的?通過代碼搜索可以看到它們來自這裡(D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\cifar.py CIFAR10 類的 init函數):
# now load the picked numpy arrays
for
file_name, checksum
indownloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with
open(file_path,
'rb')
asf:
entry = pickle.load(f, encoding=
'latin1')
self.data.append(entry[
'data'])
if 'labels' in
entry:
self.targets.extend(entry[
'labels'])
else
:
self.targets.extend(entry[
'fine_labels'])
self.data = np.vstack(self.data).reshape(
-1,
3,
32,
32)
self.data = self.data.transpose((
0,
2,
3,
1))
# convert to HWC
這一段的作用於MNIST的_load_data(), 我們的_get_img_info()一樣,就是讀取資料資訊。
總結:
- getitem函數中十分簡潔,邏輯簡單
- 初始化時需完成資料資訊的採集,存儲到變數中,供getitem使用
VOC
之前討論的資料集主要用於教學目的,比較複雜的目標檢測資料是否具有較高的編寫難度?答案是,一點也不,仍舊可以用我們分析出來的邏輯進行編寫。
下面來看第一個大規模應用的目標檢測資料集——PASCAL VOC,
D:\Anaconda_data\envs\pytorch_1.10_gpu\Lib\site-packages\torchvision\datasets\voc.py的
VOCDetection類的getitem函數:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert(
"RGB")
target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
if
self.transforms
is not None:
img, target = self.transforms(img, target)
return
img, target
更簡潔了,與我們的案例中的getitem一樣一樣的,那麼images和annotations從哪裡來?相信大家已經知道答案了,那就是初始化的時候根據資料格式、資料組織結構,從磁片中讀取。
COCO
說到目標檢測就不得不提COCO資料集,COCO資料集是微軟提出的大規模視覺資料集,主要用於目標檢測,它從資料量、類別量都遠超VOC,對於深度學習模型的落地應用起到了推動作用。
對於CV那麼重要的COCO,它的dataset難嗎?答案是,不難。反而更簡單了,整個類僅40多行。
getitem函數連注釋都顯得是多餘的:
def __getitem__(self, index: int) -> Tuple[Any, Any]:
id = self.ids[index]
image = self._load_image(id)
target = self._load_target(id)
if
self.transforms
is not None:
image, target = self.transforms(image, target)
return
image, target
其實,這一切得益于COCO的應用過於廣泛,因此有了針對COCO資料集的輪子——pycocotools,它非常好用,建議使用COCO資料集的話,一定要花幾天時間熟悉pycocotools。pycocotools把getitem需要的東西都準備好了,因此這個類只需要40多行代碼。
小結
本章從資料模組中兩個核心——Dataset&Dataloader出發,剖析pytorch是如何從硬碟中讀取資料、組裝資料和處理資料的。在資料處理流程中深入介紹資料預處理、資料增強模組transforms,並通過notebook的形式展示了常用的transforms方法使用,最後歸納總結torchvision中常見的dataset,為大家將來應對五花八門的任務時都能寫出dataset代碼。 下一章將介紹模型模組。