diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 80d089e..ec388a2 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -752,6 +752,96 @@ class ConstConverter : public OpConversionPattern { } }; +class ReduceConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lmhlo::ReduceOp reduce_op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = reduce_op.getLoc(); + lmhlo::ReduceOp::Adaptor adaptor(args); + auto operand_shape = + adaptor.operands()[0].getType().template dyn_cast(); + if (!operand_shape || !operand_shape.hasRank()) { + emitError(loc, "lhlo to linalg conversion expects known-rank args"); + return failure(); + } + + // First fill the output buffer with the init value. + Value init_value = rewriter.create(loc, adaptor.init_values()[0]); + rewriter.create(loc, adaptor.out()[0], init_value); + + DenseIntElementsAttr dimensions_attr = reduce_op.dimensions(); + SmallVector reduction_dims; + for (const auto& dim : dimensions_attr.getIntValues()) { + reduction_dims.push_back(dim.getSExtValue()); + } + + SmallVector src_exprs; + SmallVector dst_exprs; + SmallVector types; + for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) { + bool is_reduced = llvm::is_contained(reduction_dims, i); + types.push_back(is_reduced ? getReductionIteratorTypeName() + : getParallelIteratorTypeName()); + + src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + if (!is_reduced) { + dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); + } + } + + auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs}); + + auto linalg_op = rewriter.create( + loc, /*resultTensorTypes=*/ArrayRef{}, + /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), + /*initTensors=*/ValueRange{}, maps, types); + linalg_op.region().takeBody(reduce_op.body()); + { + OpBuilder::InsertionGuard region_guard(rewriter); + Block* block = linalg_op.getBody(); + rewriter.setInsertionPoint(&block->front()); + + // The incoming region is operating on buffers, while linalg.generic + // expects scalar SSA values. Add some allocs around the original op to + // make it compatible. + auto arg_type = block->getArgument(0).getType().cast(); + Value alloc_a = rewriter.create(loc, arg_type); + Value alloc_b = rewriter.create(loc, arg_type); + Value alloc_res = rewriter.create(loc, arg_type); + + // Now turn the existing signature + // (memref, memref, memref) -> () + // into + // (X, X) -> X + TypeConverter::SignatureConversion signature_converter(3); + signature_converter.remapInput(0, alloc_a); + signature_converter.remapInput(1, alloc_b); + signature_converter.remapInput(2, alloc_res); + signature_converter.addInputs( + {arg_type.getElementType(), arg_type.getElementType()}); + Block* entry_block = rewriter.applySignatureConversion( + &linalg_op.region(), signature_converter); + + // Store the arguments into the newly allocated buffers. + rewriter.setInsertionPointAfter(alloc_res.getDefiningOp()); + rewriter.create(loc, entry_block->getArgument(0), alloc_a); + rewriter.create(loc, entry_block->getArgument(1), alloc_b); + rewriter.replaceOp(entry_block->getTerminator(), {}); + + // Load & yield the result. + rewriter.setInsertionPointToEnd(entry_block); + auto load_res = rewriter.create(loc, alloc_res); + rewriter.create(loc, ValueRange{load_res}); + } + + rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults()); + return success(); + } +}; + // TODO(b/156787842): Support the lowering for dynamic shapes. template class ReverseConverter @@ -853,9 +943,11 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + ReduceConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, + ScalarPointwiseToStandardConverter, SliceConverter, TransposeConverter >(context); diff --git a/tests/end2end/reduce.mlir b/tests/end2end/reduce.mlir new file mode 100644 index 0000000..8eb6553 --- /dev/null +++ b/tests/end2end/reduce.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ +// RUN: -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting \ +// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \ +// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ +// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ +// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \ +// RUN: -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @main() -> () { + call @reduce_add() : () -> () + call @reduce_max() : () -> () + return +} + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @reduce_add() { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %i_i64 = index_cast %i : index to i64 + %i_f32 = sitofp %i_i64 : i64 to f32 + store %i_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK: [0, 0, 0] + // CHECK: [1, 1, 1] + + %in = tensor_load %input : memref<2x3xf32> + %init = mhlo.constant dense<0.000000e+00> : tensor + + %reduce = "mhlo.reduce"(%in, %init) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor<2x3xf32>, tensor) -> tensor<2xf32> + + %output = alloc() : memref<2xf32> + tensor_store %reduce, %output : memref<2xf32> + %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] + // CHECK: [0, 3] + return +} + +func @reduce_max() { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %i_i64 = index_cast %i : index to i64 + %i_f32 = sitofp %i_i64 : i64 to f32 + store %i_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK: [0, 0, 0] + // CHECK: [1, 1, 1] + + %in = tensor_load %input : memref<2x3xf32> + %init = mhlo.constant dense<0xff800000> : tensor + + %reduce = "mhlo.reduce"(%in, %init) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.maximum %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor<2x3xf32>, tensor) -> tensor<2xf32> + + %output = alloc() : memref<2xf32> + tensor_store %reduce, %output : memref<2xf32> + %unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 1 offset = 0 sizes = [2] strides = [1] + // CHECK: [0, 1] + return +} diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 4715108..8548300 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -846,3 +846,76 @@ func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]] + +// ----- + +// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @reduce_add +func @reduce_add(%arg: memref<100x10xf32>, + %init: memref, + %result: memref<100xf32>) { + "lmhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "lmhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref<100x10xf32>, memref, memref<100xf32>) -> () + return +} +// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref +// CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) +// CHECK: linalg.generic { +// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction"]} +// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { +// CHECK: alloca +// CHECK-NEXT: alloca +// CHECK-NEXT: alloca +// CHECK-NEXT: store +// CHECK-NEXT: store +// CHECK-NEXT: load +// CHECK-NEXT: load +// CHECK-NEXT: addf +// CHECK-NEXT: store +// CHECK-NEXT: load +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } + +// ----- + +// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @reduce_maximum +func @reduce_maximum(%arg: memref<100x10xf32>, + %init: memref, + %result: memref<100xf32>) { + "lmhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "lmhlo.maximum"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref<100x10xf32>, memref, memref<100xf32>) -> () + return +} +// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref +// CHECK: linalg.fill(%arg2, %[[INIT_VAL]]) +// CHECK: linalg.generic { +// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction"]} +// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) { +// CHECK: alloca +// CHECK-NEXT: alloca +// CHECK-NEXT: alloca +// CHECK-NEXT: store +// CHECK-NEXT: store +// CHECK-NEXT: load +// CHECK-NEXT: load +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: store +// CHECK-NEXT: load +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: }