Add safe softmax demo code.

This commit is contained in:
Colin 2025-03-09 14:39:34 +08:00
parent c4e9637c10
commit 0600d46f2f
1 changed files with 46 additions and 0 deletions

46
test/safe_softmax.py Normal file
View File

@ -0,0 +1,46 @@
import torch
import torch.nn.functional as F
import numpy as np
# 定义输入向量
input_tensor = torch.tensor([1.12, 2.32, 3.43, 4.76, 4.543, 5.43, 6.75, 7.24])
# 标准 Softmax 计算
standard_softmax = F.softmax(input_tensor, dim=0)
print("标准 Softmax 结果:\n", standard_softmax.numpy())
# 将输入向量拆分为两部分
part1 = input_tensor[:4] # 前4个元素
part2 = input_tensor[4:] # 后4个元素
part1_max = part1.max()
part1_exp = torch.exp(part1 - part1_max)
part1_exp_sum = part1_exp.sum()
softmax_part1 = part1_exp
part2_max = part2.max()
part2_exp = torch.exp(part2 - part2_max)
part2_exp_sum = part2_exp.sum()
softmax_part2 = part2_exp
print("softmax_part1 结果:\n", softmax_part1.numpy())
print("softmax_part2 结果:\n", softmax_part2.numpy())
# 计算全局最大值
max_global = torch.max(part1_max, part2_max)
exp_sum_global = part1_exp_sum * torch.exp(part1_max - max_global) + part2_exp_sum * torch.exp(part2_max - max_global)
global_coeff_part1 = softmax_part1 * torch.exp(part1_max - max_global) / exp_sum_global
global_coeff_part2 = softmax_part2 * torch.exp(part2_max - max_global) / exp_sum_global
merged_softmax = torch.cat([global_coeff_part1, global_coeff_part2], dim=-1)
# 输出结果
print("标准 Softmax 结果:\n", standard_softmax.numpy())
print("合并后的 Softmax 结果:\n", merged_softmax.numpy())
# 验证结果是否一致
print("结果是否一致:", torch.allclose(standard_softmax, merged_softmax, atol=1e-6))