Add support for lowering variadic mhlo.reduce op.
Also add more lowering for body ops. Some MinOp and MaxOp can be legalized to SelectOp + CompareOp. PiperOrigin-RevId: 369891551
This commit is contained in:
parent
2dfea55d0c
commit
49df46893c
|
@ -1473,8 +1473,8 @@ struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
|
||||||
})) {
|
})) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
|
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||||
args, &rewriter);
|
op, getElementTypeOrSelf(op.getType()), args, &rewriter);
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1518,58 +1518,65 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
mhlo::ReduceOp::Adaptor adaptor(args);
|
mhlo::ReduceOp::Adaptor adaptor(args);
|
||||||
if (op.getNumOperands() != 2) {
|
|
||||||
return op.emitError("expects exactly two operands");
|
int num_inputs = static_cast<int>(adaptor.inputs().size());
|
||||||
}
|
auto src_type = adaptor.inputs()[0].getType().cast<ShapedType>();
|
||||||
Value src = adaptor.inputs()[0];
|
|
||||||
auto src_type = src.getType().cast<ShapedType>();
|
|
||||||
int src_rank = src_type.getRank();
|
int src_rank = src_type.getRank();
|
||||||
if (!src_rank) {
|
if (!src_rank) {
|
||||||
return rewriter.notifyMatchFailure(op, "expects known-rank args");
|
return rewriter.notifyMatchFailure(op, "expects known-rank args");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if init_value is constant. If so, inline the value into the region.
|
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
|
||||||
Value init_value = adaptor.init_values()[0];
|
|
||||||
Attribute init_const_val = GetInitValueAsConst(init_value);
|
SmallVector<Value> inputs, outputs;
|
||||||
if (init_const_val) {
|
SmallVector<AffineMap, 3> indexing_maps;
|
||||||
init_value = rewriter.create<ConstantOp>(
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
init_value.getDefiningOp()->getLoc(), init_const_val);
|
Value src = adaptor.inputs()[i];
|
||||||
} else {
|
if (src.getType() != src_type) return failure();
|
||||||
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
|
||||||
|
// Check if init_value is constant. If so, inline the value into the
|
||||||
|
// region.
|
||||||
|
Value init_value = adaptor.init_values()[i];
|
||||||
|
Attribute init_const_val = GetInitValueAsConst(init_value);
|
||||||
|
if (init_const_val) {
|
||||||
|
init_value = rewriter.create<ConstantOp>(
|
||||||
|
init_value.getDefiningOp()->getLoc(), init_const_val);
|
||||||
|
} else {
|
||||||
|
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs.push_back(src);
|
||||||
|
auto result_type = op.getResult(i).getType().cast<ShapedType>();
|
||||||
|
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
||||||
|
rewriter, loc, src, result_type, reduction_dims);
|
||||||
|
auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape);
|
||||||
|
Value filled_tensor =
|
||||||
|
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
||||||
|
.result();
|
||||||
|
outputs.push_back(filled_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare indexing maps for linalg generic op. The elements are for src and
|
// Prepare indexing maps for linalg generic op. The elements are for src
|
||||||
// dst. Transpose `src` to make the reduction loops be the innermost,
|
// and dst. Transpose `src` to make the reduction loops be the innermost,
|
||||||
// because it's easier to fully utilize processors.
|
// because it's easier to fully utilize processors.
|
||||||
SmallVector<AffineMap, 3> indexing_maps;
|
indexing_maps.append(
|
||||||
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
|
num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank,
|
||||||
indexing_maps.emplace_back(GetTransposeMapForReduction(
|
reduction_dims));
|
||||||
rewriter.getContext(), src_rank, reduction_dims));
|
|
||||||
|
|
||||||
// The indexing map of `dst` should drop the reduction loops. Since the
|
// The indexing map of `dst` should drop the reduction loops. Since the
|
||||||
// reduction loops now are all in the innermost, drops
|
// reduction loops now are all in the innermost, drops
|
||||||
// `reduction_dims.size()` dimensions. We don't need an inverse permutation
|
// `reduction_dims.size()` dimensions. We don't need an inverse
|
||||||
// here because they are the same.
|
// permutation here because they are the same.
|
||||||
SmallVector<AffineExpr, 4> exprs;
|
SmallVector<AffineExpr, 4> exprs;
|
||||||
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
|
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
|
||||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
|
indexing_maps.append(num_inputs,
|
||||||
exprs, rewriter.getContext()));
|
AffineMap::get(src_rank, /*symbolCount=*/0, exprs,
|
||||||
|
rewriter.getContext()));
|
||||||
SmallVector<Value, 2> inputs = {adaptor.inputs()[0]};
|
|
||||||
Type result_type = op.getResult(0).getType();
|
|
||||||
auto shaped_type = result_type.cast<ShapedType>();
|
|
||||||
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
|
||||||
rewriter, loc, adaptor.inputs()[0], result_type.cast<ShapedType>(),
|
|
||||||
reduction_dims);
|
|
||||||
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
|
||||||
Value filled_tensor =
|
|
||||||
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
||||||
/*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
|
/*outputBuffers=*/ValueRange{outputs}, indexing_maps,
|
||||||
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
||||||
|
|
||||||
// Convert the signature of the body. The reduce op region apply function
|
// Convert the signature of the body. The reduce op region apply function
|
||||||
|
@ -1579,10 +1586,10 @@ class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
||||||
// be converted to "(f32, f32, f32)".
|
// be converted to "(f32, f32, f32)".
|
||||||
Region& region = linalg_op.region();
|
Region& region = linalg_op.region();
|
||||||
rewriter.inlineRegionBefore(op.body(), region, region.end());
|
rewriter.inlineRegionBefore(op.body(), region, region.end());
|
||||||
TypeConverter::SignatureConversion signatureConverter(2);
|
TypeConverter::SignatureConversion signature_converter(num_inputs * 2);
|
||||||
signatureConverter.addInputs(0, src_type.getElementType());
|
for (int i = 0; i < num_inputs * 2; ++i)
|
||||||
signatureConverter.addInputs(1, src_type.getElementType());
|
signature_converter.addInputs(i, src_type.getElementType());
|
||||||
rewriter.applySignatureConversion(®ion, signatureConverter);
|
rewriter.applySignatureConversion(®ion, signature_converter);
|
||||||
rewriter.replaceOp(op, linalg_op.getResults());
|
rewriter.replaceOp(op, linalg_op.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2272,6 +2279,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::AndOp>,
|
ReduceRegionXLAOpConversion<mhlo::AndOp>,
|
||||||
ReduceRegionXLAOpConversion<mhlo::OrOp>,
|
ReduceRegionXLAOpConversion<mhlo::OrOp>,
|
||||||
|
ReduceRegionXLAOpConversion<mhlo::SelectOp>,
|
||||||
|
ReduceRegionXLAOpConversion<mhlo::CompareOp>,
|
||||||
ReduceRegionReturnOpConversion>(context);
|
ReduceRegionReturnOpConversion>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1506,6 +1506,49 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
|
||||||
|
%cst0 = mhlo.constant dense<-2147483648> : tensor<i32>
|
||||||
|
%cst1 = mhlo.constant dense<0> : tensor<i32>
|
||||||
|
%res0, %res1 = "mhlo.reduce"(%arg0, %arg1, %cst0, %cst1) ( {
|
||||||
|
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg15: tensor<i32>, %arg16: tensor<i32>): // no predecessors
|
||||||
|
%669 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
%670 = "mhlo.select"(%669, %arg2, %arg15) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
%671 = "mhlo.compare"(%arg2, %arg15) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
%672 = mhlo.minimum %arg3, %arg16 : tensor<i32>
|
||||||
|
%673 = "mhlo.select"(%669, %arg3, %arg16) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
%674 = "mhlo.select"(%671, %672, %673) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
"mhlo.return"(%670, %674) : (tensor<i32>, tensor<i32>) -> ()
|
||||||
|
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>)
|
||||||
|
return %res0, %res1 : tensor<2xi32>, tensor<2xi32>
|
||||||
|
}
|
||||||
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||||
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
|
// CHECK: func @variadic_reduce
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[CST0:.*]] = constant -2147483648 : i32
|
||||||
|
// CHECK: %[[INIT0:.*]] = linalg.init_tensor [2] : tensor<2xi32>
|
||||||
|
// CHECK: %[[FILL0:.*]] = linalg.fill(%[[INIT0]], %[[CST0]])
|
||||||
|
// CHECK: %[[CST1:.*]] = constant 0 : i32
|
||||||
|
// CHECK: %[[INIT1:.*]] = linalg.init_tensor [2] : tensor<2xi32>
|
||||||
|
// CHECK: %[[FILL1:.*]] = linalg.fill(%[[INIT1]], %[[CST1]])
|
||||||
|
// CHECK: %[[RES:.+]]:2 = linalg.generic {
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]]
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>)
|
||||||
|
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[T1:.*]] = cmpi sge, %[[IN0]], %[[OUT0]] : i32
|
||||||
|
// CHECK-NEXT: %[[T2:.*]] = select %[[T1]], %[[IN0]], %[[OUT0]] : i32
|
||||||
|
// CHECK-NEXT: %[[T3:.*]] = cmpi eq, %[[IN0]], %[[OUT0]] : i32
|
||||||
|
// CHECK-NEXT: %[[T4:.*]] = cmpi slt, %[[IN1]], %[[OUT1]] : i32
|
||||||
|
// CHECK-NEXT: %[[T5:.*]] = select %[[T4]], %[[IN1]], %[[OUT1]] : i32
|
||||||
|
// CHECK-NEXT: %[[T6:.*]] = select %[[T1]], %[[IN1]], %[[OUT1]] : i32
|
||||||
|
// CHECK-NEXT: %[[T7:.*]] = select %[[T3]], %[[T5]], %[[T6]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[T2]], %[[T7]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||||
%0 = "mhlo.slice"(%arg0) {
|
%0 = "mhlo.slice"(%arg0) {
|
||||||
start_indices = dense<[1, 0]> : tensor<2xi64>,
|
start_indices = dense<[1, 0]> : tensor<2xi64>,
|
||||||
|
|
Loading…
Reference in New Issue