diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index dd98fb4..aeffaea 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -841,7 +841,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, - SliceConverter + SliceConverter, + TransposeConverter >(context); // clang-format on } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index cdebf5b..136bee8 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -748,3 +748,16 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-DAG: #[[TRANSPOSE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[TRANSPOSE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @transpose +func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + "lmhlo.transpose"(%arg0, %arg1) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]]