Add support for legalizing lmhlo.transpose to linalg.generic

PiperOrigin-RevId: 330130704
This commit is contained in:
Eugene Burmako 2020-09-04 15:37:02 -07:00 committed by TensorFlow MLIR Team
parent f5d12604ed
commit 2a9d7ac084
2 changed files with 15 additions and 1 deletions

View File

@ -841,7 +841,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
SliceConverter
SliceConverter,
TransposeConverter<lmhlo::TransposeOp>
>(context);
// clang-format on
}

View File

@ -748,3 +748,16 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
"lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-DAG: #[[TRANSPOSE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[TRANSPOSE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @transpose
func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
"lmhlo.transpose"(%arg0, %arg1) {
permutation = dense<[1, 0]> : tensor<2xi64>
} : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]]