在PyTorch DistributedSampler中封装其他Sampler策略
用了PyTorch的分布式训练后,我把所有的dataloader都加上了 DistributedSampler 。
现在遇到的一个问题是需要对不同类别的样本进行采样,而PytTorch自带的WeightedRandomSampler 又不是那么回事,不能直接对类别进行采样,索性自己造了个轮子解决这个问题。
WeightedBalanceClassSampler
首先要解决的是对每个类别进行采样,这里用了catalyst的一部分代码1。catalyst提供了一个 BalanceClassSampler 实现类别均衡采样,但在我的场景下,类别不均衡比较严重,BalanceClassSampler 里将所有类的数量直接填成一样的了,不满足我的要求。
在 BalanceClassSampler 的基础上,这里实现了 WeightedBalanceClassSampler 用于带权采样。 用 weight 指定每个类别的采样比例,用length指定采样后数据集的大小。
weight 归一化后,乘上length计算采样后每个的数目,使用 saferound 保证类型转换后sample的总数仍然是一样的。
在 __iter__ 方法中 使用 np.random.choice 对每个类别下的索引进行采样。
class WeightedBalanceClassSampler(Sampler):
"""Allows you to create stratified sample on unbalanced classes with given probabilities (weights).
Args:
labels: list of class label for each elem in the dataset
weight: A sequence of weights to balance classes, not necessary summing up to one.
length: the length of the sample dataset.
"""
def __init__(
self, labels: List[int], weight: List, length: int,
):
"""Sampler initialisation."""
super().__init__(labels)
labels = np.array(labels).astype(np.int)
self.lbl2idx = {
label: np.arange(len(labels))[labels == label].tolist()
for label in set(labels)
}
weight = np.array(weight)
weight = weight / weight.sum()
samples_per_class = weight * length
samples_per_class = np.array(saferound(samples_per_class, places=0)).astype(np.int)
self.labels = labels
self.samples_per_class = samples_per_class
self.length = length
def __iter__(self) -> Iterator[int]:
"""
Yields:
indices of stratified sample
"""
indices = []
for key in sorted(self.lbl2idx):
replace_flag = self.samples_per_class[key] > len(self.lbl2idx[key])
indices += np.random.choice(
self.lbl2idx[key], self.samples_per_class[key], replace=replace_flag
).tolist()
assert len(indices) == self.length
np.random.shuffle(indices)
return iter(indices)
def __len__(self) -> int:
"""
Returns:
length of result sample
"""
return self.length
DistributedSamplerWrapper
PyTorch的DistributedSampler是直接对dataset进行封装,这里在已经封装了一层 WeightedBalanceClassSampler 后,需要将内部的 sampler 再放到DistributedSampler 内。
这里仍然是用了catalyst的两个类:DatasetFromSampler和DistributedSamplerWrapper。
其中 DatasetFromSampler 将内部的sampler包装成dataset的接口。
class DatasetFromSampler(torch.utils.data.Dataset):
"""Dataset to create indexes from `Sampler`.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
而DistributedSamplerWrapper是继承自PyTorch自带的DistributedSampler。
看 PyTorch DistributedSampler的源码2可以知道,继承后需要覆写它的 __iter__ 方法,实现自己的迭代过程。
父类 DistributedSampler的 __iter__ 方法会返回当前rank下的dataset 索引,即已经处理好了分布式下的sampler,那在这里可以使用父类返回的索引值,对内部的 WeightedBalanceClassSampler 再进行一次索引,实现对 WeightedBalanceClassSampler 的封装。
class DistributedSamplerWrapper(DistributedSampler):
"""
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self):
"""@TODO: Docs. Contribution is welcome."""
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))