第三章 PyTorch 資料模組
第三章簡介
經過前兩章的鋪墊,本章終於可以講講專案代碼中重要的模組——資料模組。
資料模組包括哪些內容呢?相信大家多少會有一些感覺,不過最好結合具體任務來剖析資料模組。
我們回顧2.2中的COVID-19分類任務,觀察一下資料是如何從硬碟到模型輸入的。
我們倒著推,
模型接收的訓練資料是
- data:outputs = model(data)
- data來自train_loader: for data, labels in train_loader:
- train_loader 來自 DataLoader與train_data:train_loader = DataLoader(dataset=train_data, batch_size=2)
- train_data 來自 COVID19Dataset:train_data = COVID19Dataset(root_dir=img_dir, txt_path=path_txt_train, transform=transforms_func)
- COVID19Dataset繼承於Dataset:COVID19Dataset(Dataset)
至此,知道整個資料處理過程會涉及pytorch的兩個核心——Dataset, DataLoader。
Dataset是一個抽象基類,提供給使用者定義自己的資料讀取方式,最核心在於getitem中間對資料的處理。
DataLoader是pytorch資料載入的核心,其中包括多個功能,如打亂資料,採樣機制(實現均衡1:1採樣),多進程資料載入,組裝成Batch形式等豐富的功能。
本章將圍繞著它們兩個展開介紹pytorch的資料讀取、預處理、載入等功能。
3.1 torch.utils.data.Dataset
資料交互模組——Dataset
雖然說pytorch資料模組的核心是DataLoader,但是對於使用者而言,改動最多的、與來源資料最接近的是Dataset, 本小節就詳細分析Dataset的作用,並通過三個案例學習如何編寫自訂Dataset來讀取自己的資料集。
Dataset的功能
pytorch提供的torch.utils.data.Dataset類是一個抽象基類An abstract class representing a Dataset,供用戶繼承,編寫自己的dataset,實現對資料的讀取。在Dataset類的編寫中必須要實現的兩個函數是__getitem__和__len__(由於markdown語法問題,後續雙底線就省略了)。
- getitem:需要實現讀取一個樣本的功能。通常是傳入索引(index,可以是序號或key),然後實現從磁片中讀取資料,並進行預處理(包括online的資料增強),然後返回一個樣本的資料。資料可以是包括模型需要的輸入、標籤,也可以是其他元資訊,例如圖片的路徑。getitem返回的資料會在dataloader中組裝成一個batch。即,通常情況下是在dataloader中調用Dataset的getitem函數獲取一個樣本。
- len:返回資料集的大小,資料集的大小也是個最要的資訊,它在dataloader中也會用到。如果這個函數返回的是0,dataloader會報錯:"ValueError: num_samples should be a positive integer value, but got num_samples=0"
這個報錯相信大家經常會遇到,這通常是檔路徑沒寫對,導致你的dataset找不到資料,資料個數為0。
瞭解Dataset類的概念,下面通過一幅示意圖,來理解Dataset與DataLoader的關係。
dataset負責與磁片打交道,將磁片上的資料讀取並預處理好,提供給DataLoader,而DataLoader只需要關心如何組裝成批資料,以及如何採樣。採樣的體現是出現在傳入getitem函數的索引,這裡採樣的規則可以通過sampler由使用者自訂,可以方便地實現均衡採樣、隨機採樣、有偏採樣、漸進式採樣等,這個留在DataLoader中會詳細展開。
此處,先分析Dataset如何與磁片構建聯繫。
從2.2節的例子中可以看到,我們為COVID19Dataset定義了一個_get_img_info函數,該函數就是用來建立磁片關係的,在這個函數中收集並處理樣本的路徑資訊、標籤資訊,存儲到一個list中,供getitem函數使用。getitem函數只需要拿到序號,就可獲得圖片的路徑資訊、標籤資訊,接著進行圖片預處理,最後返回一個樣本資訊。
希望大家體會_get_img_info函數的作用,對於各種不同的資料形式,都可以用這個範本實現Dataset的構建,只需將資料資訊(路徑、標籤等)收集並存儲至清單中,供__getitem__函數使用”。
三個Dataset案例
相信大家在做自己的任務時,遇到的第一個問題就是,怎麼把自己的資料放到github的模型上跑起來。很多朋友通常會把自己的資料調整為與現成專案資料一模一樣的資料形式,然後執行相關代碼。這樣雖然快捷,但缺少靈活性。
為了讓大家能掌握各類資料形式的讀取,這裡構建三個不同的資料形式進行編寫Dataset。
第一個:2.2中的類型。資料的劃分及標籤在txt中。
第二個:資料的劃分及標籤在資料夾中體現
第三個:資料的劃分及標籤在csv中
詳情請結合 配套代碼,深刻理解_get_img_info及Dataset做了什麼。
代碼輸出主要有兩部分,
第一部分是兩種dataset的getitem輸出。
第二部分是結合DataLoader進行資料載入。
先看第一部分,輸出的是 PIL物件及圖像標籤,這裡可以進入getitem函數看到採用了
img = Image.open(path_img).convert('L')
對圖片進行了讀取,得到了PIL物件,由於transform為None,不對圖像進行任何預處理,因此getitem函數返回的圖像是PIL物件。
2 (, 1) 2 (, 1)
第二部分是結合DataLoader的使用,這種形式更貼近真實場景,在這裡為Dataset設置了一些transform,有圖像的縮放,ToTensor, normalize三個方法。因此,getitem返回的圖像變為了張量的形式,並且在DataLoader中組裝成了batchsize的形式。大家可以嘗試修改縮放的大小來觀察輸出,也可以注釋normalize來觀察它們的作用。
0 torch.Size([2, 1, 4, 4]) tensor([[[[-0.0431, -0.1216, -0.0980, -0.1373], [-0.0667, -0.2000, -0.0824, -0.2392], [-0.1137, 0.0353, 0.1843, -0.2078], [ 0.0510, 0.3255, 0.3490, -0.0510]]],
[[[-0.3569, -0.2863, -0.3333, -0.4118],
[ 0.0196, -0.3098, -0.2941, 0.1059],
[-0.2392, -0.1294, 0.0510, -0.2314],
[-0.1059, 0.4118, 0.4667, 0.0275]]]]) torch.Size([2]) tensor([1, 0])
Copy
關於transform的系列方法以及工作原理,將在本章後半部分講解資料增強部分再詳細展開。
小結
本小結介紹了torch.utils.data.Dataset類的結構及工作原理,並通過三個案例實踐,加深大家對自行編寫Dataset的認識,關於Dataset的編寫,torchvision也有很多常用公開資料集的Dataset範本,建議大家學習,本章後半部分也會挑選幾個Dataset進行分析。下一小節將介紹DataLoader類的使用。
補充學習建議
- IDE的debug: 下一小節的代碼將採用debug模式進行逐步分析,建議大家提前熟悉pycharm等IDE的debug功能。
- python的反覆運算器:相信很多初學者對代碼中的“next(iter(train_set))”不太瞭解,這裡建議大家瞭解iter概念、next概念、反覆運算器概念、以及雙底線函數概念。
留言列表