Witllm/binary/unfold.py

101 lines
3.2 KiB
Python

import torch
import torch.nn.functional as F
def to_binary_tensor(input_tensor, bits):
int_tensor = torch.round(input_tensor).clamp(0, 2**bits - 1).to(torch.int64)
shifts = torch.arange(bits - 1, -1, -1, device=int_tensor.device)
binary_bits = (int_tensor.unsqueeze(-1) >> shifts) & 1
return binary_bits
def generate_unfold_index(input_shape, kernel_size, stride, dilation=1):
padding = 0
B, C, H, W = input_shape
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation)
kH, kW = kernel_size
dH, dW = dilation
pH, pW = padding
sH, sW = stride
# 计算输出窗口数量
out_h = (H + 2 * pH - dH * (kH - 1) - 1) // sH + 1
out_w = (W + 2 * pW - dW * (kW - 1) - 1) // sW + 1
# 构造索引
batch_idx = []
channel_idx = []
h_idx = []
w_idx = []
for b in range(B):
for c in range(C):
for i in range(out_h):
for j in range(out_w):
h_start = i * sH
w_start = j * sW
for kh in range(kH):
for kw in range(kW):
batch_idx.append(b)
channel_idx.append(c)
h_idx.append(h_start + kh * dH)
w_idx.append(w_start + kw * dW)
# 转换为 tensor
batch_idx = torch.tensor(batch_idx, dtype=torch.long)
channel_idx = torch.tensor(channel_idx, dtype=torch.long)
h_idx = torch.tensor(h_idx, dtype=torch.long)
w_idx = torch.tensor(w_idx, dtype=torch.long)
return (batch_idx, channel_idx, h_idx, w_idx)
def test(batch_size=2, channels=2, height=2, width=2, kernel_size=2, stride=1, dilation=1):
x = torch.randn(batch_size, channels, height, width)
index = generate_unfold_index(input_shape=x.shape, kernel_size=kernel_size, stride=stride, dilation=dilation)
unfolded_by_index = x[index]
unfolded_by_f = F.unfold(x, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation)
unfolded_by_index = unfolded_by_index.view(batch_size, channels, -1, kernel_size * kernel_size)
unfolded_by_index = unfolded_by_index.permute(0, 1, 3, 2)
unfolded_by_index = unfolded_by_index.reshape(unfolded_by_f.shape)
print("Shape of unfolded_by_index:", unfolded_by_index.shape)
print("Shape of unfolded_by_f:", unfolded_by_f.shape)
# 检查是否一致
return torch.allclose(unfolded_by_index, unfolded_by_f)
if __name__ == "__main__":
batch_size = 2
channels = 2
height = 2
width = 2
kernel_size = 2
stride = 1
dilation = 1
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
print("Are the results equal?", result)
batch_size = 5
channels = 3
height = 4
width = 5
kernel_size = 2
stride = 2
dilation = 2
result = test(batch_size, channels, height, width, kernel_size, stride, dilation)
print("Are the results equal?", result)