Merge pull request #18 from clang-ykt/matmul-shape
Infer shape for MatMul operation
This commit is contained in:
commit
31116ec3c2
|
@ -343,17 +343,98 @@ void ONNXMatMulOp::inferShapes() {
|
|||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
auto lhsShape = lhsTy.getShape();
|
||||
auto rhsShape = rhsTy.getShape();
|
||||
|
||||
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
||||
// Multiplication by scalars is not allowed.
|
||||
emitError("Multiplication by scalar arguments not allowed.");
|
||||
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||
// Special case when both arrays are 1-dimensional and according to
|
||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||
// need to be removed after the multiplication but cannot be removed if all
|
||||
// sizes are 1.
|
||||
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[0] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
dims.emplace_back(1);
|
||||
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned lhsRank = lhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < lhsRank - 1; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
||||
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
||||
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
||||
// =>
|
||||
// (u1 x u2 x... x uK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned lhsRank = lhsShape.size();
|
||||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
// Check and perform broadcasting for the shapes.
|
||||
SmallVector<int64_t, 2> lhsBcastShape;
|
||||
for (int i = 0; i < lhsRank - 2; ++i)
|
||||
lhsBcastShape.emplace_back(lhsShape[i]);
|
||||
SmallVector<int64_t, 2> rhsBcastShape;
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||
emitError("Broadcasted dimensions are incompatible.");
|
||||
|
||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
} else {
|
||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||
int64_t lhsDim = lhsShape[0];
|
||||
int64_t rhsDim = rhsShape[0];
|
||||
if (lhsShape.size() > 1) {
|
||||
lhsDim = lhsShape[1];
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Verify that matrix sizes are valid.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
// Check legality of matrix multiplication.
|
||||
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
if (rhsShape.size() > 1)
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
|
|
@ -8,22 +8,114 @@
|
|||
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_default_transpose
|
||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||
}
|
||||
|
||||
/// Test shape inference for transposition when perm attribute is specified.
|
||||
|
||||
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_transpose
|
||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
|
||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test the shape inferencing scheme for the matmul operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// MatMul: 1-D x 1-D
|
||||
|
||||
func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_1
|
||||
// CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<1xf32>
|
||||
// CHECK: return [[RES1]] : tensor<1xf32>
|
||||
}
|
||||
|
||||
/// MatMul: K-D x 2-D (K > 2)
|
||||
|
||||
func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_2
|
||||
// CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32>
|
||||
// CHECK: return [[RES2]] : tensor<16x?x64x32xf32>
|
||||
}
|
||||
|
||||
/// MatMul: 2-D x K-D (K > 2)
|
||||
|
||||
func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_3
|
||||
// CHECK: [[RES3:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<16x?x64x32xf32>
|
||||
// CHECK: return [[RES3]] : tensor<16x?x64x32xf32>
|
||||
}
|
||||
|
||||
/// MatMul: 2-D x K-D (K > 2)
|
||||
|
||||
func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_4
|
||||
// CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x64x?xf32>
|
||||
// CHECK: return [[RES4]] : tensor<?x?x64x?xf32>
|
||||
}
|
||||
|
||||
/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2)
|
||||
|
||||
func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_5
|
||||
// CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32>
|
||||
// CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32>
|
||||
}
|
||||
|
||||
/// MatMul: 1-D x 2-D
|
||||
|
||||
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_6
|
||||
// CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32>
|
||||
// CHECK: return [[RES6]] : tensor<64xf32>
|
||||
}
|
||||
|
||||
/// MatMul: 2-D x 1-D
|
||||
|
||||
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_7
|
||||
// CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32>
|
||||
// CHECK: return [[RES7]] : tensor<32xf32>
|
||||
}
|
||||
|
||||
/// MatMul: 2-D x 2-D
|
||||
|
||||
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_matmul_8
|
||||
// CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32>
|
||||
// CHECK: return [[RES8]] : tensor<32x128xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test shape inference for ConvNoBias operation and all its attributes.
|
||||
|
|
Loading…
Reference in New Issue