Add safe softmax demo code.
This commit is contained in:
parent
c4e9637c10
commit
0600d46f2f
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# 定义输入向量
|
||||||
|
input_tensor = torch.tensor([1.12, 2.32, 3.43, 4.76, 4.543, 5.43, 6.75, 7.24])
|
||||||
|
|
||||||
|
# 标准 Softmax 计算
|
||||||
|
standard_softmax = F.softmax(input_tensor, dim=0)
|
||||||
|
|
||||||
|
print("标准 Softmax 结果:\n", standard_softmax.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
# 将输入向量拆分为两部分
|
||||||
|
part1 = input_tensor[:4] # 前4个元素
|
||||||
|
part2 = input_tensor[4:] # 后4个元素
|
||||||
|
|
||||||
|
part1_max = part1.max()
|
||||||
|
part1_exp = torch.exp(part1 - part1_max)
|
||||||
|
part1_exp_sum = part1_exp.sum()
|
||||||
|
softmax_part1 = part1_exp
|
||||||
|
|
||||||
|
part2_max = part2.max()
|
||||||
|
part2_exp = torch.exp(part2 - part2_max)
|
||||||
|
part2_exp_sum = part2_exp.sum()
|
||||||
|
softmax_part2 = part2_exp
|
||||||
|
|
||||||
|
print("softmax_part1 结果:\n", softmax_part1.numpy())
|
||||||
|
print("softmax_part2 结果:\n", softmax_part2.numpy())
|
||||||
|
|
||||||
|
# 计算全局最大值
|
||||||
|
max_global = torch.max(part1_max, part2_max)
|
||||||
|
exp_sum_global = part1_exp_sum * torch.exp(part1_max - max_global) + part2_exp_sum * torch.exp(part2_max - max_global)
|
||||||
|
|
||||||
|
global_coeff_part1 = softmax_part1 * torch.exp(part1_max - max_global) / exp_sum_global
|
||||||
|
global_coeff_part2 = softmax_part2 * torch.exp(part2_max - max_global) / exp_sum_global
|
||||||
|
|
||||||
|
merged_softmax = torch.cat([global_coeff_part1, global_coeff_part2], dim=-1)
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
print("标准 Softmax 结果:\n", standard_softmax.numpy())
|
||||||
|
print("合并后的 Softmax 结果:\n", merged_softmax.numpy())
|
||||||
|
|
||||||
|
# 验证结果是否一致
|
||||||
|
print("结果是否一致:", torch.allclose(standard_softmax, merged_softmax, atol=1e-6))
|
Loading…
Reference in New Issue