Add shape inference for Ops used by BERT (#249)
* Add shape inference for Ops used by BERT * Erf * Pow * ReduceMean * Dropout * Expand https://github.com/onnx/onnx/blob/master/docs/Operators.md#expand Deduce the value of the shape operand by looking at the producer of the operand. Currently supported producers are: onnx.Constant and onnx.Shape. * Add corresponding tests for each op. * Sort the list of ops with shape inference in gen_onnx_mlir.py in alphabetic order for clarity. * Restart CI Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
20686d8f0f
commit
e1386b0689
|
@ -21,8 +21,11 @@ Update as you push code to the master branch.
|
||||||
| Cos | | v | v | v | | |
|
| Cos | | v | v | v | | |
|
||||||
| Cosh | | v | v | v | | |
|
| Cosh | | v | v | v | | |
|
||||||
| Div | | v | v | v | | M |
|
| Div | | v | v | v | | M |
|
||||||
|
| Dropout | | v | v | | | |
|
||||||
| Elu | | v | v | v | | |
|
| Elu | | v | v | v | | |
|
||||||
|
| Erf | | v | v | | | |
|
||||||
| Exp | | v | v | v | | |
|
| Exp | | v | v | v | | |
|
||||||
|
| Expand | | v | v | | | |
|
||||||
| Gemm | | v | v | v | | U |
|
| Gemm | | v | v | v | | U |
|
||||||
| HardSigmoid | | v | v | v | | |
|
| HardSigmoid | | v | v | v | | |
|
||||||
| Identity | | v | v | v | | |
|
| Identity | | v | v | v | | |
|
||||||
|
@ -35,12 +38,14 @@ Update as you push code to the master branch.
|
||||||
| Mul | | v | v | v | | M |
|
| Mul | | v | v | v | | M |
|
||||||
| Or | | v | v | v | | M |
|
| Or | | v | v | v | | M |
|
||||||
| Pad | | v | V | v | | const only |
|
| Pad | | v | V | v | | const only |
|
||||||
|
| Pow | | v | v | | | M |
|
||||||
| Reciprocal | | v | v | v | | |
|
| Reciprocal | | v | v | v | | |
|
||||||
| ReduceMax | | v | v | v | | |
|
| ReduceMax | | v | v | v | | |
|
||||||
| ReduceL1 | | v | | | | be decomposed into ReduceSum and Abs |
|
| ReduceL1 | | v | | | | be decomposed into ReduceSum and Abs |
|
||||||
| ReduceL2 | | v | | | | be decomposed into ReduceSumSquare and Sqrt |
|
| ReduceL2 | | v | | | | be decomposed into ReduceSumSquare and Sqrt |
|
||||||
| ReduceLogSum | | v | | | | be decomposed into ReduceSum and Log |
|
| ReduceLogSum | | v | | | | be decomposed into ReduceSum and Log |
|
||||||
| ReduceLogSumExp| | v | | | | be decomposed into ReduceLogSum and Exp |
|
| ReduceLogSumExp| | v | | | | be decomposed into ReduceLogSum and Exp |
|
||||||
|
| ReduceMean | | v | v | | | |
|
||||||
| ReduceMin | | v | v | v | | |
|
| ReduceMin | | v | v | v | | |
|
||||||
| ReduceProd | | v | v | v | | |
|
| ReduceProd | | v | v | v | | |
|
||||||
| ReduceSum | | v | v | v | | |
|
| ReduceSum | | v | v | v | | |
|
||||||
|
@ -104,12 +109,9 @@ And add literal tests at each step, and end to end tests once completed.
|
||||||
| DepthToSpace | | | | | | |
|
| DepthToSpace | | | | | | |
|
||||||
| DequantizeLin | | | | | | |
|
| DequantizeLin | | | | | | |
|
||||||
| Det | | | | | | |
|
| Det | | | | | | |
|
||||||
| Dropout | | | | | | |
|
|
||||||
| DynQuantizeLin | | | | | | |
|
| DynQuantizeLin | | | | | | |
|
||||||
| Einsum | | | | | | V |
|
| Einsum | | | | | | V |
|
||||||
| Equal | | | | | | M |
|
| Equal | | | | | | M |
|
||||||
| Erf | | | | | | |
|
|
||||||
| Expand | | | | | | |
|
|
||||||
| EyeLike | | | | | | |
|
| EyeLike | | | | | | |
|
||||||
| Flatten | | | | | | |
|
| Flatten | | | | | | |
|
||||||
| Floor | | | | | | |
|
| Floor | | | | | | |
|
||||||
|
@ -151,7 +153,6 @@ And add literal tests at each step, and end to end tests once completed.
|
||||||
| Not | | | | | | |
|
| Not | | | | | | |
|
||||||
| OneHot | | | | | | |
|
| OneHot | | | | | | |
|
||||||
| PRelu | | | | | | U |
|
| PRelu | | | | | | U |
|
||||||
| Power | | | | | | M |
|
|
||||||
| QLinearConv | | | | | | P |
|
| QLinearConv | | | | | | P |
|
||||||
| QLinearMatMul | | | | | | M |
|
| QLinearMatMul | | | | | | M |
|
||||||
| QuantizeLinear | | | | | | |
|
| QuantizeLinear | | | | | | |
|
||||||
|
@ -161,7 +162,6 @@ And add literal tests at each step, and end to end tests once completed.
|
||||||
| RandUniform | | | | | | |
|
| RandUniform | | | | | | |
|
||||||
| RandUniformLike| | | | | | |
|
| RandUniformLike| | | | | | |
|
||||||
| Range | | | | | | |
|
| Range | | | | | | |
|
||||||
| ReduceMean | | | | | | |
|
|
||||||
| Resize | | | | | | |
|
| Resize | | | | | | |
|
||||||
| ReverseSequence| | | | | | |
|
| ReverseSequence| | | | | | |
|
||||||
| RoiAlign | | | | | | |
|
| RoiAlign | | | | | | |
|
||||||
|
|
|
@ -769,6 +769,29 @@ LogicalResult ONNXAbsOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Erf
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXErfOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pow
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXPowOp::inferShapes() {
|
||||||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor(s) not ranked");
|
||||||
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1294,6 +1317,19 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReduceMean
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXReduceMeanOp::inferShapes() {
|
||||||
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
|
||||||
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReduceMin
|
// ReduceMin
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2655,6 +2691,90 @@ LogicalResult ONNXSliceOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Expand
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXExpandOp::inferShapes() {
|
||||||
|
if (!input().getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
|
||||||
|
auto lhsTy = input().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto elementType = lhsTy.getElementType();
|
||||||
|
auto lhsShape = lhsTy.getShape();
|
||||||
|
SmallVector<int64_t, 2> rhsShape;
|
||||||
|
|
||||||
|
Operation *shapeDef = shape().getDefiningOp();
|
||||||
|
|
||||||
|
if (mlir::ONNXShapeOp shapeOp =
|
||||||
|
dyn_cast_or_null<mlir::ONNXShapeOp>(shapeDef)) {
|
||||||
|
// If the shape operand is produced by a onnx.Shape operation, infer its
|
||||||
|
// shape and use it as the requested shape.
|
||||||
|
if (!shapeOp.data().getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> rhsShapeRef =
|
||||||
|
shapeOp.data().getType().cast<RankedTensorType>().getShape();
|
||||||
|
rhsShape.assign(rhsShapeRef.begin(), rhsShapeRef.end());
|
||||||
|
|
||||||
|
} else if (mlir::ONNXConstantOp constantOp =
|
||||||
|
dyn_cast_or_null<mlir::ONNXConstantOp>(shapeDef)) {
|
||||||
|
// If the shape operand is produced by a onnx.Constant operation, extract
|
||||||
|
// the actual value of the constant and use it as the reqested shape.
|
||||||
|
|
||||||
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (shapeTensorTy.getRank() != 1)
|
||||||
|
return emitError("Shape tensor must have rank one");
|
||||||
|
|
||||||
|
DenseElementsAttr valueAttribute =
|
||||||
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
||||||
|
if (!valueAttribute)
|
||||||
|
return emitError("DenseElementsAttr expected");
|
||||||
|
|
||||||
|
int64_t shapeRank = shapeTensorTy.getShape()[0];
|
||||||
|
rhsShape.resize(shapeRank);
|
||||||
|
|
||||||
|
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
||||||
|
for (int i = 0; i != shapeRank; ++i)
|
||||||
|
rhsShape[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
assert(valueIt == valueAttribute.getValues<IntegerAttr>().end() &&
|
||||||
|
"Shape of constant does not match its actual value");
|
||||||
|
} else {
|
||||||
|
return emitError(
|
||||||
|
"Shape argument of Expand is the output of an unexpected operation: " +
|
||||||
|
shapeDef->getName().getStringRef() +
|
||||||
|
". Supported operations are: onnx.Constant and onnx.Shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> resultShape;
|
||||||
|
if (!getBroadcastedShape(lhsShape, rhsShape, resultShape)) {
|
||||||
|
return emitError("Tensor not exapandable");
|
||||||
|
}
|
||||||
|
|
||||||
|
getResult().setType(RankedTensorType::get(resultShape, elementType));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dropout
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXDropoutOp::inferShapes() {
|
||||||
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
|
||||||
|
getResult(0).setType(data().getType());
|
||||||
|
|
||||||
|
auto inputShape = data().getType().cast<RankedTensorType>().getShape();
|
||||||
|
|
||||||
|
IntegerType i1Type = IntegerType::get(1, IntegerType::Signless, getContext());
|
||||||
|
getResult(1).setType(RankedTensorType::get(inputShape, i1Type));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNX type related code
|
// ONNX type related code
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1045,7 +1045,7 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Dropout operation";
|
let summary = "ONNX Dropout operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
"Dropout takes one input floating tensor and produces two tensor outputs,"
|
||||||
|
@ -1193,7 +1193,7 @@ def ONNXEqualOp:ONNX_Op<"Equal",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXErfOp:ONNX_Op<"Erf",
|
def ONNXErfOp:ONNX_Op<"Erf",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Erf operation";
|
let summary = "ONNX Erf operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the error function of the given input tensor element-wise."
|
"Computes the error function of the given input tensor element-wise."
|
||||||
|
@ -1247,7 +1247,7 @@ def ONNXExpOp:ONNX_Op<"Exp",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXExpandOp:ONNX_Op<"Expand",
|
def ONNXExpandOp:ONNX_Op<"Expand",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Expand operation";
|
let summary = "ONNX Expand operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Broadcast the input tensor following the given shape and the broadcast rule."
|
"Broadcast the input tensor following the given shape and the broadcast rule."
|
||||||
|
@ -3222,7 +3222,7 @@ def ONNXPadOp:ONNX_Op<"Pad",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPowOp:ONNX_Op<"Pow",
|
def ONNXPowOp:ONNX_Op<"Pow",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Pow operation";
|
let summary = "ONNX Pow operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Pow takes input data (Tensor<T>) and exponent Tensor, and"
|
"Pow takes input data (Tensor<T>) and exponent Tensor, and"
|
||||||
|
@ -3798,7 +3798,7 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReduceMeanOp:ONNX_Op<"ReduceMean",
|
def ONNXReduceMeanOp:ONNX_Op<"ReduceMean",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX ReduceMean operation";
|
let summary = "ONNX ReduceMean operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes the mean of the input tensor's element along the provided axes. The resulted"
|
"Computes the mean of the input tensor's element along the provided axes. The resulted"
|
||||||
|
|
|
@ -1356,6 +1356,8 @@ func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<
|
||||||
// CHECK: return [[RES]] : tensor<1x2xf32>
|
// CHECK: return [[RES]] : tensor<1x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test the shape inferencing for the scaler operation.
|
/// Test the shape inferencing for the scaler operation.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1367,3 +1369,114 @@ func @test_scaler_no_scale_int(%arg0: tensor<3xi32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<3xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for Pow.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_pow(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<f32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_pow
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<f32>) -> tensor<1x2x3x4xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x2x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for Erf.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_erf(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_erf
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x2x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for Expand.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_expand_with_constant(%arg0 : tensor<2x1x6x1xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[7, 1, 5]> : tensor<3xi64> } : () -> tensor<3xi64>
|
||||||
|
%1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_expand_with_constant
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<2x7x6x5xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x7x6x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<*xi64>
|
||||||
|
%1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_expand_with_shape
|
||||||
|
// CHECK: [[SHAPE:%.+]] = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<2xi64>
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, [[SHAPE]]) : (tensor<2x1x6x1xf32>, tensor<2xi64>) -> tensor<2x1x6x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<2x1x6x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for ReduceMean.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_reduce_mean_1(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reduce_mean_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x1xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x2x3x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_reduce_mean_2(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reduce_mean_2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x1x4xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x2x1x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_reduce_mean_3(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reduce_mean_3
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x2x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for Dropout.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>) {
|
||||||
|
%output, %mask = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>)
|
||||||
|
"std.return"(%output, %mask) : (tensor<*xf32>, tensor<*xi1>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dropout
|
||||||
|
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
|
||||||
|
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
|
||||||
|
}
|
||||||
|
|
|
@ -244,15 +244,77 @@ special_op_handler = dict([
|
||||||
|
|
||||||
# Operations supporting shape inference.
|
# Operations supporting shape inference.
|
||||||
OpsWithShapeInference=[
|
OpsWithShapeInference=[
|
||||||
'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
'Abs',
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul',
|
'Add',
|
||||||
'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'And',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Atan',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'AveragePool',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
'Cast',
|
||||||
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
'Concat',
|
||||||
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
'Constant',
|
||||||
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice', 'Scaler'
|
'ConstantOfShape',
|
||||||
|
'Conv',
|
||||||
|
'ConvInteger',
|
||||||
|
'ConvTranspose',
|
||||||
|
'Cos',
|
||||||
|
'Cosh',
|
||||||
|
'DequantizeLinear',
|
||||||
|
'Div',
|
||||||
|
'Dropout',
|
||||||
|
'DynamicQuantizeLinear',
|
||||||
|
'Elu',
|
||||||
|
'Erf',
|
||||||
|
'Exp',
|
||||||
|
'Expand',
|
||||||
|
'Flatten',
|
||||||
|
'GRU',
|
||||||
|
'Gather',
|
||||||
|
'Gemm',
|
||||||
|
'HardSigmoid',
|
||||||
|
'Identity',
|
||||||
|
'LSTM',
|
||||||
|
'LeakyRelu',
|
||||||
|
'Log',
|
||||||
|
'MatMul',
|
||||||
|
'Max',
|
||||||
|
'Min',
|
||||||
|
'Mul',
|
||||||
|
'Neg',
|
||||||
|
'Or',
|
||||||
|
'Pad',
|
||||||
|
'Pow',
|
||||||
|
'QuantizeLinear',
|
||||||
|
'RNN',
|
||||||
|
'Reciprocal',
|
||||||
|
'ReduceMax',
|
||||||
|
'ReduceMean',
|
||||||
|
'ReduceMin',
|
||||||
|
'ReduceProd',
|
||||||
|
'ReduceSum',
|
||||||
|
'Relu',
|
||||||
|
'Reshape',
|
||||||
|
'Scaler',
|
||||||
|
'Selu',
|
||||||
|
'Shape',
|
||||||
|
'Sigmoid',
|
||||||
|
'Sign',
|
||||||
|
'Sin',
|
||||||
|
'Sinh',
|
||||||
|
'Slice',
|
||||||
|
'Softmax',
|
||||||
|
'Softplus',
|
||||||
|
'Softsign',
|
||||||
|
'Split',
|
||||||
|
'Sqrt',
|
||||||
|
'Squeeze',
|
||||||
|
'Sub',
|
||||||
|
'Sum',
|
||||||
|
'Tan',
|
||||||
|
'Tanh',
|
||||||
|
'Tile',
|
||||||
|
'Transpose',
|
||||||
|
'Unsqueeze',
|
||||||
|
'Xor',
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue