Implement shape inference for SplitOp (#95)

* Implement shape inference for SplitOp

* Change spitOpt to SplitAttribute and check the axis range before updating the axis attribute

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-05-13 19:07:27 +09:00 committed by GitHub
parent 7c29da191e
commit 9a874007ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 2 deletions

View File

@ -1537,6 +1537,79 @@ bool ONNXConcatOp::inferShapes() {
return true; return true;
} }
//===----------------------------------------------------------------------===//
// Split
bool ONNXSplitOp::inferShapes() {
if (!getOperand().getType().cast<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false;
}
int numOfResults = getNumResults();
auto inputType = getOperand().getType().cast<RankedTensorType>();
auto inputShape = inputType.getShape();
int64_t inputRank = inputShape.size();
// Checking value of axis parameter.
auto axisIndex = axis().getSExtValue();
if (axisIndex < -inputRank || axisIndex >= inputRank) {
emitError("Split axis value out of bound");
return false;
}
// Negative axis means values are counted from the opposite side.
if (axisIndex < 0) {
axisIndex = inputRank + axisIndex;
auto builder = mlir::Builder(getContext());
axisAttr(builder.getI64IntegerAttr(axisIndex));
}
// Checking value of split parameter.
auto splitAttribute = split();
SmallVector<int64_t, 4> splitLengths;
if (splitAttribute.hasValue()) {
if (ArrayAttrSize(splitAttribute) != numOfResults) {
emitError("Split size not equal to the number of results");
}
for (int i = 0; i < numOfResults; ++i)
splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i));
} else {
if (inputShape[axisIndex] <= 0) {
emitError("The dimension at the split axis is expected to be known at "
"compile time");
return false;
}
if (inputShape[axisIndex] % numOfResults != 0) {
emitError("The dimension at the split axis is expected to be divisible "
"by the number of results");
return false;
}
// If split parameter is not specified, the dimension is split to
// equal-sized parts.
for (int i = 0; i < numOfResults; ++i)
splitLengths.emplace_back(inputShape[axisIndex] / numOfResults);
// Build attribute and store attribute.
auto builder = mlir::Builder(getContext());
splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths)));
}
// Build result types.
for (int i = 0; i < numOfResults; ++i) {
SmallVector<int64_t, 3> resultShape;
for (int j = 0; j < inputRank; ++j) {
if (j == axisIndex) {
resultShape.emplace_back(splitLengths[i]);
} else {
resultShape.emplace_back(inputShape[j]);
}
}
getResults()[i].setType(
RankedTensorType::get(resultShape, inputType.getElementType()));
}
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3241,7 +3241,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
} }
def ONNXSplitOp:ONNX_Op<"Split", def ONNXSplitOp:ONNX_Op<"Split",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Split operation"; let summary = "ONNX Split operation";
let description = [{ let description = [{
"Split a tensor into a list of tensors, along the specified" "Split a tensor into a list of tensors, along the specified"

View File

@ -124,6 +124,7 @@ public:
op->getName().getStringRef() != "onnx.Abs" && op->getName().getStringRef() != "onnx.Abs" &&
op->getName().getStringRef() != "onnx.Constant" && op->getName().getStringRef() != "onnx.Constant" &&
op->getName().getStringRef() != "onnx.Concat" && op->getName().getStringRef() != "onnx.Concat" &&
op->getName().getStringRef() != "onnx.Split" &&
op->getName().getStringRef() != "onnx.Neg" && op->getName().getStringRef() != "onnx.Neg" &&
op->getName().getStringRef() != "onnx.Unsqueeze") op->getName().getStringRef() != "onnx.Unsqueeze")
return false; return false;

View File

@ -610,3 +610,36 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
// CHECK: return [[RES]] : tensor<5x9x32xf32> // CHECK: return [[RES]] : tensor<5x9x32xf32>
} }
// -----
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_split_1
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
}
// -----
func @test_split_2(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
%0, %1 = "onnx.Split"(%arg0) { axis = -2 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_split_2
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
}
// -----
func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
%0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_split_3
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
}

View File

@ -63,7 +63,7 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg' 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.