Handle 1-D MATMUL N-D (#56)
This commit is contained in:
parent
195bf9d15d
commit
f3047943a1
|
@ -371,6 +371,42 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
lhsShape[0] != rhsShape[0])
|
lhsShape[0] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
dims.emplace_back(1);
|
dims.emplace_back(1);
|
||||||
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
||||||
|
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
||||||
|
// 1 to its dimensions. After matrix multiplication the prepended 1 is
|
||||||
|
// removed.
|
||||||
|
//
|
||||||
|
// N MATMUL (s1 x s2 x... x sK x N x P)
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x P)
|
||||||
|
|
||||||
|
// Check legality of matrix multiplication.
|
||||||
|
unsigned rhsRank = rhsShape.size();
|
||||||
|
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||||
|
lhsShape[0] != rhsShape[rhsRank - 2])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
for (int i = 0; i < rhsRank - 2; ++i)
|
||||||
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||||
|
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
|
||||||
|
// If the second argument is 1-D, it is promoted to a matrix by appending a
|
||||||
|
// 1 to its dimensions. After matrix multiplication the appended 1 is
|
||||||
|
// removed.
|
||||||
|
//
|
||||||
|
// (s1 x s2 x... x sK x M x N) MATMUL N
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x M)
|
||||||
|
|
||||||
|
// Check legality of matrix multiplication.
|
||||||
|
unsigned lhsRank = lhsShape.size();
|
||||||
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
for (int i = 0; i < lhsRank - 2; ++i)
|
||||||
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||||
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||||
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
||||||
// =>
|
// =>
|
||||||
|
|
|
@ -117,6 +117,28 @@ func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> te
|
||||||
// CHECK: return [[RES8]] : tensor<32x128xf32>
|
// CHECK: return [[RES8]] : tensor<32x128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// MatMul: 1-D x N-D
|
||||||
|
|
||||||
|
func @test_matmul_9(%arg0 : tensor<42xf32>, %arg1 : tensor<?x42x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<42xf32>, tensor<?x42x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_matmul_9
|
||||||
|
// CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<42xf32>, tensor<?x42x32xf32>) -> tensor<?x32xf32>
|
||||||
|
// CHECK: return [[RES1]] : tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MatMul: N-D x 1-D
|
||||||
|
|
||||||
|
func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x42x32xf32>, tensor<32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_matmul_10
|
||||||
|
// CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x42x32xf32>, tensor<32xf32>) -> tensor<?x42xf32>
|
||||||
|
// CHECK: return [[RES1]] : tensor<?x42xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test shape inference for ConvNoBias operation and all its attributes.
|
/// Test shape inference for ConvNoBias operation and all its attributes.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue