Add safe softmax demo code.
This commit is contained in:
		
							parent
							
								
									c4e9637c10
								
							
						
					
					
						commit
						0600d46f2f
					
				|  | @ -0,0 +1,46 @@ | |||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import numpy as np | ||||
| 
 | ||||
| 
 | ||||
| # 定义输入向量 | ||||
| input_tensor = torch.tensor([1.12, 2.32, 3.43, 4.76, 4.543, 5.43, 6.75, 7.24]) | ||||
| 
 | ||||
| # 标准 Softmax 计算 | ||||
| standard_softmax = F.softmax(input_tensor, dim=0) | ||||
| 
 | ||||
| print("标准 Softmax 结果:\n", standard_softmax.numpy()) | ||||
| 
 | ||||
| 
 | ||||
| # 将输入向量拆分为两部分 | ||||
| part1 = input_tensor[:4]  # 前4个元素 | ||||
| part2 = input_tensor[4:]  # 后4个元素 | ||||
| 
 | ||||
| part1_max = part1.max() | ||||
| part1_exp = torch.exp(part1 - part1_max) | ||||
| part1_exp_sum = part1_exp.sum() | ||||
| softmax_part1 = part1_exp | ||||
| 
 | ||||
| part2_max = part2.max() | ||||
| part2_exp = torch.exp(part2 - part2_max) | ||||
| part2_exp_sum = part2_exp.sum() | ||||
| softmax_part2 = part2_exp | ||||
| 
 | ||||
| print("softmax_part1 结果:\n", softmax_part1.numpy()) | ||||
| print("softmax_part2 结果:\n", softmax_part2.numpy()) | ||||
| 
 | ||||
| # 计算全局最大值 | ||||
| max_global = torch.max(part1_max, part2_max) | ||||
| exp_sum_global = part1_exp_sum * torch.exp(part1_max - max_global) + part2_exp_sum * torch.exp(part2_max - max_global) | ||||
| 
 | ||||
| global_coeff_part1 = softmax_part1 * torch.exp(part1_max - max_global) / exp_sum_global | ||||
| global_coeff_part2 = softmax_part2 * torch.exp(part2_max - max_global) / exp_sum_global | ||||
| 
 | ||||
| merged_softmax = torch.cat([global_coeff_part1, global_coeff_part2], dim=-1) | ||||
| 
 | ||||
| # 输出结果 | ||||
| print("标准 Softmax 结果:\n", standard_softmax.numpy()) | ||||
| print("合并后的 Softmax 结果:\n", merged_softmax.numpy()) | ||||
| 
 | ||||
| # 验证结果是否一致 | ||||
| print("结果是否一致:", torch.allclose(standard_softmax, merged_softmax, atol=1e-6)) | ||||
		Loading…
	
		Reference in New Issue