diff --git a/SharingWork.md b/SharingWork.md index 257a9a7..1e2cfc3 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -21,8 +21,11 @@ Update as you push code to the master branch. | Cos | | v | v | v | | | | Cosh | | v | v | v | | | | Div | | v | v | v | | M | +| Dropout | | v | v | | | | | Elu | | v | v | v | | | +| Erf | | v | v | | | | | Exp | | v | v | v | | | +| Expand | | v | v | | | | | Gemm | | v | v | v | | U | | HardSigmoid | | v | v | v | | | | Identity | | v | v | v | | | @@ -35,12 +38,14 @@ Update as you push code to the master branch. | Mul | | v | v | v | | M | | Or | | v | v | v | | M | | Pad | | v | V | v | | const only | +| Pow | | v | v | | | M | | Reciprocal | | v | v | v | | | | ReduceMax | | v | v | v | | | | ReduceL1 | | v | | | | be decomposed into ReduceSum and Abs | | ReduceL2 | | v | | | | be decomposed into ReduceSumSquare and Sqrt | | ReduceLogSum | | v | | | | be decomposed into ReduceSum and Log | | ReduceLogSumExp| | v | | | | be decomposed into ReduceLogSum and Exp | +| ReduceMean | | v | v | | | | | ReduceMin | | v | v | v | | | | ReduceProd | | v | v | v | | | | ReduceSum | | v | v | v | | | @@ -104,12 +109,9 @@ And add literal tests at each step, and end to end tests once completed. | DepthToSpace | | | | | | | | DequantizeLin | | | | | | | | Det | | | | | | | -| Dropout | | | | | | | | DynQuantizeLin | | | | | | | | Einsum | | | | | | V | | Equal | | | | | | M | -| Erf | | | | | | | -| Expand | | | | | | | | EyeLike | | | | | | | | Flatten | | | | | | | | Floor | | | | | | | @@ -151,7 +153,6 @@ And add literal tests at each step, and end to end tests once completed. | Not | | | | | | | | OneHot | | | | | | | | PRelu | | | | | | U | -| Power | | | | | | M | | QLinearConv | | | | | | P | | QLinearMatMul | | | | | | M | | QuantizeLinear | | | | | | | @@ -161,7 +162,6 @@ And add literal tests at each step, and end to end tests once completed. | RandUniform | | | | | | | | RandUniformLike| | | | | | | | Range | | | | | | | -| ReduceMean | | | | | | | | Resize | | | | | | | | ReverseSequence| | | | | | | | RoiAlign | | | | | | | diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 052df8b..ce5dfa3 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -769,6 +769,29 @@ LogicalResult ONNXAbsOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Erf +//===----------------------------------------------------------------------===// + +LogicalResult ONNXErfOp::inferShapes() { + getResult().setType(getOperand().getType()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pow +//===----------------------------------------------------------------------===// + +LogicalResult ONNXPowOp::inferShapes() { + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); + getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return success(); +} + //===----------------------------------------------------------------------===// // Add //===----------------------------------------------------------------------===// @@ -1294,6 +1317,19 @@ LogicalResult ONNXReduceMaxOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// ReduceMean +//===----------------------------------------------------------------------===// + +LogicalResult ONNXReduceMeanOp::inferShapes() { + if (!getOperand().getType().isa()) + return emitError("Input tensor not ranked"); + + auto operandTy = getOperand().getType().cast(); + getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); + return success(); +} + //===----------------------------------------------------------------------===// // ReduceMin //===----------------------------------------------------------------------===// @@ -2655,6 +2691,90 @@ LogicalResult ONNXSliceOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Expand +//===----------------------------------------------------------------------===// + +LogicalResult ONNXExpandOp::inferShapes() { + if (!input().getType().isa()) + return emitError("Input tensor not ranked"); + + auto lhsTy = input().getType().cast(); + + auto elementType = lhsTy.getElementType(); + auto lhsShape = lhsTy.getShape(); + SmallVector rhsShape; + + Operation *shapeDef = shape().getDefiningOp(); + + if (mlir::ONNXShapeOp shapeOp = + dyn_cast_or_null(shapeDef)) { + // If the shape operand is produced by a onnx.Shape operation, infer its + // shape and use it as the requested shape. + if (!shapeOp.data().getType().isa()) + return emitError("Input tensor not ranked"); + + ArrayRef rhsShapeRef = + shapeOp.data().getType().cast().getShape(); + rhsShape.assign(rhsShapeRef.begin(), rhsShapeRef.end()); + + } else if (mlir::ONNXConstantOp constantOp = + dyn_cast_or_null(shapeDef)) { + // If the shape operand is produced by a onnx.Constant operation, extract + // the actual value of the constant and use it as the reqested shape. + + auto shapeTensorTy = shape().getType().cast(); + + if (shapeTensorTy.getRank() != 1) + return emitError("Shape tensor must have rank one"); + + DenseElementsAttr valueAttribute = + constantOp.valueAttr().dyn_cast(); + if (!valueAttribute) + return emitError("DenseElementsAttr expected"); + + int64_t shapeRank = shapeTensorTy.getShape()[0]; + rhsShape.resize(shapeRank); + + auto valueIt = valueAttribute.getValues().begin(); + for (int i = 0; i != shapeRank; ++i) + rhsShape[i] = (*valueIt++).cast().getInt(); + + assert(valueIt == valueAttribute.getValues().end() && + "Shape of constant does not match its actual value"); + } else { + return emitError( + "Shape argument of Expand is the output of an unexpected operation: " + + shapeDef->getName().getStringRef() + + ". Supported operations are: onnx.Constant and onnx.Shape"); + } + + SmallVector resultShape; + if (!getBroadcastedShape(lhsShape, rhsShape, resultShape)) { + return emitError("Tensor not exapandable"); + } + + getResult().setType(RankedTensorType::get(resultShape, elementType)); + return success(); +} + +//===----------------------------------------------------------------------===// +// Dropout +//===----------------------------------------------------------------------===// + +LogicalResult ONNXDropoutOp::inferShapes() { + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); + + getResult(0).setType(data().getType()); + + auto inputShape = data().getType().cast().getShape(); + + IntegerType i1Type = IntegerType::get(1, IntegerType::Signless, getContext()); + getResult(1).setType(RankedTensorType::get(inputShape, i1Type)); + return success(); +} + //===----------------------------------------------------------------------===// // ONNX type related code //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 48916a3..6a3a16c 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -1045,7 +1045,7 @@ def ONNXDivOp:ONNX_Op<"Div", } def ONNXDropoutOp:ONNX_Op<"Dropout", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Dropout operation"; let description = [{ "Dropout takes one input floating tensor and produces two tensor outputs," @@ -1193,7 +1193,7 @@ def ONNXEqualOp:ONNX_Op<"Equal", } def ONNXErfOp:ONNX_Op<"Erf", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Erf operation"; let description = [{ "Computes the error function of the given input tensor element-wise." @@ -1247,7 +1247,7 @@ def ONNXExpOp:ONNX_Op<"Exp", } def ONNXExpandOp:ONNX_Op<"Expand", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Expand operation"; let description = [{ "Broadcast the input tensor following the given shape and the broadcast rule." @@ -3222,7 +3222,7 @@ def ONNXPadOp:ONNX_Op<"Pad", } def ONNXPowOp:ONNX_Op<"Pow", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Pow operation"; let description = [{ "Pow takes input data (Tensor) and exponent Tensor, and" @@ -3798,7 +3798,7 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", } def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMean operation"; let description = [{ "Computes the mean of the input tensor's element along the provided axes. The resulted" diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index d69f9e2..a51f4e3 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1356,6 +1356,8 @@ func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor< // CHECK: return [[RES]] : tensor<1x2xf32> } +// ----- + //===----------------------------------------------------------------------===// /// Test the shape inferencing for the scaler operation. //===----------------------------------------------------------------------===// @@ -1367,3 +1369,114 @@ func @test_scaler_no_scale_int(%arg0: tensor<3xi32>) -> tensor<*xf32> { // CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32> // CHECK: return [[RES_ATTR]] : tensor<3xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Pow. +//===----------------------------------------------------------------------===// + +func @test_pow(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_pow + // CHECK: [[RES:%.+]] = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor) -> tensor<1x2x3x4xf32> + // CHECK: return [[RES]] : tensor<1x2x3x4xf32> +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Erf. +//===----------------------------------------------------------------------===// + +func @test_erf(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> { + %0 = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_erf + // CHECK: [[RES:%.+]] = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> + // CHECK: return [[RES]] : tensor<1x2x3x4xf32> +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Expand. +//===----------------------------------------------------------------------===// + +func @test_expand_with_constant(%arg0 : tensor<2x1x6x1xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[7, 1, 5]> : tensor<3xi64> } : () -> tensor<3xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_expand_with_constant + // CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<2x7x6x5xf32> + // CHECK: return [[RES]] : tensor<2x7x6x5xf32> +} + +// ----- + +func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<*xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<*xi64>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_expand_with_shape + // CHECK: [[SHAPE:%.+]] = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<2xi64> + // CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, [[SHAPE]]) : (tensor<2x1x6x1xf32>, tensor<2xi64>) -> tensor<2x1x6x2xf32> + // CHECK: return [[RES]] : tensor<2x1x6x2xf32> +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for ReduceMean. +//===----------------------------------------------------------------------===// + +func @test_reduce_mean_1(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> { + %0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reduce_mean_1 + // CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x1xf32> + // CHECK: return [[RES]] : tensor<1x2x3x1xf32> +} + +// ----- + +func @test_reduce_mean_2(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> { + %0 = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reduce_mean_2 + // CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x1x4xf32> + // CHECK: return [[RES]] : tensor<1x2x1x4xf32> +} + +// ----- + +func @test_reduce_mean_3(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> { + %0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reduce_mean_3 + // CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3xf32> + // CHECK: return [[RES]] : tensor<1x2x3xf32> +} + +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Dropout. +//===----------------------------------------------------------------------===// + +func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>) { + %output, %mask = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>) + "std.return"(%output, %mask) : (tensor<*xf32>, tensor<*xi1>) -> () + + // CHECK-LABEL: test_dropout + // CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>) + // CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1> +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 143c0e7..3c33374 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -243,16 +243,78 @@ special_op_handler = dict([ ]) # Operations supporting shape inference. -OpsWithShapeInference = [ - 'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', - 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', - 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', - 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', - 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', - 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', - 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', - 'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice', 'Scaler' +OpsWithShapeInference=[ + 'Abs', + 'Add', + 'And', + 'Atan', + 'AveragePool', + 'Cast', + 'Concat', + 'Constant', + 'ConstantOfShape', + 'Conv', + 'ConvInteger', + 'ConvTranspose', + 'Cos', + 'Cosh', + 'DequantizeLinear', + 'Div', + 'Dropout', + 'DynamicQuantizeLinear', + 'Elu', + 'Erf', + 'Exp', + 'Expand', + 'Flatten', + 'GRU', + 'Gather', + 'Gemm', + 'HardSigmoid', + 'Identity', + 'LSTM', + 'LeakyRelu', + 'Log', + 'MatMul', + 'Max', + 'Min', + 'Mul', + 'Neg', + 'Or', + 'Pad', + 'Pow', + 'QuantizeLinear', + 'RNN', + 'Reciprocal', + 'ReduceMax', + 'ReduceMean', + 'ReduceMin', + 'ReduceProd', + 'ReduceSum', + 'Relu', + 'Reshape', + 'Scaler', + 'Selu', + 'Shape', + 'Sigmoid', + 'Sign', + 'Sin', + 'Sinh', + 'Slice', + 'Softmax', + 'Softplus', + 'Softsign', + 'Split', + 'Sqrt', + 'Squeeze', + 'Sub', + 'Sum', + 'Tan', + 'Tanh', + 'Tile', + 'Transpose', + 'Unsqueeze', + 'Xor', ] # Operations supporting canonicalization.