Add safe softmax demo code.
This commit is contained in:
parent
c4e9637c10
commit
0600d46f2f
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
# 定义输入向量
|
||||
input_tensor = torch.tensor([1.12, 2.32, 3.43, 4.76, 4.543, 5.43, 6.75, 7.24])
|
||||
|
||||
# 标准 Softmax 计算
|
||||
standard_softmax = F.softmax(input_tensor, dim=0)
|
||||
|
||||
print("标准 Softmax 结果:\n", standard_softmax.numpy())
|
||||
|
||||
|
||||
# 将输入向量拆分为两部分
|
||||
part1 = input_tensor[:4] # 前4个元素
|
||||
part2 = input_tensor[4:] # 后4个元素
|
||||
|
||||
part1_max = part1.max()
|
||||
part1_exp = torch.exp(part1 - part1_max)
|
||||
part1_exp_sum = part1_exp.sum()
|
||||
softmax_part1 = part1_exp
|
||||
|
||||
part2_max = part2.max()
|
||||
part2_exp = torch.exp(part2 - part2_max)
|
||||
part2_exp_sum = part2_exp.sum()
|
||||
softmax_part2 = part2_exp
|
||||
|
||||
print("softmax_part1 结果:\n", softmax_part1.numpy())
|
||||
print("softmax_part2 结果:\n", softmax_part2.numpy())
|
||||
|
||||
# 计算全局最大值
|
||||
max_global = torch.max(part1_max, part2_max)
|
||||
exp_sum_global = part1_exp_sum * torch.exp(part1_max - max_global) + part2_exp_sum * torch.exp(part2_max - max_global)
|
||||
|
||||
global_coeff_part1 = softmax_part1 * torch.exp(part1_max - max_global) / exp_sum_global
|
||||
global_coeff_part2 = softmax_part2 * torch.exp(part2_max - max_global) / exp_sum_global
|
||||
|
||||
merged_softmax = torch.cat([global_coeff_part1, global_coeff_part2], dim=-1)
|
||||
|
||||
# 输出结果
|
||||
print("标准 Softmax 结果:\n", standard_softmax.numpy())
|
||||
print("合并后的 Softmax 结果:\n", merged_softmax.numpy())
|
||||
|
||||
# 验证结果是否一致
|
||||
print("结果是否一致:", torch.allclose(standard_softmax, merged_softmax, atol=1e-6))
|
Loading…
Reference in New Issue