Canonicalize multidimensional iota to use broadcast

There is no reason to have a multidimensional iota for codegen.
This should be canonicalized to a single dimensional iota followed
by a broadcast. Changing iota to on a single dimension  and a broadcast
substantially simplifies implementing iota operations.

PiperOrigin-RevId: 320095470
This commit is contained in:
Robert Suderman 2020-07-08 00:35:27 +00:00 committed by Mehdi Amini
parent 8900222fed
commit e1651b6090
3 changed files with 130 additions and 0 deletions

View File

@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
// TODO(b/130357376): Iota has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
let hasCanonicalizer = 1;
}
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {

View File

@ -214,6 +214,41 @@ static LogicalResult Verify(IotaOp op) {
return success();
}
// Iota operations across multiple dimensions can be reduced to an iota and a
// ranked broadcast.
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
using OpRewritePattern<IotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
result_ty.getElementType());
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
broadcast_attr);
return success();
}
};
void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IotaBroadcast>(context);
}
//===----------------------------------------------------------------------===//
// DynamicIotaOp
//===----------------------------------------------------------------------===//
@ -235,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
}
};
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
// and a ranked broadcast.
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_dimension_int = iota_dimension.getLimitedValue();
auto converted_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
iota.output_shape().getType().cast<ShapedType>().getShape(),
rewriter.getI64Type()),
iota.output_shape());
auto sliced_shape = rewriter.create<SliceOp>(
iota.getLoc(), converted_shape,
GetI64ElementsAttr(iota_dimension_int, &rewriter),
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
GetI64ElementsAttr(1, &rewriter));
auto converted_sliced_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
{1},
iota.output_shape().getType().cast<ShapedType>().getElementType()),
sliced_shape);
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
auto new_iota = rewriter.create<DynamicIotaOp>(
iota.getLoc(), iota_type, converted_sliced_shape,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
return success();
}
};
} // namespace
void DynamicIotaOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicIotaIsStatic>(context);
results.insert<DynamicIotaBroadcast>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -391,6 +391,30 @@ func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @dynamic_iota_broadcast
func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
// CHECK: return [[BROADCAST]]
return %0 : tensor<5x?xi32>
}
// CHECK-LABEL: @dynamic_iota_broadcast_second
func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
// CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
// CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
// CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex>
// CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<?xi32>
// CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
// CHECK: return [[BROADCAST]]
return %0 : tensor<5x?xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
@ -399,6 +423,24 @@ func @iota_not_lowered_to_constant() -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @iota_broadcast
func @iota_broadcast() -> tensor<5x4xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32>
return %0 : tensor<5x4xi32>
}
// CHECK-LABEL: @iota_broadcast
func @iota_broadcast_second() -> tensor<5x4xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32>
return %0 : tensor<5x4xi32>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>