3.3 系列APIs

API*022-1月至4月完全沒有更新,這三個月發生的事情太多了,無論如何,過去的終將過去,繼續學習PyTorch與人工智慧吧,加油!

前幾個小節已經把pytorch的資料讀取、載入、預處理機制和邏輯關係理清楚了,下面講一下實用的APIs,包括資料集的拼接、截取、劃分,以及十分重要的採樣策略——sampler

concat

在實際專案中,資料的來源往往是多源的,可能是多個中心收集的,也可能來自多個時間段的收集,很難將可用資料統一到一個資料形式。通常有兩種做法,一種是固定一個資料形式,所有獲取到的資料經過整理,變為統一格式,然後用一個Dataset即可讀取。還有一種更為靈活的方式是為每批資料編寫一個Dataset,然後使用torch.utils.data.ConcatDataset類將他們拼接起來,這種方法可以靈活的處理多來源資料,也可以很好的使用別人的資料及Dataset

下面還是來看COVID-19的例子,大家知道想要獲取大量的COVID-19資料,肯定是多源的,不同國家、不同機構、不同時間的X光片收集過來之後,如何把他們整理起來供模型訓練呢?先看這個github倉庫covid-chestxray-dataset,他們採取了以下方法,將採集到的資料統一整理,並生成metadata(元資訊)。基於現成的Dataset,我們可通過拼接的方法將所有資料拼接成一個大的dataset進行使用。

請結合代碼閱讀,在2.23.2中分別實現了COVID19DatasetCOVID19Dataset2COVID19Dataset3,假設在專案開始時拿到了COVID19Dataset,做了一段時間來了新資料23,那麼像把他們放到一起充當訓練集,可以用concat完成。可以看到代碼將3個資料集拼接得到總的資料集,資料量為2+2+2=6。這裡的concatdataset其實還是一個dataset類,它內部還是有lengetitem,裡面的getitem代碼思路值得學習。concatdataset通過給資料集編號、所有樣本編號,然後在__getitem函數中將dataloader傳進來的整體樣本序號進行計算,得到匹配的資料集序號,以及在該資料集內的樣本編號。

可能有點繞,請看圖:假設dataloader想要第5個樣本,傳入index=4 這時getitem會計算第五個樣本在第三個資料集的第1個位置。然後通過self.datasets[datasetidx][sampleidx]來獲取資料。這樣對外進行一層封裝,內部實現仍舊調用各個dataset__getitem,這樣是不是很巧妙呢?

<<AI人工智慧 PyTorch自學>> 3.3 系列API

    def __getitem__(self, idx):

        if idx < 0:

            if -idx > len(self):

                raise ValueError("absolute value of index should not exceed dataset length")

            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)

        if dataset_idx == 0:

            sample_idx = idx

        else:

            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]

        return self.datasets[dataset_idx][sample_idx]

Copy

Subset

subset可根據指定的索引獲取子資料集,Subset也是Dataset類,同樣包含_len___getitem\,其代碼編寫風格也可以學習一下.

CLASStorch.utils.data.Subset(datasetindices)[SOURCE]

Subset of a dataset at specified indices.

Parameters

  • dataset (Dataset) – The whole Dataset
  • indices (sequence) – Indices in the whole set selected for subset

    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:

        self.dataset = dataset

        self.indices = indices

 

    def __getitem__(self, idx):

        if isinstance(idx, list):

            return self.dataset[[self.indices[i] for i in idx]]

        return self.dataset[self.indices[idx]]

 

    def __len__(self):

        return len(self.indices)

Copy

使用上非常簡單,代碼一目了然,不再贅述。

random_split

該函數的功能是隨機的將dataset劃分為多個不重疊的子集,適合用來劃分訓練、驗證集(不過不建議通過它進行,因為對用戶而言,其劃分不可見,不利於分析)。

使用也非常簡單,只需要設置每個子集的資料量,傳給lengths即可。

torch.utils.data.random_split(dataset, lengths, generator=)[SOURCE]

Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.:

Parameters

  • dataset (Dataset) – Dataset to be split
  • lengths (sequence) – lengths of splits to be produced
  • generator (Generator) – Generator used for the random permutation

---------------------------------------------------------------------- 分割線 ------------------------------------------------------------------

sampler

下面進入另外一個主題——sampler sampler是在dataloader中起到挑選資料的功能,主要是設置挑選策略,如按順序挑選、隨機挑選、按類別分概率挑選等等,這些都可以通過自訂sampler實現。

在上一節我們已經用過了一個sampler,那就是batch_sampler,我們先學習一下它的用法,然後再去瞭解 RandomSampler SequentialSampler 以及SubsetRandomSamplerWeightedRandomSampler

sampler的概念比較複雜,建議大家將BatchSamplerRandomSamplerSequentialSampler放在一起學習。

sampler batch_sampler

首先講一下dataloader類的sampler變數與batch_sampler變數的區別,在dataloader裡會有這兩個變數,第一次碰到時候很懵,怎麼還有兩個採樣器,dataloader到底用的哪一個?還是兩個都用?經過一番調試,終於搞清楚了。

本質上它們兩個都是採樣器,當採用auto_collation時,採用batch_sampler。依據如下:dataloader.py 365

@property

def _index_sampler(self):

​    if self._auto_collation:

​        return self.batch_sampler

​    else:

​        return self.sampler

Copy

來看一下兩者定義:

  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

從定義可知道batch_sampler是一次返回一個batch的索引。通常我們用的都是batch_sampler,其對應的是BatchSampler類。

BatchSampler

下麵先學習BatchSampler類。回顧3.3dataloader獲取一個樣本的機制,會在一個self.nextindex()中調用實際的sampler反覆運算器,繼續進入會來到BatchSampler類的__iter函數,dataloader初始化的時候根據參數配置,自動設置了採樣策略為BatchSampler 依據如下:dataloader.py 272行代碼

        if batch_size is not None and batch_sampler is None:

            # auto_collation without custom batch_sampler

            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

Copy

dataloader.py 365

    @property

    def _index_sampler(self):

        if self._auto_collation:

            return self.batch_sampler

        else:

            return self.sampler

Copy

定位到了BatchSampler,下面來看看類的定義以及傳進去的參數是什麼。

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

後兩個參數好理解,第一個參數傳入的是一個sampler採樣器,在這裡會有兩種情況,如果需要shuffle,則傳入RandomSampler,不需要打亂,則傳入SequentialSampler

依據如下, dataloader.py 267行。

                if shuffle:

                    sampler = RandomSampler(dataset, generator=generator)

                else:

                    sampler = SequentialSampler(dataset)

Copy

到這裡,BatchSamplerRandomSamplerSequentialSampler三者之間的關係逐漸清晰.

BatchSampler是在其它兩者之上封裝了一個批抽取的功能,一次yield一個batchindex,而樣本採樣的順序取決於RandomSamplerSequentialSample

來學習一下BatchSampler如何產生一個batch的序號,並且支持drop_last的功能。

    def __iter__(self) -> Iterator[List[int]]:

        batch = []

        for idx in self.sampler:

            batch.append(idx)

            if len(batch) == self.batch_size:

                yield batch

                batch = []

   # for迴圈結束,且batch的數量又不滿足batchsize時,則進入以下代碼

    # 其實就是drop_last的邏輯代碼

        if len(batch) > 0 and not self.drop_last:

            yield batch

Copy

理解了三者的關係(BatchSamplerRandomSamplerSequentialSampler),RandomSamplerSequentialSampler就很容易理解,來看它們的核心iter函數,學習一下如何編寫順序反覆運算器以及隨機反覆運算器。

SequentialSampler

順序反覆運算器相對簡單,是得到一個按順序的反覆運算器。這個順序就來自 range()函數。

    def __iter__(self) -> Iterator[int]:

        return iter(range(len(self.data_source)))

Copy

RandomSampler

RandomSampleriter函數核心在於設置一個隨機策略,隨機策略委託給generator實現,在使用的時候非常簡單,預設情況下會使用這行代碼實現:yield from torch.randperm(n, generator=generator).tolist() 利用torch的隨機方法生成一個隨機整數序列,對於generator預設採用的是隨機一個隨機種子進行設置。更多的隨機概念可以自行瞭解torch.Generator()torch.randperm()

    def __iter__(self) -> Iterator[int]:

        n = len(self.data_source)

        if self.generator is None:

            seed = int(torch.empty((), dtype=torch.int64).random_().item())

            generator = torch.Generator()

            generator.manual_seed(seed)

        else:

            generator = self.generator

 

        if self.replacement:

            for _ in range(self.num_samples // 32):

                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()

            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()

        else:

            yield from torch.randperm(n, generator=generator).tolist()

Copy

接下來介紹另外兩個實用的採樣器:SubsetRandomSamplerWeightedRandomSampler

SubsetRandomSampler

顧名思義,可以通過索引定義一個子集的隨機採樣器,直接看代碼

```

​ def iter(self) -> Iterator[int]:

​ for i in torch.randperm(len(self.indices), generator=self.generator):

​ yield self.indices[i]

從代碼可知道,這個採樣器返回的樣本總數是傳入的索引的長度,這裡體現了subset,而隨機則是每次會隨機的從子集裡挑選1個資料返回。

---------------------------------------------------------------------- 分割線 ------------------------------------------------------------------

WeightedRandomSampler

不知大家是否自行處理過資料均衡採樣?最簡單粗暴的方法是否是把資料少的樣本複製n份,直到所有類別樣本數量一致,這是一種辦法,其實可以通過採樣器進行加權的採樣,下面來看看WeightedRandomSampler

先來看它的原型:

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

  • weights (sequence) – 每個樣本的採樣權重,權重之和不必為1,只需要關心各樣本之間的比例即可。
  • num_samples (int) – 採樣數量,一般設為樣本總量。
  • replacement (bool) –是否有放回採樣。 True,表示有放回。
  • generator (Generator) – 自訂生成器,通常用默認的。

pytorch的機制裡,sampler為每個sample設置權重,因此在設置的時候不僅要指定每個類的採樣概率,還要把各類採樣概率分發到每個樣本上,再傳給WeightedRandomSampler。這個機制與常識有一點點不一樣,直觀的理解應該是為每個類別設置採樣概率就好,但這卻是為每個樣本設置權重,因此需要額外操作兩行代碼。

通過以下兩個案例學習如何使用WeightedRandomSampler

案例1 sampler初認識

# 第一步:計算每個類的採樣概率

weights = torch.tensor([1, 5], dtype=torch.float)

# 第二步:生成每個樣本的採樣概率

train_targets = [sample[1] for sample in train_data.img_info]

samples_weights = weights[train_targets]

# 第三步:產生實體WeightedRandomSampler

sampler_w = WeightedRandomSampler(

​    weights=samples_weights,

​    num_samples=len(samples_weights),

​    replacement=True)

Copy

sampler的構建分三步

  1. 計算各類的採樣概率:這裡手動設置,是為了讓大家可以調整不同的比率,觀察dataloader采出樣本的變化。下一個例子中採用樣本數量進行計算,來達到均衡採樣。
  2. 生成每個樣本的概率:從pytorch機制瞭解到,需要為每個樣本設置採樣概率,這裡採用的方法是按類別分發即可。在這裡有一點需要注意,就是樣本標籤的順序需要與dataset中的getitem中的索引順序保持一致!由於這裡採用了dataset.img_info來維護這個順序,因此可以輕鬆獲得樣本順序。
  3. 產生實體WeightedRandomSampler

通過運行配套代碼可以看到

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 0])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

torch.Size([2]) tensor([1, 1])

Copy

這裡發現出現了很多次[1, 1]。這是因為有放回採樣,並且樣本1的採樣概率比0高很多。

通過這個例子,希望大家能瞭解

  • WeightedRandomSampler的使用流程
  • WeightedRandomSampler採樣機制可以為有放回的
  • 有的樣本在整個loader中可能不會選中

案例2:不均衡資料集進行均衡採樣

點擊進入配套代碼

下面利用WeightedRandomSampler實現一個10類別的不均衡資料集採樣,使它變為1:1的採樣。

下面製作了一個虛擬的不均衡資料集,每個類別數量分別是 10 20..., 100。總共550張樣本,下面希望通過WeightedRandomSampler實現一個dataloader,每次採樣550張樣本,各類別的數量大約為55

代碼的核心在於統計各類樣本的數量,可仔細閱讀

# 第一步:計算各類別的採樣權重

# 計算每個類的樣本數量

train_targets = [sample[1] for sample in train_data.img_info]

label_counter = collections.Counter(train_targets)

class_sample_counts = [label_counter[k] for k in sorted(label_counter)]  # 需要特別注意,此list的順序!

# 計算權重,利用倒數即可

weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float)

Copy

最後可以看到每個epoch採樣到的資料幾乎實現1:1,可以很好的實現按照設置的權重比例採樣。

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

 

Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

Copy

接下來運用sampler

Counter({0: 62, 4: 62, 8: 61, 9: 58, 6: 57, 3: 54, 1: 51, 7: 50, 5: 48, 2: 47})

 

Counter({5: 72, 7: 59, 6: 59, 8: 57, 1: 57, 0: 55, 4: 53, 2: 49, 9: 48, 3: 41})

 

Counter({0: 71, 3: 64, 5: 60, 9: 57, 4: 56, 2: 54, 1: 54, 6: 51, 8: 43, 7: 40})

 

Counter({4: 64, 7: 62, 3: 60, 8: 58, 1: 54, 5: 54, 0: 53, 6: 51, 2: 50, 9: 44})

 

Counter({8: 68, 0: 62, 7: 60, 6: 58, 2: 55, 3: 51, 9: 50, 5: 50, 1: 50, 4: 46})

 

Counter({5: 66, 4: 59, 9: 57, 0: 56, 1: 55, 3: 54, 7: 53, 2: 51, 8: 51, 6: 48})

 

Counter({3: 72, 9: 68, 5: 65, 6: 58, 4: 56, 8: 49, 1: 47, 2: 47, 0: 45, 7: 43})

 

Counter({4: 63, 2: 62, 7: 60, 9: 59, 3: 58, 8: 57, 6: 52, 0: 50, 5: 45, 1: 44})

 

Counter({8: 73, 3: 62, 6: 55, 0: 55, 2: 54, 4: 53, 7: 51, 1: 50, 9: 49, 5: 48})

 

Counter({5: 61, 3: 61, 2: 60, 9: 57, 1: 57, 7: 55, 6: 55, 4: 53, 8: 47, 0: 44})

Copy

進一步地,為了便於大家理解weights (sequence) – a sequence of weights, not necessary summing up to one”這句話,在代碼中增加了

weights = 12345. / torch.tensor(class_sample_counts, dtype=torch.float)

大家可以隨機修改weight的尺度,觀察採樣結果

關於採樣策略有很多的研究,也有多種現成工具庫可以使用,推薦大家看看這個repo

小結

本小結將常用的datasetdataloader配套方法進行了講解,包括資料集的拼接、子集挑選、子集劃分和sampler。其中sampler是漲點神器,推薦掌握。在sampler中,先通過代碼單步調試瞭解RandomSampler,然後順藤摸瓜找到SequentialSamplerSubsetRandomSampler, 最後通過兩個案例詳細介紹漲點神器——WeightedRandomSampler的代碼編寫。

同時推薦大家拓展閱讀關於資料採樣策略對模型精度的論文,典型的主題是——長尾分佈(Long Tail

下一小節將介紹另外一個漲點首選神器——資料增強模組。先從torchvisiontransform模組講起,然後拓展到更強大的Albumentations

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 HCHUNGW 的頭像
    HCHUNGW

    HCHUNGW的部落格

    HCHUNGW 發表在 痞客邦 留言(0) 人氣()