Fix 1 and 2 dimensional cases. Add test for 1 and 2 dimensional combinations.
This commit is contained in:
parent
642f77abed
commit
da0e9b01b1
|
@ -389,8 +389,11 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||||
} else {
|
} else {
|
||||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||||
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
if (lhsShape.size() != 1)
|
||||||
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
dims.emplace_back(lhsShape[0]);
|
||||||
|
|
||||||
|
if (rhsShape.size() != 1)
|
||||||
|
dims.emplace_back(rhsShape[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
|
|
@ -60,4 +60,34 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x
|
||||||
// CHECK-LABEL: test_matmul_5
|
// CHECK-LABEL: test_matmul_5
|
||||||
// CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32>
|
// CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32>
|
||||||
// CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32>
|
// CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MatMul: 1-D x 2-D
|
||||||
|
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_matmul_6
|
||||||
|
// CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32>
|
||||||
|
// CHECK: return [[RES6]] : tensor<64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MatMul: 2-D x 1-D
|
||||||
|
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_matmul_7
|
||||||
|
// CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32>
|
||||||
|
// CHECK: return [[RES7]] : tensor<32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MatMul: 2-D x 2-D
|
||||||
|
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_matmul_8
|
||||||
|
// CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32>
|
||||||
|
// CHECK: return [[RES8]] : tensor<32x128xf32>
|
||||||
}
|
}
|
Loading…
Reference in New Issue