Add the default shape inference for the transposition operation.
This commit is contained in:
parent
caeba371fb
commit
151f4f8c44
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
|||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log']
|
||||
'Identity', 'Cos', 'Log', 'Transpose']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
|
|
|
@ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() {
|
|||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Transpose
|
||||
|
||||
void ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand()->getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
// TODO: Once attributes are supported we can handle the case where the
|
||||
// transposition uses a permutation vector to interchange the axes.
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK",
|
|||
}
|
||||
|
||||
def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Transpose operation";
|
||||
let description = [{
|
||||
"Transpose the input tensor similar to numpy.transpose. For example, when"
|
||||
|
|
|
@ -115,7 +115,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape")
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s
|
||||
|
||||
/// Test the default behavior of transpose when no information for the
|
||||
/// permutation of the axes is provided.
|
||||
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_default_transpose
|
||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
Loading…
Reference in New Issue