Add the default shape inference for the transposition operation.
This commit is contained in:
parent
caeba371fb
commit
151f4f8c44
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log']
|
'Identity', 'Cos', 'Log', 'Transpose']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Transpose
|
||||||
|
|
||||||
|
void ONNXTransposeOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!getOperand()->getType().isa<RankedTensorType>())
|
||||||
|
emitError("Shape tensor not ranked.");
|
||||||
|
|
||||||
|
// Naive transposition which handles the default case of
|
||||||
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
|
// TODO: Once attributes are supported we can handle the case where the
|
||||||
|
// transposition uses a permutation vector to interchange the axes.
|
||||||
|
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||||
|
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXTransposeOp:ONNX_Op<"Transpose",
|
def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Transpose operation";
|
let summary = "ONNX Transpose operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Transpose the input tensor similar to numpy.transpose. For example, when"
|
"Transpose the input tensor similar to numpy.transpose. For example, when"
|
||||||
|
|
|
@ -115,7 +115,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||||
op->getName().getStringRef() != "onnx.Reshape")
|
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Transpose")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<RankedTensorType>();
|
return !result_type.isa<RankedTensorType>();
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
/// Test the default behavior of transpose when no information for the
|
||||||
|
/// permutation of the axes is provided.
|
||||||
|
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_default_transpose
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
Loading…
Reference in New Issue