Implement shape inference for SplitOp (#95)
* Implement shape inference for SplitOp * Change spitOpt to SplitAttribute and check the axis range before updating the axis attribute Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
7c29da191e
commit
9a874007ce
|
@ -1537,6 +1537,79 @@ bool ONNXConcatOp::inferShapes() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Split
|
||||||
|
|
||||||
|
bool ONNXSplitOp::inferShapes() {
|
||||||
|
if (!getOperand().getType().cast<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int numOfResults = getNumResults();
|
||||||
|
auto inputType = getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
int64_t inputRank = inputShape.size();
|
||||||
|
|
||||||
|
// Checking value of axis parameter.
|
||||||
|
auto axisIndex = axis().getSExtValue();
|
||||||
|
if (axisIndex < -inputRank || axisIndex >= inputRank) {
|
||||||
|
emitError("Split axis value out of bound");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Negative axis means values are counted from the opposite side.
|
||||||
|
if (axisIndex < 0) {
|
||||||
|
axisIndex = inputRank + axisIndex;
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checking value of split parameter.
|
||||||
|
auto splitAttribute = split();
|
||||||
|
SmallVector<int64_t, 4> splitLengths;
|
||||||
|
if (splitAttribute.hasValue()) {
|
||||||
|
if (ArrayAttrSize(splitAttribute) != numOfResults) {
|
||||||
|
emitError("Split size not equal to the number of results");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numOfResults; ++i)
|
||||||
|
splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if (inputShape[axisIndex] <= 0) {
|
||||||
|
emitError("The dimension at the split axis is expected to be known at "
|
||||||
|
"compile time");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (inputShape[axisIndex] % numOfResults != 0) {
|
||||||
|
emitError("The dimension at the split axis is expected to be divisible "
|
||||||
|
"by the number of results");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// If split parameter is not specified, the dimension is split to
|
||||||
|
// equal-sized parts.
|
||||||
|
for (int i = 0; i < numOfResults; ++i)
|
||||||
|
splitLengths.emplace_back(inputShape[axisIndex] / numOfResults);
|
||||||
|
// Build attribute and store attribute.
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build result types.
|
||||||
|
for (int i = 0; i < numOfResults; ++i) {
|
||||||
|
SmallVector<int64_t, 3> resultShape;
|
||||||
|
for (int j = 0; j < inputRank; ++j) {
|
||||||
|
if (j == axisIndex) {
|
||||||
|
resultShape.emplace_back(splitLengths[i]);
|
||||||
|
} else {
|
||||||
|
resultShape.emplace_back(inputShape[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getResults()[i].setType(
|
||||||
|
RankedTensorType::get(resultShape, inputType.getElementType()));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3241,7 +3241,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSplitOp:ONNX_Op<"Split",
|
def ONNXSplitOp:ONNX_Op<"Split",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Split operation";
|
let summary = "ONNX Split operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Split a tensor into a list of tensors, along the specified"
|
"Split a tensor into a list of tensors, along the specified"
|
||||||
|
|
|
@ -124,6 +124,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Abs" &&
|
op->getName().getStringRef() != "onnx.Abs" &&
|
||||||
op->getName().getStringRef() != "onnx.Constant" &&
|
op->getName().getStringRef() != "onnx.Constant" &&
|
||||||
op->getName().getStringRef() != "onnx.Concat" &&
|
op->getName().getStringRef() != "onnx.Concat" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Split" &&
|
||||||
op->getName().getStringRef() != "onnx.Neg" &&
|
op->getName().getStringRef() != "onnx.Neg" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -610,3 +610,36 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_split_1
|
||||||
|
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
|
||||||
|
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_2(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = -2 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_split_2
|
||||||
|
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
|
||||||
|
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_split_3
|
||||||
|
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
|
||||||
|
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg'
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue