implement shape inference for concat (#74)

* implement shape inference for concat

* better checking of axis being concatenated: constant values only
This commit is contained in:
Alexandre Eichenberger 2020-04-07 16:13:41 -04:00 committed by GitHub
parent 37399fd8b8
commit f5bed72e13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 2 deletions

View File

@ -46,7 +46,7 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat'
]
# Operations supporting canonicalization.

View File

@ -1463,6 +1463,65 @@ bool ONNXConstantOp::inferShapes() {
return true;
}
//===----------------------------------------------------------------------===//
// Concat
bool ONNXConcatOp::inferShapes() {
int inputNum = getNumOperands();
for (int i = 0; i < inputNum; ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
}
// Checking value of axis parameter.
auto commonType = getOperand(0).getType().cast<RankedTensorType>();
auto commonShape = commonType.getShape();
auto commonRank = commonShape.size();
auto axisIndex = axis().getSExtValue();
if (!(axisIndex >= 0 && axisIndex < commonRank)) {
emitError("Concat axis value out of bound");
return false;
}
// Initial cummlative size is that of the first operand.
int cummulativeAxisSize = commonShape[axisIndex];
// Compute the cummlative size with all of the other ones, and make sure that
// the other sizes are all alike.
for (int i = 1; i < inputNum; ++i) {
auto currShape =
getOperand(i).getType().cast<RankedTensorType>().getShape();
if (currShape.size() != commonRank) {
emitError("Concat input must all have the same rank");
return false;
}
for (int j = 0; j < commonRank; ++j) {
if (j == axisIndex) {
// Check that the value is positive.
if (currShape[j] <= 0) {
emitError("Concat axis being concatenated is expected to be known at "
"compile time for now");
return false;
}
} else if (currShape[j] != commonShape[j]) {
emitError("Concat input dimensions must be all identical, except for "
"dimension on the axis of the concatenation");
return false;
}
}
cummulativeAxisSize += currShape[axisIndex];
}
// Set output size and type
SmallVector<int64_t, 4> outputDims;
for (int j = 0; j < commonRank; ++j)
outputDims.emplace_back(
j == axisIndex ? cummulativeAxisSize : commonShape[j]);
getResult().setType(
RankedTensorType::get(outputDims, commonType.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -314,7 +314,7 @@ def ONNXCompressOp:ONNX_Op<"Compress",
}
def ONNXConcatOp:ONNX_Op<"Concat",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Concat operation";
let description = [{
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."

View File

@ -123,6 +123,7 @@ public:
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
op->getName().getStringRef() != "onnx.Abs" &&
op->getName().getStringRef() != "onnx.Constant" &&
op->getName().getStringRef() != "onnx.Concat" &&
op->getName().getStringRef() != "onnx.Unsqueeze")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -487,3 +487,25 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
// CHECK: return [[RES]] : tensor<80x5x2xf32>
}
//===----------------------------------------------------------------------===//
/// Test the reshape op inference when concat are present.
//===----------------------------------------------------------------------===//
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<*xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_concat_1
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
// CHECK: return [[RES]] : tensor<5x5x9x32xf32>
}
func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 1 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_concat_2
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
// CHECK: return [[RES]] : tensor<5x9x32xf32>
}