def get_mean_std_value(loader):
'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的均值和,dim=1 为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的平方均值和,dim=1 为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计 batch 的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
为什么可以这样计算均值,从这个代码中我的到一个结论:"每个样本均值的和/样本数=整体数据的均值"
有点不太理解这个东西,有大佬能用数学公式证明一下吗
简单说明一下数据情况:这是 CIFAR10 数据集,每个样本的结构是( batch_size,channels,height,width), 即(样本数量,RGB 通道,图片高度,图片宽度)
1
dji38838c 10 天前
这个很显然呀。
比如:假如一共有 300 个数据(a1, a2,... a300),分成 100 组,每组 3 个。 那么 [(a1+a2+a3)/3 + (a4+a5+a6)/3 + .... (a298+a299+a300) / 3] / 100 可以整理成 [(a1+a2+a3+...+a300)/3] / 100 = (a1+a2+...+a300)/300 |
2
NessajCN 10 天前
这里能这么算的前提是每个样本的采样数量,也就是计算始终用来计算 torch.mean() 的分母,都是一样的才成立
|
3
bler OP @dji38838c 我发现这个方法还是存在很大的问题的,这种方法只适用于"总数据量/batch_size=整数"这种情况下的计算出的结果才能成立。假设最后的数据恰好是一个异常数据,那么通过这个计算方式计算出来的均值就是有极大异常的均值
|
4
Eureka0 10 天前 via iPhone
这个只对每个样本的样本容量都一样的情况成立,其实就是
均值=求和(样本均值 i*样本容量 i)/求和(样本容量 i) 样本容量都一样就可以约掉了 |
6
Sawyerhou 10 天前 via Android
如楼上所说,数据量很大的情况下,怎么算都差不多。
|
7
faterazer 10 天前
理论和现实是有 gap 的,楼主的疑问很正常,这就是一种近似计算,当然你也可以算整个数据集的精确均值和方差(更麻烦以及更多的计算时间)。在实践中,近似计算和精确计算不会带来太大的性能差异,一般都是按方便的来。另外 CIFAR10 这样的开源数据集的均值方差都有算好的直接用就行
|
8
bler OP 已经发现这个问题了,我用 chatgpt 问了一下,好多答案都是计算一个大概值,不是一个精确值
|