Fix 1 and 2 dimensional cases. Add test for 1 and 2 dimensional combinations.

This commit is contained in:
Doru Bercea 2020-01-10 15:16:45 -05:00
parent 642f77abed
commit da0e9b01b1
2 changed files with 35 additions and 2 deletions

View File

@ -389,8 +389,11 @@ void ONNXMatMulOp::inferShapes() {
dims.emplace_back(rhsShape[rightDims - 1]);
} else {
// This case covers all remaining combinations of 1 and 2-D matrices.
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
if (lhsShape.size() != 1)
dims.emplace_back(lhsShape[0]);
if (rhsShape.size() != 1)
dims.emplace_back(rhsShape[1]);
}
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));

View File

@ -60,4 +60,34 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x
// CHECK-LABEL: test_matmul_5
// CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32>
// CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32>
}
/// MatMul: 1-D x 2-D
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul_6
// CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32>
// CHECK: return [[RES6]] : tensor<64xf32>
}
/// MatMul: 2-D x 1-D
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul_7
// CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32>
// CHECK: return [[RES7]] : tensor<32xf32>
}
/// MatMul: 2-D x 2-D
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_matmul_8
// CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32>
// CHECK: return [[RES8]] : tensor<32x128xf32>
}