Update unsuper minist.

This commit is contained in:
Colin 2024-11-04 14:23:29 +08:00
parent df05002c90
commit f74a5d29bd
2 changed files with 22 additions and 26 deletions

View File

@ -1,21 +1,9 @@
## 方向
1. 输入的信噪比 1. 输入的信噪比
2. loss函数的设计 2. loss函数的设计
3. grad信息的应用 3. grad信息的应用
abs.cpu().detach().numpy()
array([[8.1206687e-02, 2.2388995e-05, 3.7080176e-02, 5.7033218e-02,
1.7404296e-03, 7.6270252e-02, 5.9453689e-02, 4.0801242e-05]],
dtype=float32)
ratio_nor.cpu().detach().numpy()
array([[4.3121886e+00, 3.2778070e-07, 8.9907879e-01, 2.1270120e+00,
1.9807382e-03, 3.8038602e+00, 2.3113825e+00, 1.0885816e-06]],
dtype=float32)
都比较差的时候区分不开
## 发现的问题 ## 发现的问题
1. 3x3的时候会有重复 1. 3x3的时候会有重复
1. 重复的权重,虽然权重看起来都一样,但是有稍微的不同,不是完全一样 1. 重复的权重,虽然权重看起来都一样,但是有稍微的不同,不是完全一样
@ -31,10 +19,13 @@ array([[4.3121886e+00, 3.2778070e-07, 8.9907879e-01, 2.1270120e+00,
2. 网格状就是为了尽量降低最终输出的绝对值降低loss 2. 网格状就是为了尽量降低最终输出的绝对值降低loss
1. 形成4个一组一共2组的形式交替成为最大值把另外一组的输出降低最后都输出最低的绝对值 1. 形成4个一组一共2组的形式交替成为最大值把另外一组的输出降低最后都输出最低的绝对值
2. 需要类似batchnormal的方式对各个conv核心之间进行归一化 2. 需要类似batchnormal的方式对各个conv核心之间进行归一化
3. 采用label不改变原来sample output的 abs均值的限制 3. 采用label不改变原来sample output的 abs均值的限制也就是不改变所有卷积核输出的总能量
1. 生成2个分别交替充当最大值的极端只有2个像素不是0的卷积核其他的卷积核都是输出接近0的网格 1. 生成2个分别交替充当最大值的极端只有2个像素不是0的卷积核
2. 其他的几个卷积核都是输出接近0的网格
3. 为什么生成了极端的卷积核?为什么还是有相同的卷积核?是不是数据集导致了这样的结果?
4. 多个卷积核之间需要有差异,同一个卷积核的不同样本输入也要有差异,卷积核的分布要有要求 4. 多个卷积核之间需要有差异,同一个卷积核的不同样本输入也要有差异,卷积核的分布要有要求
5. 每个卷积核尽量平摊权重到所有像素,而不是集中一个像素?提高鲁棒性? 5. 每个卷积核尽量平摊权重到所有像素,而不是集中一个像素?提高鲁棒性?
6. 采用自动的ratio增益控制之后好像没有重复了卷积核了但是还有极端的2个像素有效的卷积核
## 可能的策略 ## 可能的策略
1. 每个卷积核的改变权重(grad)能量守恒 1. 每个卷积核的改变权重(grad)能量守恒

View File

@ -117,7 +117,7 @@ for epoch in range(epochs):
images = images.to(device) images = images.to(device)
# images = torch.ones((1, 1, 5, 5), device=device) # images = torch.ones((1, 1, 5, 5), device=device)
# type = random.randint(0, 7) # type = random.randint(0, 3)
# if type == 0: # if type == 0:
# rand = random.randint(0, 4) # rand = random.randint(0, 4)
# images[:, :, rand, :] = images[:, :, rand, :] * 0.5 # images[:, :, rand, :] = images[:, :, rand, :] * 0.5
@ -131,8 +131,8 @@ for epoch in range(epochs):
# images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5 # images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5
# images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5 # images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5
# if type == 3: # if type == 3:
# randx = random.randint(0, 2) # randx = random.randint(1, 3)
# randy = random.randint(0, 2) # randy = random.randint(1, 3)
# images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5 # images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5
# images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5 # images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5
# images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5 # images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5
@ -143,7 +143,6 @@ for epoch in range(epochs):
outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8 outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8
sample = outputs.reshape(-1, outputs.shape[3]) # -> 36864 8 sample = outputs.reshape(-1, outputs.shape[3]) # -> 36864 8
# sample = outputs.reshape(-1, 8,24*24) # -> 36864 8 # sample = outputs.reshape(-1, 8,24*24) # -> 36864 8
# sample = torch.mean(sample,dim=2) # -> 36864 8 # sample = torch.mean(sample,dim=2) # -> 36864 8
@ -152,7 +151,15 @@ for epoch in range(epochs):
mean = torch.mean(abs, dim=1) mean = torch.mean(abs, dim=1)
mean = torch.expand_copy(mean.reshape(-1, 1), abs.shape) mean = torch.expand_copy(mean.reshape(-1, 1), abs.shape)
max = torch.expand_copy(max.reshape(-1, 1), abs.shape) max = torch.expand_copy(max.reshape(-1, 1), abs.shape)
ratio = torch.pow(abs / mean, 2)
e = torch.sum(torch.pow(abs - mean, 2), dim=1)
e = torch.expand_copy(e.reshape(-1, 1), abs.shape)
e = 1 / e
e = torch.where(torch.isinf(e), 1.0, e)
e = torch.pow(e, 0.5)
ratio = abs / mean * e
# ratio = torch.pow(abs / mean, e )
ratio = torch.where(torch.isnan(ratio), 0.0, ratio) ratio = torch.where(torch.isnan(ratio), 0.0, ratio)
label = ratio * abs label = ratio * abs
@ -160,15 +167,13 @@ for epoch in range(epochs):
label = label - label_mean + mean label = label - label_mean + mean
sample = torch.abs(sample) sample = torch.abs(sample)
sample_nz = sample[abs > 0] loss = F.l1_loss(sample[abs > 0], label[abs > 0])
label_nz = label[abs > 0]
loss = F.l1_loss(sample_nz, label_nz)
model.conv1.weight.grad = None model.conv1.weight.grad = None
loss.backward() loss.backward()
# if epoch >= (epochs - 1): # if epoch >= (epochs - 1):
# continue # continue
model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 0.1 model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 0.01
model.conv1.weight.data = model.normal_conv1_weight() model.conv1.weight.data = model.normal_conv1_weight()
if (i + 1) % 100 == 0: if (i + 1) % 100 == 0: