Add shape inference for Ops used by BERT (#249)

* Add shape inference for Ops used by BERT

* Erf
* Pow
* ReduceMean
* Dropout
* Expand
  https://github.com/onnx/onnx/blob/master/docs/Operators.md#expand
  Deduce the value of the shape operand by looking at the producer
  of the operand.
  Currently supported producers are: onnx.Constant and onnx.Shape.

* Add corresponding tests for each op.

* Sort the list of ops with shape inference in gen_onnx_mlir.py
in alphabetic order for clarity.

* Restart CI

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
albertomagni-ms 2020-08-07 06:08:00 +01:00 committed by GitHub
parent 20686d8f0f
commit e1386b0689
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 315 additions and 20 deletions

View File

@ -21,8 +21,11 @@ Update as you push code to the master branch.
| Cos | | v | v | v | | | | Cos | | v | v | v | | |
| Cosh | | v | v | v | | | | Cosh | | v | v | v | | |
| Div | | v | v | v | | M | | Div | | v | v | v | | M |
| Dropout | | v | v | | | |
| Elu | | v | v | v | | | | Elu | | v | v | v | | |
| Erf | | v | v | | | |
| Exp | | v | v | v | | | | Exp | | v | v | v | | |
| Expand | | v | v | | | |
| Gemm | | v | v | v | | U | | Gemm | | v | v | v | | U |
| HardSigmoid | | v | v | v | | | | HardSigmoid | | v | v | v | | |
| Identity | | v | v | v | | | | Identity | | v | v | v | | |
@ -35,12 +38,14 @@ Update as you push code to the master branch.
| Mul | | v | v | v | | M | | Mul | | v | v | v | | M |
| Or | | v | v | v | | M | | Or | | v | v | v | | M |
| Pad | | v | V | v | | const only | | Pad | | v | V | v | | const only |
| Pow | | v | v | | | M |
| Reciprocal | | v | v | v | | | | Reciprocal | | v | v | v | | |
| ReduceMax | | v | v | v | | | | ReduceMax | | v | v | v | | |
| ReduceL1 | | v | | | | be decomposed into ReduceSum and Abs | | ReduceL1 | | v | | | | be decomposed into ReduceSum and Abs |
| ReduceL2 | | v | | | | be decomposed into ReduceSumSquare and Sqrt | | ReduceL2 | | v | | | | be decomposed into ReduceSumSquare and Sqrt |
| ReduceLogSum | | v | | | | be decomposed into ReduceSum and Log | | ReduceLogSum | | v | | | | be decomposed into ReduceSum and Log |
| ReduceLogSumExp| | v | | | | be decomposed into ReduceLogSum and Exp | | ReduceLogSumExp| | v | | | | be decomposed into ReduceLogSum and Exp |
| ReduceMean | | v | v | | | |
| ReduceMin | | v | v | v | | | | ReduceMin | | v | v | v | | |
| ReduceProd | | v | v | v | | | | ReduceProd | | v | v | v | | |
| ReduceSum | | v | v | v | | | | ReduceSum | | v | v | v | | |
@ -104,12 +109,9 @@ And add literal tests at each step, and end to end tests once completed.
| DepthToSpace | | | | | | | | DepthToSpace | | | | | | |
| DequantizeLin | | | | | | | | DequantizeLin | | | | | | |
| Det | | | | | | | | Det | | | | | | |
| Dropout | | | | | | |
| DynQuantizeLin | | | | | | | | DynQuantizeLin | | | | | | |
| Einsum | | | | | | V | | Einsum | | | | | | V |
| Equal | | | | | | M | | Equal | | | | | | M |
| Erf | | | | | | |
| Expand | | | | | | |
| EyeLike | | | | | | | | EyeLike | | | | | | |
| Flatten | | | | | | | | Flatten | | | | | | |
| Floor | | | | | | | | Floor | | | | | | |
@ -151,7 +153,6 @@ And add literal tests at each step, and end to end tests once completed.
| Not | | | | | | | | Not | | | | | | |
| OneHot | | | | | | | | OneHot | | | | | | |
| PRelu | | | | | | U | | PRelu | | | | | | U |
| Power | | | | | | M |
| QLinearConv | | | | | | P | | QLinearConv | | | | | | P |
| QLinearMatMul | | | | | | M | | QLinearMatMul | | | | | | M |
| QuantizeLinear | | | | | | | | QuantizeLinear | | | | | | |
@ -161,7 +162,6 @@ And add literal tests at each step, and end to end tests once completed.
| RandUniform | | | | | | | | RandUniform | | | | | | |
| RandUniformLike| | | | | | | | RandUniformLike| | | | | | |
| Range | | | | | | | | Range | | | | | | |
| ReduceMean | | | | | | |
| Resize | | | | | | | | Resize | | | | | | |
| ReverseSequence| | | | | | | | ReverseSequence| | | | | | |
| RoiAlign | | | | | | | | RoiAlign | | | | | | |

View File

@ -769,6 +769,29 @@ LogicalResult ONNXAbsOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Erf
//===----------------------------------------------------------------------===//
LogicalResult ONNXErfOp::inferShapes() {
getResult().setType(getOperand().getType());
return success();
}
//===----------------------------------------------------------------------===//
// Pow
//===----------------------------------------------------------------------===//
LogicalResult ONNXPowOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return emitError("Input tensor(s) not ranked");
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1294,6 +1317,19 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// ReduceMean
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceMeanOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMin // ReduceMin
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2655,6 +2691,90 @@ LogicalResult ONNXSliceOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Expand
//===----------------------------------------------------------------------===//
LogicalResult ONNXExpandOp::inferShapes() {
if (!input().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
auto lhsTy = input().getType().cast<RankedTensorType>();
auto elementType = lhsTy.getElementType();
auto lhsShape = lhsTy.getShape();
SmallVector<int64_t, 2> rhsShape;
Operation *shapeDef = shape().getDefiningOp();
if (mlir::ONNXShapeOp shapeOp =
dyn_cast_or_null<mlir::ONNXShapeOp>(shapeDef)) {
// If the shape operand is produced by a onnx.Shape operation, infer its
// shape and use it as the requested shape.
if (!shapeOp.data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
ArrayRef<int64_t> rhsShapeRef =
shapeOp.data().getType().cast<RankedTensorType>().getShape();
rhsShape.assign(rhsShapeRef.begin(), rhsShapeRef.end());
} else if (mlir::ONNXConstantOp constantOp =
dyn_cast_or_null<mlir::ONNXConstantOp>(shapeDef)) {
// If the shape operand is produced by a onnx.Constant operation, extract
// the actual value of the constant and use it as the reqested shape.
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
if (shapeTensorTy.getRank() != 1)
return emitError("Shape tensor must have rank one");
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute)
return emitError("DenseElementsAttr expected");
int64_t shapeRank = shapeTensorTy.getShape()[0];
rhsShape.resize(shapeRank);
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i = 0; i != shapeRank; ++i)
rhsShape[i] = (*valueIt++).cast<IntegerAttr>().getInt();
assert(valueIt == valueAttribute.getValues<IntegerAttr>().end() &&
"Shape of constant does not match its actual value");
} else {
return emitError(
"Shape argument of Expand is the output of an unexpected operation: " +
shapeDef->getName().getStringRef() +
". Supported operations are: onnx.Constant and onnx.Shape");
}
SmallVector<int64_t, 2> resultShape;
if (!getBroadcastedShape(lhsShape, rhsShape, resultShape)) {
return emitError("Tensor not exapandable");
}
getResult().setType(RankedTensorType::get(resultShape, elementType));
return success();
}
//===----------------------------------------------------------------------===//
// Dropout
//===----------------------------------------------------------------------===//
LogicalResult ONNXDropoutOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
getResult(0).setType(data().getType());
auto inputShape = data().getType().cast<RankedTensorType>().getShape();
IntegerType i1Type = IntegerType::get(1, IntegerType::Signless, getContext());
getResult(1).setType(RankedTensorType::get(inputShape, i1Type));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNX type related code // ONNX type related code
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1045,7 +1045,7 @@ def ONNXDivOp:ONNX_Op<"Div",
} }
def ONNXDropoutOp:ONNX_Op<"Dropout", def ONNXDropoutOp:ONNX_Op<"Dropout",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Dropout operation"; let summary = "ONNX Dropout operation";
let description = [{ let description = [{
"Dropout takes one input floating tensor and produces two tensor outputs," "Dropout takes one input floating tensor and produces two tensor outputs,"
@ -1193,7 +1193,7 @@ def ONNXEqualOp:ONNX_Op<"Equal",
} }
def ONNXErfOp:ONNX_Op<"Erf", def ONNXErfOp:ONNX_Op<"Erf",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Erf operation"; let summary = "ONNX Erf operation";
let description = [{ let description = [{
"Computes the error function of the given input tensor element-wise." "Computes the error function of the given input tensor element-wise."
@ -1247,7 +1247,7 @@ def ONNXExpOp:ONNX_Op<"Exp",
} }
def ONNXExpandOp:ONNX_Op<"Expand", def ONNXExpandOp:ONNX_Op<"Expand",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Expand operation"; let summary = "ONNX Expand operation";
let description = [{ let description = [{
"Broadcast the input tensor following the given shape and the broadcast rule." "Broadcast the input tensor following the given shape and the broadcast rule."
@ -3222,7 +3222,7 @@ def ONNXPadOp:ONNX_Op<"Pad",
} }
def ONNXPowOp:ONNX_Op<"Pow", def ONNXPowOp:ONNX_Op<"Pow",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Pow operation"; let summary = "ONNX Pow operation";
let description = [{ let description = [{
"Pow takes input data (Tensor<T>) and exponent Tensor, and" "Pow takes input data (Tensor<T>) and exponent Tensor, and"
@ -3798,7 +3798,7 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax",
} }
def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", def ONNXReduceMeanOp:ONNX_Op<"ReduceMean",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX ReduceMean operation"; let summary = "ONNX ReduceMean operation";
let description = [{ let description = [{
"Computes the mean of the input tensor's element along the provided axes. The resulted" "Computes the mean of the input tensor's element along the provided axes. The resulted"

View File

@ -1356,6 +1356,8 @@ func @test_slice_all_constant_negative_steps(%arg0 : tensor<2x4xf32>) -> tensor<
// CHECK: return [[RES]] : tensor<1x2xf32> // CHECK: return [[RES]] : tensor<1x2xf32>
} }
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test the shape inferencing for the scaler operation. /// Test the shape inferencing for the scaler operation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1367,3 +1369,114 @@ func @test_scaler_no_scale_int(%arg0: tensor<3xi32>) -> tensor<*xf32> {
// CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32> // CHECK: [[RES_ATTR:%.+]] = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK: return [[RES_ATTR]] : tensor<3xf32> // CHECK: return [[RES_ATTR]] : tensor<3xf32>
} }
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for Pow.
//===----------------------------------------------------------------------===//
func @test_pow(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%0 = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<f32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_pow
// CHECK: [[RES:%.+]] = "onnx.Pow"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<f32>) -> tensor<1x2x3x4xf32>
// CHECK: return [[RES]] : tensor<1x2x3x4xf32>
}
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for Erf.
//===----------------------------------------------------------------------===//
func @test_erf(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
%0 = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_erf
// CHECK: [[RES:%.+]] = "onnx.Erf"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK: return [[RES]] : tensor<1x2x3x4xf32>
}
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for Expand.
//===----------------------------------------------------------------------===//
func @test_expand_with_constant(%arg0 : tensor<2x1x6x1xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[7, 1, 5]> : tensor<3xi64> } : () -> tensor<3xi64>
%1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_expand_with_constant
// CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<3xi64>) -> tensor<2x7x6x5xf32>
// CHECK: return [[RES]] : tensor<2x7x6x5xf32>
}
// -----
func @test_expand_with_shape(%arg0 : tensor<2x1x6x1xf32>, %arg1: tensor<6x2xf32>) -> tensor<*xf32> {
%0 = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<*xi64>
%1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xf32>, tensor<*xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_expand_with_shape
// CHECK: [[SHAPE:%.+]] = "onnx.Shape"(%arg1) : (tensor<6x2xf32>) -> tensor<2xi64>
// CHECK: [[RES:%.+]] = "onnx.Expand"(%arg0, [[SHAPE]]) : (tensor<2x1x6x1xf32>, tensor<2xi64>) -> tensor<2x1x6x2xf32>
// CHECK: return [[RES]] : tensor<2x1x6x2xf32>
}
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for ReduceMean.
//===----------------------------------------------------------------------===//
func @test_reduce_mean_1(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
%0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reduce_mean_1
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x1xf32>
// CHECK: return [[RES]] : tensor<1x2x3x1xf32>
}
// -----
func @test_reduce_mean_2(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
%0 = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reduce_mean_2
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [2], keepdims = 1 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x1x4xf32>
// CHECK: return [[RES]] : tensor<1x2x1x4xf32>
}
// -----
func @test_reduce_mean_3(%arg0: tensor<1x2x3x4xf32>) -> tensor<*xf32> {
%0 = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reduce_mean_3
// CHECK: [[RES:%.+]] = "onnx.ReduceMean"(%arg0) {axes = [-1], keepdims = 0 : i64} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3xf32>
// CHECK: return [[RES]] : tensor<1x2x3xf32>
}
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for Dropout.
//===----------------------------------------------------------------------===//
func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>) {
%output, %mask = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>)
"std.return"(%output, %mask) : (tensor<*xf32>, tensor<*xi1>) -> ()
// CHECK-LABEL: test_dropout
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
}

View File

@ -244,15 +244,77 @@ special_op_handler = dict([
# Operations supporting shape inference. # Operations supporting shape inference.
OpsWithShapeInference=[ OpsWithShapeInference=[
'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Abs',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Add',
'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'And',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Atan',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'AveragePool',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Cast',
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'Concat',
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', 'Constant',
'Squeeze', 'Shape', 'Tile', 'Gather', 'ConstantOfShape', 'Slice', 'Scaler' 'ConstantOfShape',
'Conv',
'ConvInteger',
'ConvTranspose',
'Cos',
'Cosh',
'DequantizeLinear',
'Div',
'Dropout',
'DynamicQuantizeLinear',
'Elu',
'Erf',
'Exp',
'Expand',
'Flatten',
'GRU',
'Gather',
'Gemm',
'HardSigmoid',
'Identity',
'LSTM',
'LeakyRelu',
'Log',
'MatMul',
'Max',
'Min',
'Mul',
'Neg',
'Or',
'Pad',
'Pow',
'QuantizeLinear',
'RNN',
'Reciprocal',
'ReduceMax',
'ReduceMean',
'ReduceMin',
'ReduceProd',
'ReduceSum',
'Relu',
'Reshape',
'Scaler',
'Selu',
'Shape',
'Sigmoid',
'Sign',
'Sin',
'Sinh',
'Slice',
'Softmax',
'Softplus',
'Softsign',
'Split',
'Sqrt',
'Squeeze',
'Sub',
'Sum',
'Tan',
'Tanh',
'Tile',
'Transpose',
'Unsqueeze',
'Xor',
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.