101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def to_binary_tensor(input_tensor, bits):
|
|
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
|
|
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
|
|
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
|
|
return binary_bits
|
|
|
|
|
|
def generate_unfold_index(input_shape, kernel_size, stride, dilation=1):
|
|
padding = 0
|
|
B, C, H, W = input_shape
|
|
if isinstance(kernel_size, int):
|
|
kernel_size = (kernel_size, kernel_size)
|
|
if isinstance(stride, int):
|
|
stride = (stride, stride)
|
|
if isinstance(padding, int):
|
|
padding = (padding, padding)
|
|
if isinstance(dilation, int):
|
|
dilation = (dilation, dilation)
|
|
|
|
kH, kW = kernel_size
|
|
dH, dW = dilation
|
|
pH, pW = padding
|
|
sH, sW = stride
|
|
|
|
# 计算输出窗口数量
|
|
out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1
|
|
out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1
|
|
|
|
# 构造索引
|
|
batch_idx = []
|
|
channel_idx = []
|
|
h_idx = []
|
|
w_idx = []
|
|
|
|
for b in range(B):
|
|
for c in range(C):
|
|
for i in range(out_h):
|
|
for j in range(out_w):
|
|
h_start = i * sH
|
|
w_start = j * sW
|
|
for kh in range(kH):
|
|
for kw in range(kW):
|
|
batch_idx.append(b)
|
|
channel_idx.append(c)
|
|
h_idx.append(h_start + kh * dH)
|
|
w_idx.append(w_start + kw * dW)
|
|
|
|
# 转换为 tensor
|
|
batch_idx = torch.tensor(batch_idx, dtype=torch.long)
|
|
channel_idx = torch.tensor(channel_idx, dtype=torch.long)
|
|
h_idx = torch.tensor(h_idx, dtype=torch.long)
|
|
w_idx = torch.tensor(w_idx, dtype=torch.long)
|
|
|
|
return (batch_idx, channel_idx, h_idx, w_idx)
|
|
|
|
|
|
def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1):
|
|
|
|
x = torch.randn(batch_size, channels, height, width)
|
|
index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation)
|
|
unfolded_by_index = x[index]
|
|
|
|
unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation)
|
|
|
|
unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size)
|
|
unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2)
|
|
unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape)
|
|
|
|
print("Shape of unfolded_by_index:", unfolded_by_index.shape)
|
|
print("Shape of unfolded_by_f:", unfolded_by_f.shape)
|
|
|
|
# 检查是否一致
|
|
return torch.allclose(unfolded_by_index, unfolded_by_f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
batch_size = 2
|
|
channels = 2
|
|
height = 2
|
|
width = 2
|
|
kernel_size = 2
|
|
stride = 1
|
|
dilation = 1
|
|
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
|
print("Are the results equal?", result)
|
|
|
|
batch_size = 5
|
|
channels = 3
|
|
height = 4
|
|
width = 5
|
|
kernel_size = 2
|
|
stride = 2
|
|
dilation = 2
|
|
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
|
|
print("Are the results equal?", result)
|