2.2 新冠肺炎X光分類

上一節,我們學習了pytorch python API的結構,本節將以一個具體的案例介紹pytorch模型訓練流程,並提出一系列問題,供大家思考。當然,這些問題也是本書後續章節一一解答的內容。

相信絕大多數朋友接觸過或者看到的第一個Hello World級圖像分類都是Mnist,思來想去覺得還是換點東西,於是選擇了當下與所有人都息息相關的案例——新型冠狀病毒肺炎Corona Virus Disease 2019COVID-19,簡稱新冠肺炎。關於新冠肺炎的背景,大家很熟悉了,口罩、綠碼、核酸檢測已經融入了我們的生活。因此,想讓大家更進一步的瞭解COVID-19,所以選用此案例。當然,最重要的目的是要瞭解pytorch如何完成模型訓練。

案例背景

20201月底2月初的時候,新冠在國內/外大流行。而確定一個人是否患有新冠肺炎,是尤為重要的事情。新冠肺炎的確診需要通過核酸檢測完成,但是核酸檢測並不是那麼容易完成,需要醫護人員採樣、送檢、在PCR儀器上進行檢測、出結果、發報告等一系列複雜工序,核酸檢測產能完全達不到當時的檢測需求。在疫情初期,就有醫生提出,是否可以採用特殊方法進行診斷,例如通過CTX光的方法,給病人進行X光攝影,幾分鐘就能看出結果,比核酸檢測快了不少。於是,新冠肺炎患者的胸片X光資料就不斷的被收集,並發佈到網上供全球科學家使用,共同抗擊新冠疫情。這裡就採用了https://github.com/ieee8023/covid-chestxray-dataset上的資料,同時採用了正常人的X光片,來自於:https://github.com/zoogzog/chexnet

由於本案例目的是pytorch流程學習,為了簡化學習過程,資料僅選擇了4張圖片,分為2類,正常與新冠,訓練集2張,驗證集2張。標籤資訊存儲於TXT檔中。具體目錄結構如下:

├─imgs

  ├─covid-19

        auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg

        ryct.2020200028.fig1a.jpeg

 

  └─no-finding

          00001215_000.png

          00001215_001.png

└─labels

   train.txt

   valid.txt

Copy

建模思路

這是一個典型的圖像分類任務,此處採用面向過程的思路給大家介紹如何進行代碼編寫。

step 1 數據

需要編寫代碼完成資料的讀取,轉換成模型能夠讀取的格式。這裡涉及pytorchdatasetdataloadertransforms等模組。以及需要清楚地知道pytorch的模型需要怎樣的格式。

資料模組需要完整的工作大體如下圖所示:

<<AI人工智慧 PyTorch自學>> 2.2 新冠肺炎X

首先,需要將資料在硬碟上的資訊,如路徑,標籤讀取並存儲起來,然後被使用,這一步驟主要是通過COVID19Dataset這個類。類裡有四個函數,除了Dataset類必須要實現的三個外,我們通過get_img_info()函數實現讀取硬碟中的路徑、標籤等資訊,並存儲到一個清單中。後續大家可以根據不同的任務情況在這個函數中修改,只要能獲取到資料的資訊,供\_getitem__函數進行讀取。

接著,使用dataloader進行封裝,dataloader是一個資料載入器,提供諸多方法進行資料的獲取,如設置一個batch獲取幾個樣本,採用幾個進程進行資料讀取,是否對資料進行隨機化等功能。

下一步,還需要設置對圖像進行預處理(Preprocess)的操作,這裡為了演示,僅採用resize  totensor兩個方法,並且圖片只需要縮放到8x8的大小,並不需要224,256,448,512,1024等大尺寸。(totensor與下一小節內容強相關)

step 2 模型

資料模組構建完畢,需要輸入到模型中,所以我們需要構建神經網路模型,模型接收資料並前向傳播處理,輸出二分類概率向量。此處需要用到nn.Module模組和nn下的各個網路層進行搭建模型,模型的搭建就像搭積木,一層一層地摞起來。模型完成的任務就如下圖所示:下圖示意圖是一張解析度為4x4的圖像輸入到模型中,模型經過運算,輸出二分類概率。中間的“?"是什麼內容呢?

<<AI人工智慧 PyTorch自學>> 2.2 新冠肺炎X

這裡,是構建一個極其簡單的卷積神經網路,僅僅包含兩個網路層,第一層是包含13*3卷積核的2d卷積,第二層是兩個神經元的全連接層(pytorch也叫linear層)。模型的輸入被限制在了8x8,原因在於linear層設置了輸入神經元個數為36 8x836之間是息息相關的,他們之間的關係是什麼?這需要大家對卷積層有一定瞭解。(大家可以改一下36,改為35,或者transforms_func()中的resize改為9x9,看看會報什麼錯誤,這個錯誤是經常會遇到的報錯)

step3 優化

模型可以完成前向傳播之後,根據什麼規則對模型的參數進行更新學習呢?這就需要損失函數和優化器的搭配了,損失函數用於衡量模型輸出與標籤之間的差異,並通過反向傳播獲得每個參數的梯度,有了梯度,就可以用優化器對權重進行更新。這裡就要涉及各種LossFunctionoptim中的優化器,以及學習率調整模組optim.lr_scheduler

這裡,採用的都是常用的方法:交叉熵損失函數(CrossEntropyLoss)、隨機梯度下降法(SGD)和按固定步長下降學習率策略(StepLR)。

step4 反覆運算

有了模型參數更新的必備元件,接下來需要一遍又一遍地給模型喂資料,監控模型訓練狀態,這時候就需要for迴圈,不斷地從dataloader裡取出資料進行前向傳播,反向傳播,參數更新,觀察lossacc,周而復始。當達到條件,如最大反覆運算次數、某指標達到某個值時,進行模型保存,並break迴圈,停止訓練。

以上就是一個經典的面向過程式的代碼編寫,先考慮資料怎麼讀進來,讀進來之後喂給的模型如何搭建,模型如何更新,模型如何反覆運算訓練到滿意。請大家結合代碼一步一步的觀察整體過程。

在經過幾十個epoch的訓練之後達到了100%,模型可以成功區分從未見過的兩張圖片:auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg00001215_000.png

由於資料量少,隨機性非常大,大家多運行幾次,觀察結果。不過本案例結果完全不重要!),可以看到模型的準確率(Accuracy)變化。

一系列問題

通過上述步驟及代碼,雖然完成了一個圖像分類任務,但其中很多細節想必大家還是弄不清楚,例如:

圖像資料是哪用一行代碼讀取進來的?

transforms.Compose是如何工作對圖像資料進行轉換的?

ToTensor又有哪些操作?

自己如何編寫Dataset

DataLoader有什麼功能?如何使用?有什麼需要注意的?

模型如何按自己的資料流程程搭建?

nn有哪些網路層可以調用?

損失函數有哪些?

優化器是如何更新model參數的?

學習率調整有哪些方法?如何設置它們的參數?

model.train()model.eval()作用是什麼?

optimizer.zero_grad()是做什麼?為什麼要梯度清零?

scheduler.step() 作用是什麼?應該放在哪個for迴圈裡?

等等

如果大家能有以上的問題提出,本小節的目的就達到了。大家有了模型訓練的思路,對過程有瞭解,但是使用細節還需進一步學習,更多pytorch基礎內容將會在後續章節一一解答。

下一小節我們將介紹流動在pytorch各個模組中的基礎資料結構——Tensor(張量)。

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()