PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样
创始人
2024-05-25 07:09:10

PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

功能:按给定的权重(概率)[p0,p1,…,pn−1][p_0,p_1,\dots,p_{n-1}][p0​,p1​,…,pn−1​]对样本索引[0,1,…,n−1][0,1,\dots,n-1][0,1,…,n−1]采样

输入:

  • weights:采样权重,权重之和不要求为1,该权重需要与每个样本对应起来,即权重数量等于样本数量
  • num_samples:所采样本的数量,可以小于weights的数量
  • replacement:采样策略,如果为True,则代表使用替换采样策略,即可重复对一个样本进行采样;如果为False,则表示不用替换采样策略,即一个样本最多只能被采一次
  • generator:采样过程中的生成器

代码案例

一般用法

from torch.utils.data import WeightedRandomSamplersampler = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8)
print([i for i in sampler])

输出

这里采样得到的都是样本的索引

[5, 4, 6, 7, 0, 4, 4, 6]

replacement设为TrueFalse的区别

from torch.utils.data import WeightedRandomSamplersampler_t = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=True)
sampler_f = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=False)
print('sampler_t:', [i for i in sampler_t])
print('sampler_f:', [i for i in sampler_f])

输出

# replacement设为True时,会对同一样本多次采样
sampler_t: [6, 1, 6, 6, 3, 3, 8, 4]
# 否则每个样本只采样一次
sampler_f: [7, 0, 2, 4, 1, 3, 8, 5]

官方文档

torch.utils.data.WeightedRandomSampler:https://pytorch.org/docs/stable/data.html?highlight=sampler#torch.utils.data.WeightedRandomSampler

初步完稿于:2022年2月22日

相关内容

热门资讯

北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...