import torch import torch.nn.functional as F def to_binary_tensor(input_tensor, bits): int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64) shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device) binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1 return binary_bits def generate_unfold_index(input_shape, kernel_size, stride, dilation=1): padding = 0 B, C, H, W = input_shape if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) if isinstance(padding, int): padding = (padding, padding) if isinstance(dilation, int): dilation = (dilation, dilation) kH, kW = kernel_size dH, dW = dilation pH, pW = padding sH, sW = stride # 计算输出窗口数量 out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1 out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1 # 构造索引 batch_idx = [] channel_idx = [] h_idx = [] w_idx = [] for b in range(B): for c in range(C): for i in range(out_h): for j in range(out_w): h_start = i * sH w_start = j * sW for kh in range(kH): for kw in range(kW): batch_idx.append(b) channel_idx.append(c) h_idx.append(h_start + kh * dH) w_idx.append(w_start + kw * dW) # 转换为 tensor batch_idx = torch.tensor(batch_idx, dtype=torch.long) channel_idx = torch.tensor(channel_idx, dtype=torch.long) h_idx = torch.tensor(h_idx, dtype=torch.long) w_idx = torch.tensor(w_idx, dtype=torch.long) return (batch_idx, channel_idx, h_idx, w_idx) def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1): x = torch.randn(batch_size, channels, height, width) index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation) unfolded_by_index = x[index] unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation) unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size) unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2) unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape) print("Shape of unfolded_by_index:", unfolded_by_index.shape) print("Shape of unfolded_by_f:", unfolded_by_f.shape) # 检查是否一致 return torch.allclose(unfolded_by_index, unfolded_by_f) if __name__ == "__main__": batch_size = 2 channels = 2 height = 2 width = 2 kernel_size = 2 stride = 1 dilation = 1 result = test(batch_size, channels, height, width, kernel_size, stride, dilation) print("Are the results equal?", result) batch_size = 5 channels = 3 height = 4 width = 5 kernel_size = 2 stride = 2 dilation = 2 result = test(batch_size, channels, height, width, kernel_size, stride, dilation) print("Are the results equal?", result)