implement shape inference for concat (#74)
* implement shape inference for concat * better checking of axis being concatenated: constant values only
This commit is contained in:
parent
37399fd8b8
commit
f5bed72e13
|
@ -46,7 +46,7 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
|
@ -1463,6 +1463,65 @@ bool ONNXConstantOp::inferShapes() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Concat
|
||||||
|
|
||||||
|
bool ONNXConcatOp::inferShapes() {
|
||||||
|
int inputNum = getNumOperands();
|
||||||
|
for (int i = 0; i < inputNum; ++i) {
|
||||||
|
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Checking value of axis parameter.
|
||||||
|
auto commonType = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
|
auto commonShape = commonType.getShape();
|
||||||
|
auto commonRank = commonShape.size();
|
||||||
|
auto axisIndex = axis().getSExtValue();
|
||||||
|
if (!(axisIndex >= 0 && axisIndex < commonRank)) {
|
||||||
|
emitError("Concat axis value out of bound");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Initial cummlative size is that of the first operand.
|
||||||
|
int cummulativeAxisSize = commonShape[axisIndex];
|
||||||
|
|
||||||
|
// Compute the cummlative size with all of the other ones, and make sure that
|
||||||
|
// the other sizes are all alike.
|
||||||
|
for (int i = 1; i < inputNum; ++i) {
|
||||||
|
auto currShape =
|
||||||
|
getOperand(i).getType().cast<RankedTensorType>().getShape();
|
||||||
|
if (currShape.size() != commonRank) {
|
||||||
|
emitError("Concat input must all have the same rank");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < commonRank; ++j) {
|
||||||
|
if (j == axisIndex) {
|
||||||
|
// Check that the value is positive.
|
||||||
|
if (currShape[j] <= 0) {
|
||||||
|
emitError("Concat axis being concatenated is expected to be known at "
|
||||||
|
"compile time for now");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (currShape[j] != commonShape[j]) {
|
||||||
|
emitError("Concat input dimensions must be all identical, except for "
|
||||||
|
"dimension on the axis of the concatenation");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cummulativeAxisSize += currShape[axisIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set output size and type
|
||||||
|
SmallVector<int64_t, 4> outputDims;
|
||||||
|
for (int j = 0; j < commonRank; ++j)
|
||||||
|
outputDims.emplace_back(
|
||||||
|
j == axisIndex ? cummulativeAxisSize : commonShape[j]);
|
||||||
|
getResult().setType(
|
||||||
|
RankedTensorType::get(outputDims, commonType.getElementType()));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -314,7 +314,7 @@ def ONNXCompressOp:ONNX_Op<"Compress",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConcatOp:ONNX_Op<"Concat",
|
def ONNXConcatOp:ONNX_Op<"Concat",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Concat operation";
|
let summary = "ONNX Concat operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
|
"Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on."
|
||||||
|
|
|
@ -123,6 +123,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||||
op->getName().getStringRef() != "onnx.Abs" &&
|
op->getName().getStringRef() != "onnx.Abs" &&
|
||||||
op->getName().getStringRef() != "onnx.Constant" &&
|
op->getName().getStringRef() != "onnx.Constant" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Concat" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
|
|
|
@ -487,3 +487,25 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
|
||||||
// CHECK: return [[RES]] : tensor<80x5x2xf32>
|
// CHECK: return [[RES]] : tensor<80x5x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the reshape op inference when concat are present.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_concat_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x5x9x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 1 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_concat_2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue