Implement shape inference for SplitOp (#95)
* Implement shape inference for SplitOp * Change spitOpt to SplitAttribute and check the axis range before updating the axis attribute Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									7c29da191e
								
							
						
					
					
						commit
						9a874007ce
					
				|  | @ -1537,6 +1537,79 @@ bool ONNXConcatOp::inferShapes() { | |||
|   return true; | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Split
 | ||||
| 
 | ||||
| bool ONNXSplitOp::inferShapes() { | ||||
|   if (!getOperand().getType().cast<RankedTensorType>()) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   int numOfResults = getNumResults(); | ||||
|   auto inputType = getOperand().getType().cast<RankedTensorType>(); | ||||
|   auto inputShape = inputType.getShape(); | ||||
|   int64_t inputRank = inputShape.size(); | ||||
| 
 | ||||
|   // Checking value of axis parameter.
 | ||||
|   auto axisIndex = axis().getSExtValue(); | ||||
|   if (axisIndex < -inputRank || axisIndex >= inputRank) { | ||||
|     emitError("Split axis value out of bound"); | ||||
|     return false; | ||||
|   } | ||||
|   // Negative axis means values are counted from the opposite side.
 | ||||
|   if (axisIndex < 0) { | ||||
|     axisIndex = inputRank + axisIndex; | ||||
|     auto builder = mlir::Builder(getContext()); | ||||
|     axisAttr(builder.getI64IntegerAttr(axisIndex)); | ||||
|   } | ||||
| 
 | ||||
|   // Checking value of split parameter.
 | ||||
|   auto splitAttribute = split(); | ||||
|   SmallVector<int64_t, 4> splitLengths; | ||||
|   if (splitAttribute.hasValue()) { | ||||
|     if (ArrayAttrSize(splitAttribute) != numOfResults) { | ||||
|       emitError("Split size not equal to the number of results"); | ||||
|     } | ||||
|     for (int i = 0; i < numOfResults; ++i) | ||||
|       splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i)); | ||||
| 
 | ||||
|   } else { | ||||
|     if (inputShape[axisIndex] <= 0) { | ||||
|       emitError("The dimension at the split axis is expected to be known at " | ||||
|                 "compile time"); | ||||
|       return false; | ||||
|     } | ||||
|     if (inputShape[axisIndex] % numOfResults != 0) { | ||||
|       emitError("The dimension at the split axis is expected to be divisible " | ||||
|                 "by the number of results"); | ||||
|       return false; | ||||
|     } | ||||
|     // If split parameter is not specified, the dimension is split to
 | ||||
|     // equal-sized parts.
 | ||||
|     for (int i = 0; i < numOfResults; ++i) | ||||
|       splitLengths.emplace_back(inputShape[axisIndex] / numOfResults); | ||||
|     // Build attribute and store attribute.
 | ||||
|     auto builder = mlir::Builder(getContext()); | ||||
|     splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths))); | ||||
|   } | ||||
| 
 | ||||
|   // Build result types.
 | ||||
|   for (int i = 0; i < numOfResults; ++i) { | ||||
|     SmallVector<int64_t, 3> resultShape; | ||||
|     for (int j = 0; j < inputRank; ++j) { | ||||
|       if (j == axisIndex) { | ||||
|         resultShape.emplace_back(splitLengths[i]); | ||||
|       } else { | ||||
|         resultShape.emplace_back(inputShape[j]); | ||||
|       } | ||||
|     } | ||||
|     getResults()[i].setType( | ||||
|         RankedTensorType::get(resultShape, inputType.getElementType())); | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -3241,7 +3241,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", | |||
| } | ||||
| 
 | ||||
| def ONNXSplitOp:ONNX_Op<"Split", | ||||
|   [NoSideEffect]> { | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Split operation"; | ||||
|   let description = [{ | ||||
|   "Split a tensor into a list of tensors, along the specified" | ||||
|  |  | |||
|  | @ -124,6 +124,7 @@ public: | |||
|         op->getName().getStringRef() != "onnx.Abs" && | ||||
|         op->getName().getStringRef() != "onnx.Constant" && | ||||
|         op->getName().getStringRef() != "onnx.Concat" && | ||||
|         op->getName().getStringRef() != "onnx.Split" && | ||||
|         op->getName().getStringRef() != "onnx.Neg" && | ||||
|         op->getName().getStringRef() != "onnx.Unsqueeze") | ||||
|       return false; | ||||
|  |  | |||
|  | @ -610,3 +610,36 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg | |||
|   // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> | ||||
|   // CHECK: return [[RES]] : tensor<5x9x32xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { | ||||
|   %0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_split_1 | ||||
|   // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>) | ||||
|   // CHECK: return [[RES]]#0 : tensor<16x16x64xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_split_2(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { | ||||
|   %0, %1 = "onnx.Split"(%arg0) { axis = -2 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_split_2 | ||||
|   // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>) | ||||
|   // CHECK: return [[RES]]#0 : tensor<16x16x64xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { | ||||
|   %0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_split_3 | ||||
|   // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>) | ||||
|   // CHECK: return [[RES]]#0 : tensor<16x2x64xf32> | ||||
| } | ||||
|  |  | |||
|  | @ -63,7 +63,7 @@ OpsWithShapeInference = [ | |||
|     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||
|     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', | ||||
|     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', | ||||
|     'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg' | ||||
|     'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split' | ||||
| ] | ||||
| 
 | ||||
| # Operations supporting canonicalization. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue