diff --git a/test/safe_softmax.py b/test/safe_softmax.py new file mode 100644 index 0000000..62b740a --- /dev/null +++ b/test/safe_softmax.py @@ -0,0 +1,46 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +# 定义输入向量 +input_tensor = torch.tensor([1.12, 2.32, 3.43, 4.76, 4.543, 5.43, 6.75, 7.24]) + +# 标准 Softmax 计算 +standard_softmax = F.softmax(input_tensor, dim=0) + +print("标准 Softmax 结果:\n", standard_softmax.numpy()) + + +# 将输入向量拆分为两部分 +part1 = input_tensor[:4] # 前4个元素 +part2 = input_tensor[4:] # 后4个元素 + +part1_max = part1.max() +part1_exp = torch.exp(part1 - part1_max) +part1_exp_sum = part1_exp.sum() +softmax_part1 = part1_exp + +part2_max = part2.max() +part2_exp = torch.exp(part2 - part2_max) +part2_exp_sum = part2_exp.sum() +softmax_part2 = part2_exp + +print("softmax_part1 结果:\n", softmax_part1.numpy()) +print("softmax_part2 结果:\n", softmax_part2.numpy()) + +# 计算全局最大值 +max_global = torch.max(part1_max, part2_max) +exp_sum_global = part1_exp_sum * torch.exp(part1_max - max_global) + part2_exp_sum * torch.exp(part2_max - max_global) + +global_coeff_part1 = softmax_part1 * torch.exp(part1_max - max_global) / exp_sum_global +global_coeff_part2 = softmax_part2 * torch.exp(part2_max - max_global) / exp_sum_global + +merged_softmax = torch.cat([global_coeff_part1, global_coeff_part2], dim=-1) + +# 输出结果 +print("标准 Softmax 结果:\n", standard_softmax.numpy()) +print("合并后的 Softmax 结果:\n", merged_softmax.numpy()) + +# 验证结果是否一致 +print("结果是否一致:", torch.allclose(standard_softmax, merged_softmax, atol=1e-6))