[MLIR] Add a lmhlo.reduce -> linalg.generic converter
Doesn't support tensors right now, as it's somewhat hairy to support both at the same time. Since we use a generic lowering the result is messy and needs a mem2reg pass to eliminate extra load/store/allocas. PiperOrigin-RevId: 339562971
This commit is contained in:
parent
e58bfd48e6
commit
3bf4277ea4
|
@ -752,6 +752,96 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
|||
}
|
||||
};
|
||||
|
||||
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = reduce_op.getLoc();
|
||||
lmhlo::ReduceOp::Adaptor adaptor(args);
|
||||
auto operand_shape =
|
||||
adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
|
||||
if (!operand_shape || !operand_shape.hasRank()) {
|
||||
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// First fill the output buffer with the init value.
|
||||
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
|
||||
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
||||
|
||||
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
||||
SmallVector<int, 4> reduction_dims;
|
||||
for (const auto& dim : dimensions_attr.getIntValues()) {
|
||||
reduction_dims.push_back(dim.getSExtValue());
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr, 2> src_exprs;
|
||||
SmallVector<AffineExpr, 2> dst_exprs;
|
||||
SmallVector<StringRef, 4> types;
|
||||
for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
|
||||
bool is_reduced = llvm::is_contained(reduction_dims, i);
|
||||
types.push_back(is_reduced ? getReductionIteratorTypeName()
|
||||
: getParallelIteratorTypeName());
|
||||
|
||||
src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
||||
if (!is_reduced) {
|
||||
dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
|
||||
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
|
||||
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
|
||||
/*initTensors=*/ValueRange{}, maps, types);
|
||||
linalg_op.region().takeBody(reduce_op.body());
|
||||
{
|
||||
OpBuilder::InsertionGuard region_guard(rewriter);
|
||||
Block* block = linalg_op.getBody();
|
||||
rewriter.setInsertionPoint(&block->front());
|
||||
|
||||
// The incoming region is operating on buffers, while linalg.generic
|
||||
// expects scalar SSA values. Add some allocs around the original op to
|
||||
// make it compatible.
|
||||
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
||||
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
|
||||
// Now turn the existing signature
|
||||
// (memref<X>, memref<X>, memref<X>) -> ()
|
||||
// into
|
||||
// (X, X) -> X
|
||||
TypeConverter::SignatureConversion signature_converter(3);
|
||||
signature_converter.remapInput(0, alloc_a);
|
||||
signature_converter.remapInput(1, alloc_b);
|
||||
signature_converter.remapInput(2, alloc_res);
|
||||
signature_converter.addInputs(
|
||||
{arg_type.getElementType(), arg_type.getElementType()});
|
||||
Block* entry_block = rewriter.applySignatureConversion(
|
||||
&linalg_op.region(), signature_converter);
|
||||
|
||||
// Store the arguments into the newly allocated buffers.
|
||||
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
|
||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
|
||||
rewriter.replaceOp(entry_block->getTerminator(), {});
|
||||
|
||||
// Load & yield the result.
|
||||
rewriter.setInsertionPointToEnd(entry_block);
|
||||
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
|
||||
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
||||
}
|
||||
|
||||
rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/156787842): Support the lowering for dynamic shapes.
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class ReverseConverter
|
||||
|
@ -853,9 +943,11 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
ReduceConverter,
|
||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
||||
SliceConverter,
|
||||
TransposeConverter<lmhlo::TransposeOp>
|
||||
>(context);
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
||||
// RUN: -hlo-legalize-to-lhlo=results-escape-function=true -buffer-hoisting \
|
||||
// RUN: -buffer-deallocation -copy-removal -canonicalize -cse \
|
||||
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
||||
// RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \
|
||||
// RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \
|
||||
// RUN: -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @main() -> () {
|
||||
call @reduce_add() : () -> ()
|
||||
call @reduce_max() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
|
||||
|
||||
func @reduce_add() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
%in = tensor_load %input : memref<2x3xf32>
|
||||
%init = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
|
||||
%reduce = "mhlo.reduce"(%in, %init) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
|
||||
%output = alloc() : memref<2xf32>
|
||||
tensor_store %reduce, %output : memref<2xf32>
|
||||
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
||||
// CHECK: [0, 3]
|
||||
return
|
||||
}
|
||||
|
||||
func @reduce_max() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
||||
// Initialize input.
|
||||
%input = alloc() : memref<2x3xf32>
|
||||
%dim_x = dim %input, %c0 : memref<2x3xf32>
|
||||
%dim_y = dim %input, %c1 : memref<2x3xf32>
|
||||
scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
|
||||
%i_i64 = index_cast %i : index to i64
|
||||
%i_f32 = sitofp %i_i64 : i64 to f32
|
||||
store %i_f32, %input[%i, %j] : memref<2x3xf32>
|
||||
}
|
||||
%unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
|
||||
// CHECK: [0, 0, 0]
|
||||
// CHECK: [1, 1, 1]
|
||||
|
||||
%in = tensor_load %input : memref<2x3xf32>
|
||||
%init = mhlo.constant dense<0xff800000> : tensor<f32>
|
||||
|
||||
%reduce = "mhlo.reduce"(%in, %init) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = mhlo.maximum %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
|
||||
%output = alloc() : memref<2xf32>
|
||||
tensor_store %reduce, %output : memref<2xf32>
|
||||
%unranked_output = memref_cast %output : memref<2xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
|
||||
// CHECK: rank = 1 offset = 0 sizes = [2] strides = [1]
|
||||
// CHECK: [0, 1]
|
||||
return
|
||||
}
|
|
@ -846,3 +846,76 @@ func @transpose(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
|||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[TRANSPOSE_INPUT_MAP]], #[[TRANSPOSE_OUTPUT_MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: func @reduce_add
|
||||
func @reduce_add(%arg: memref<100x10xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100xf32>) {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||
// CHECK: linalg.generic {
|
||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||
// CHECK: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: addf
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[REDUCE_INPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[REDUCE_OUTPUT_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-LABEL: func @reduce_maximum
|
||||
func @reduce_maximum(%arg: memref<100x10xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100xf32>) {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"lmhlo.maximum"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||
// CHECK: linalg.generic {
|
||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||
// CHECK: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: cmpf
|
||||
// CHECK-NEXT: select
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
|
Loading…
Reference in New Issue