Add support for lowering variadic mhlo.reduce op.

Also add more lowering for body ops. Some MinOp and MaxOp can be legalized to
SelectOp + CompareOp.

PiperOrigin-RevId: 369891551
This commit is contained in:
Hanhan Wang 2021-04-22 09:49:40 -07:00 committed by TensorFlow MLIR Team
parent 2dfea55d0c
commit 49df46893c
2 changed files with 93 additions and 41 deletions

View File

@ -1473,8 +1473,8 @@ struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
})) {
return failure();
}
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
args, &rewriter);
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, getElementTypeOrSelf(op.getType()), args, &rewriter);
rewriter.replaceOp(op, result);
return success();
}
@ -1518,58 +1518,65 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
ConversionPatternRewriter& rewriter) const final {
Location loc = op.getLoc();
mhlo::ReduceOp::Adaptor adaptor(args);
if (op.getNumOperands() != 2) {
return op.emitError("expects exactly two operands");
}
Value src = adaptor.inputs()[0];
auto src_type = src.getType().cast<ShapedType>();
int num_inputs = static_cast<int>(adaptor.inputs().size());
auto src_type = adaptor.inputs()[0].getType().cast<ShapedType>();
int src_rank = src_type.getRank();
if (!src_rank) {
return rewriter.notifyMatchFailure(op, "expects known-rank args");
}
// Check if init_value is constant. If so, inline the value into the region.
Value init_value = adaptor.init_values()[0];
Attribute init_const_val = GetInitValueAsConst(init_value);
if (init_const_val) {
init_value = rewriter.create<ConstantOp>(
init_value.getDefiningOp()->getLoc(), init_const_val);
} else {
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
SmallVector<Value> inputs, outputs;
SmallVector<AffineMap, 3> indexing_maps;
for (int i = 0; i < num_inputs; ++i) {
Value src = adaptor.inputs()[i];
if (src.getType() != src_type) return failure();
// Check if init_value is constant. If so, inline the value into the
// region.
Value init_value = adaptor.init_values()[i];
Attribute init_const_val = GetInitValueAsConst(init_value);
if (init_const_val) {
init_value = rewriter.create<ConstantOp>(
init_value.getDefiningOp()->getLoc(), init_const_val);
} else {
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
}
inputs.push_back(src);
auto result_type = op.getResult(i).getType().cast<ShapedType>();
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
rewriter, loc, src, result_type, reduction_dims);
auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape);
Value filled_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
.result();
outputs.push_back(filled_tensor);
}
// Prepare indexing maps for linalg generic op. The elements are for src and
// dst. Transpose `src` to make the reduction loops be the innermost,
// Prepare indexing maps for linalg generic op. The elements are for src
// and dst. Transpose `src` to make the reduction loops be the innermost,
// because it's easier to fully utilize processors.
SmallVector<AffineMap, 3> indexing_maps;
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
indexing_maps.emplace_back(GetTransposeMapForReduction(
rewriter.getContext(), src_rank, reduction_dims));
indexing_maps.append(
num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank,
reduction_dims));
// The indexing map of `dst` should drop the reduction loops. Since the
// reduction loops now are all in the innermost, drops
// `reduction_dims.size()` dimensions. We don't need an inverse permutation
// here because they are the same.
// `reduction_dims.size()` dimensions. We don't need an inverse
// permutation here because they are the same.
SmallVector<AffineExpr, 4> exprs;
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
exprs, rewriter.getContext()));
SmallVector<Value, 2> inputs = {adaptor.inputs()[0]};
Type result_type = op.getResult(0).getType();
auto shaped_type = result_type.cast<ShapedType>();
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
rewriter, loc, adaptor.inputs()[0], result_type.cast<ShapedType>(),
reduction_dims);
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
Value filled_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
.getResult(0);
indexing_maps.append(num_inputs,
AffineMap::get(src_rank, /*symbolCount=*/0, exprs,
rewriter.getContext()));
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
/*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
/*outputBuffers=*/ValueRange{outputs}, indexing_maps,
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
// Convert the signature of the body. The reduce op region apply function
@ -1579,10 +1586,10 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
// be converted to "(f32, f32, f32)".
Region& region = linalg_op.region();
rewriter.inlineRegionBefore(op.body(), region, region.end());
TypeConverter::SignatureConversion signatureConverter(2);
signatureConverter.addInputs(0, src_type.getElementType());
signatureConverter.addInputs(1, src_type.getElementType());
rewriter.applySignatureConversion(&region, signatureConverter);
TypeConverter::SignatureConversion signature_converter(num_inputs * 2);
for (int i = 0; i < num_inputs * 2; ++i)
signature_converter.addInputs(i, src_type.getElementType());
rewriter.applySignatureConversion(&region, signature_converter);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
@ -2272,6 +2279,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
ReduceRegionXLAOpConversion<mhlo::AndOp>,
ReduceRegionXLAOpConversion<mhlo::OrOp>,
ReduceRegionXLAOpConversion<mhlo::SelectOp>,
ReduceRegionXLAOpConversion<mhlo::CompareOp>,
ReduceRegionReturnOpConversion>(context);
}

View File

@ -1506,6 +1506,49 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
// -----
func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
%cst0 = mhlo.constant dense<-2147483648> : tensor<i32>
%cst1 = mhlo.constant dense<0> : tensor<i32>
%res0, %res1 = "mhlo.reduce"(%arg0, %arg1, %cst0, %cst1) ( {
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg15: tensor<i32>, %arg16: tensor<i32>): // no predecessors
%669 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%670 = "mhlo.select"(%669, %arg2, %arg15) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%671 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%672 = mhlo.minimum %arg3, %arg16 : tensor<i32>
%673 = "mhlo.select"(%669, %arg3, %arg16) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%674 = "mhlo.select"(%671, %672, %673) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
"mhlo.return"(%670, %674) : (tensor<i32>, tensor<i32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
return %res0, %res1 : tensor<2xi32>, tensor<2xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @variadic_reduce
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[CST0:.*]] = constant -2147483648 : i32
// CHECK: %[[INIT0:.*]] = linalg.init_tensor [2] : tensor<2xi32>
// CHECK: %[[FILL0:.*]] = linalg.fill(%[[INIT0]], %[[CST0]])
// CHECK: %[[CST1:.*]] = constant 0 : i32
// CHECK: %[[INIT1:.*]] = linalg.init_tensor [2] : tensor<2xi32>
// CHECK: %[[FILL1:.*]] = linalg.fill(%[[INIT1]], %[[CST1]])
// CHECK: %[[RES:.+]]:2 = linalg.generic {
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>)
// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>)
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32):
// CHECK-NEXT: %[[T1:.*]] = cmpi sge, %[[IN0]], %[[OUT0]] : i32
// CHECK-NEXT: %[[T2:.*]] = select %[[T1]], %[[IN0]], %[[OUT0]] : i32
// CHECK-NEXT: %[[T3:.*]] = cmpi eq, %[[IN0]], %[[OUT0]] : i32
// CHECK-NEXT: %[[T4:.*]] = cmpi slt, %[[IN1]], %[[OUT1]] : i32
// CHECK-NEXT: %[[T5:.*]] = select %[[T4]], %[[IN1]], %[[OUT1]] : i32
// CHECK-NEXT: %[[T6:.*]] = select %[[T1]], %[[IN1]], %[[OUT1]] : i32
// CHECK-NEXT: %[[T7:.*]] = select %[[T3]], %[[T5]], %[[T6]] : i32
// CHECK-NEXT: linalg.yield %[[T2]], %[[T7]]
// -----
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 0]> : tensor<2xi64>,