Implement shape inference for SplitOp (#95)
* Implement shape inference for SplitOp * Change spitOpt to SplitAttribute and check the axis range before updating the axis attribute Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
7c29da191e
commit
9a874007ce
|
@ -1537,6 +1537,79 @@ bool ONNXConcatOp::inferShapes() {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Split
|
||||
|
||||
bool ONNXSplitOp::inferShapes() {
|
||||
if (!getOperand().getType().cast<RankedTensorType>()) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
int numOfResults = getNumResults();
|
||||
auto inputType = getOperand().getType().cast<RankedTensorType>();
|
||||
auto inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputShape.size();
|
||||
|
||||
// Checking value of axis parameter.
|
||||
auto axisIndex = axis().getSExtValue();
|
||||
if (axisIndex < -inputRank || axisIndex >= inputRank) {
|
||||
emitError("Split axis value out of bound");
|
||||
return false;
|
||||
}
|
||||
// Negative axis means values are counted from the opposite side.
|
||||
if (axisIndex < 0) {
|
||||
axisIndex = inputRank + axisIndex;
|
||||
auto builder = mlir::Builder(getContext());
|
||||
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
||||
}
|
||||
|
||||
// Checking value of split parameter.
|
||||
auto splitAttribute = split();
|
||||
SmallVector<int64_t, 4> splitLengths;
|
||||
if (splitAttribute.hasValue()) {
|
||||
if (ArrayAttrSize(splitAttribute) != numOfResults) {
|
||||
emitError("Split size not equal to the number of results");
|
||||
}
|
||||
for (int i = 0; i < numOfResults; ++i)
|
||||
splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i));
|
||||
|
||||
} else {
|
||||
if (inputShape[axisIndex] <= 0) {
|
||||
emitError("The dimension at the split axis is expected to be known at "
|
||||
"compile time");
|
||||
return false;
|
||||
}
|
||||
if (inputShape[axisIndex] % numOfResults != 0) {
|
||||
emitError("The dimension at the split axis is expected to be divisible "
|
||||
"by the number of results");
|
||||
return false;
|
||||
}
|
||||
// If split parameter is not specified, the dimension is split to
|
||||
// equal-sized parts.
|
||||
for (int i = 0; i < numOfResults; ++i)
|
||||
splitLengths.emplace_back(inputShape[axisIndex] / numOfResults);
|
||||
// Build attribute and store attribute.
|
||||
auto builder = mlir::Builder(getContext());
|
||||
splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths)));
|
||||
}
|
||||
|
||||
// Build result types.
|
||||
for (int i = 0; i < numOfResults; ++i) {
|
||||
SmallVector<int64_t, 3> resultShape;
|
||||
for (int j = 0; j < inputRank; ++j) {
|
||||
if (j == axisIndex) {
|
||||
resultShape.emplace_back(splitLengths[i]);
|
||||
} else {
|
||||
resultShape.emplace_back(inputShape[j]);
|
||||
}
|
||||
}
|
||||
getResults()[i].setType(
|
||||
RankedTensorType::get(resultShape, inputType.getElementType()));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3241,7 +3241,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
|
|||
}
|
||||
|
||||
def ONNXSplitOp:ONNX_Op<"Split",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Split operation";
|
||||
let description = [{
|
||||
"Split a tensor into a list of tensors, along the specified"
|
||||
|
|
|
@ -124,6 +124,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Abs" &&
|
||||
op->getName().getStringRef() != "onnx.Constant" &&
|
||||
op->getName().getStringRef() != "onnx.Concat" &&
|
||||
op->getName().getStringRef() != "onnx.Split" &&
|
||||
op->getName().getStringRef() != "onnx.Neg" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
|
|
|
@ -610,3 +610,36 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
|||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_split_1
|
||||
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
|
||||
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_split_2(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||
%0, %1 = "onnx.Split"(%arg0) { axis = -2 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_split_2
|
||||
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>)
|
||||
// CHECK: return [[RES]]#0 : tensor<16x16x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||
%0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_split_3
|
||||
// CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>)
|
||||
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ OpsWithShapeInference = [
|
|||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg'
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
|
Loading…
Reference in New Issue