[MLIR] Add a lmhlo.reduce -> linalg.generic converter

Doesn't support tensors right now, as it's somewhat hairy to support both at
the same time. Since we use a generic lowering the result is messy
and needs a mem2reg pass to eliminate extra load/store/allocas.

PiperOrigin-RevId: 339562971
This commit is contained in:
Benjamin Kramer 2020-10-28 16:37:38 -07:00 committed by TensorFlow MLIR Team
parent e58bfd48e6
commit 3bf4277ea4
3 changed files with 258 additions and 0 deletions

View File

@ -752,6 +752,96 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
}
};
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
public:
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = reduce_op.getLoc();
lmhlo::ReduceOp::Adaptor adaptor(args);
auto operand_shape =
adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
if (!operand_shape || !operand_shape.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure();
}
// First fill the output buffer with the init value.
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
SmallVector<int, 4> reduction_dims;
for (const auto& dim : dimensions_attr.getIntValues()) {
reduction_dims.push_back(dim.getSExtValue());
}
SmallVector<AffineExpr, 2> src_exprs;
SmallVector<AffineExpr, 2> dst_exprs;
SmallVector<StringRef, 4> types;
for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
bool is_reduced = llvm::is_contained(reduction_dims, i);
types.push_back(is_reduced ? getReductionIteratorTypeName()
: getParallelIteratorTypeName());
src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
if (!is_reduced) {
dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
}
auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
/*initTensors=*/ValueRange{}, maps, types);
linalg_op.region().takeBody(reduce_op.body());
{
OpBuilder::InsertionGuard region_guard(rewriter);
Block* block = linalg_op.getBody();
rewriter.setInsertionPoint(&block->front());
// The incoming region is operating on buffers, while linalg.generic
// expects scalar SSA values. Add some allocs around the original op to
// make it compatible.
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
// Now turn the existing signature
// (memref<X>, memref<X>, memref<X>) -> ()
// into
// (X, X) -> X
TypeConverter::SignatureConversion signature_converter(3);
signature_converter.remapInput(0, alloc_a);
signature_converter.remapInput(1, alloc_b);
signature_converter.remapInput(2, alloc_res);
signature_converter.addInputs(
{arg_type.getElementType(), arg_type.getElementType()});
Block* entry_block = rewriter.applySignatureConversion(
&linalg_op.region(), signature_converter);
// Store the arguments into the newly allocated buffers.
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
rewriter.replaceOp(entry_block->getTerminator(), {});
// Load & yield the result.
rewriter.setInsertionPointToEnd(entry_block);
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
}
rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
return success();
}
};
// TODO(b/156787842): Support the lowering for dynamic shapes.
template <typename OpTy, bool isLHLO = true>
class ReverseConverter
@ -853,9 +943,11 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
SliceConverter,
TransposeConverter<lmhlo::TransposeOp>
>(context);

93
tests/end2end/reduce.mlir Normal file
View File

@ -0,0 +1,93 @@
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting \
// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \
// RUN: -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
// RUN: FileCheck %s
func @main() -> () {
call @reduce_add() : () -> ()
call @reduce_max() : () -> ()
return
}
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
func @reduce_add() {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
%in = tensor_load %input : memref<2x3xf32>
%init = mhlo.constant dense<0.000000e+00> : tensor<f32>
%reduce = "mhlo.reduce"(%in, %init) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
%output = alloc() : memref<2xf32>
tensor_store %reduce, %output : memref<2xf32>
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
// CHECK: [0, 3]
return
}
func @reduce_max() {
%c0 = constant 0 : index
%c1 = constant 1 : index
// Initialize input.
%input = alloc() : memref<2x3xf32>
%dim_x = dim %input, %c0 : memref<2x3xf32>
%dim_y = dim %input, %c1 : memref<2x3xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
%i_i64 = index_cast %i : index to i64
%i_f32 = sitofp %i_i64 : i64 to f32
store %i_f32, %input[%i, %j] : memref<2x3xf32>
}
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
// CHECK: [0, 0, 0]
// CHECK: [1, 1, 1]
%in = tensor_load %input : memref<2x3xf32>
%init = mhlo.constant dense<0xff800000> : tensor<f32>
%reduce = "mhlo.reduce"(%in, %init) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.maximum %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
%output = alloc() : memref<2xf32>
tensor_store %reduce, %output : memref<2xf32>
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
// CHECK: [0, 1]
return
}

View File

@ -846,3 +846,76 @@ func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]]
// -----
// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @reduce_add
func @reduce_add(%arg: memref<100x10xf32>,
%init: memref<f32>,
%result: memref<100xf32>) {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
}
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: store
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: load
// CHECK-NEXT: addf
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
// -----
// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @reduce_maximum
func @reduce_maximum(%arg: memref<100x10xf32>,
%init: memref<f32>,
%result: memref<100xf32>) {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"lmhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
}
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: store
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: load
// CHECK-NEXT: cmpf
// CHECK-NEXT: select
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }