implement shape inference for concat (#74)
* implement shape inference for concat * better checking of axis being concatenated: constant values only
This commit is contained in:
parent
37399fd8b8
commit
f5bed72e13
|
@ -46,7 +46,7 @@ OpsWithShapeInference = [
|
|||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
|
|
@ -1463,6 +1463,65 @@ bool ONNXConstantOp::inferShapes() {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Concat
|
||||
|
||||
bool ONNXConcatOp::inferShapes() {
|
||||
int inputNum = getNumOperands();
|
||||
for (int i = 0; i < inputNum; ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Checking value of axis parameter.
|
||||
auto commonType = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto commonShape = commonType.getShape();
|
||||
auto commonRank = commonShape.size();
|
||||
auto axisIndex = axis().getSExtValue();
|
||||
if (!(axisIndex >= 0 && axisIndex < commonRank)) {
|
||||
emitError("Concat axis value out of bound");
|
||||
return false;
|
||||
}
|
||||
// Initial cummlative size is that of the first operand.
|
||||
int cummulativeAxisSize = commonShape[axisIndex];
|
||||
|
||||
// Compute the cummlative size with all of the other ones, and make sure that
|
||||
// the other sizes are all alike.
|
||||
for (int i = 1; i < inputNum; ++i) {
|
||||
auto currShape =
|
||||
getOperand(i).getType().cast<RankedTensorType>().getShape();
|
||||
if (currShape.size() != commonRank) {
|
||||
emitError("Concat input must all have the same rank");
|
||||
return false;
|
||||
}
|
||||
for (int j = 0; j < commonRank; ++j) {
|
||||
if (j == axisIndex) {
|
||||
// Check that the value is positive.
|
||||
if (currShape[j] <= 0) {
|
||||
emitError("Concat axis being concatenated is expected to be known at "
|
||||
"compile time for now");
|
||||
return false;
|
||||
}
|
||||
} else if (currShape[j] != commonShape[j]) {
|
||||
emitError("Concat input dimensions must be all identical, except for "
|
||||
"dimension on the axis of the concatenation");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
cummulativeAxisSize += currShape[axisIndex];
|
||||
}
|
||||
|
||||
// Set output size and type
|
||||
SmallVector<int64_t, 4> outputDims;
|
||||
for (int j = 0; j < commonRank; ++j)
|
||||
outputDims.emplace_back(
|
||||
j == axisIndex ? cummulativeAxisSize : commonShape[j]);
|
||||
getResult().setType(
|
||||
RankedTensorType::get(outputDims, commonType.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -314,7 +314,7 @@ def ONNXCompressOp:ONNX_Op<"Compress",
|
|||
}
|
||||
|
||||
def ONNXConcatOp:ONNX_Op<"Concat",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Concat operation";
|
||||
let description = [{
|
||||
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
|
||||
|
|
|
@ -123,6 +123,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||
op->getName().getStringRef() != "onnx.Abs" &&
|
||||
op->getName().getStringRef() != "onnx.Constant" &&
|
||||
op->getName().getStringRef() != "onnx.Concat" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -487,3 +487,25 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
|||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<80x5x2xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test the reshape op inference when concat are present.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_concat_1
|
||||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x5x9x32xf32>
|
||||
}
|
||||
|
||||
func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 1 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_concat_2
|
||||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue