Canonicalize multidimensional iota to use broadcast
There is no reason to have a multidimensional iota for codegen. This should be canonicalized to a single dimensional iota followed by a broadcast. Changing iota to on a single dimension and a broadcast substantially simplifies implementing iota operations. PiperOrigin-RevId: 320095470
This commit is contained in:
parent
8900222fed
commit
e1651b6090
|
@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
|
||||||
|
|
||||||
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
||||||
|
|
|
@ -214,6 +214,41 @@ static LogicalResult Verify(IotaOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Iota operations across multiple dimensions can be reduced to an iota and a
|
||||||
|
// ranked broadcast.
|
||||||
|
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
|
||||||
|
using OpRewritePattern<IotaOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(IotaOp iota,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto result_ty = iota.getType().cast<ShapedType>();
|
||||||
|
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iota_dimension = iota.iota_dimension();
|
||||||
|
|
||||||
|
auto iota_type = RankedTensorType::get(
|
||||||
|
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
|
||||||
|
result_ty.getElementType());
|
||||||
|
|
||||||
|
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
|
||||||
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
|
||||||
|
auto broadcast_attr = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
|
||||||
|
{iota_dimension});
|
||||||
|
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
|
||||||
|
broadcast_attr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<IotaBroadcast>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DynamicIotaOp
|
// DynamicIotaOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -235,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
|
||||||
|
// and a ranked broadcast.
|
||||||
|
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
|
||||||
|
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(DynamicIotaOp iota,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto result_ty = iota.getType().cast<ShapedType>();
|
||||||
|
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iota_dimension = iota.iota_dimension();
|
||||||
|
auto iota_dimension_int = iota_dimension.getLimitedValue();
|
||||||
|
|
||||||
|
auto converted_shape = rewriter.create<IndexCastOp>(
|
||||||
|
iota.getLoc(),
|
||||||
|
RankedTensorType::get(
|
||||||
|
iota.output_shape().getType().cast<ShapedType>().getShape(),
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
iota.output_shape());
|
||||||
|
|
||||||
|
auto sliced_shape = rewriter.create<SliceOp>(
|
||||||
|
iota.getLoc(), converted_shape,
|
||||||
|
GetI64ElementsAttr(iota_dimension_int, &rewriter),
|
||||||
|
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
|
||||||
|
GetI64ElementsAttr(1, &rewriter));
|
||||||
|
|
||||||
|
auto converted_sliced_shape = rewriter.create<IndexCastOp>(
|
||||||
|
iota.getLoc(),
|
||||||
|
RankedTensorType::get(
|
||||||
|
{1},
|
||||||
|
iota.output_shape().getType().cast<ShapedType>().getElementType()),
|
||||||
|
sliced_shape);
|
||||||
|
|
||||||
|
auto iota_type = RankedTensorType::get(
|
||||||
|
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
|
||||||
|
|
||||||
|
auto new_iota = rewriter.create<DynamicIotaOp>(
|
||||||
|
iota.getLoc(), iota_type, converted_sliced_shape,
|
||||||
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
|
||||||
|
auto broadcast_attr = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
|
||||||
|
{iota_dimension});
|
||||||
|
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
|
||||||
|
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void DynamicIotaOp::getCanonicalizationPatterns(
|
void DynamicIotaOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<DynamicIotaIsStatic>(context);
|
results.insert<DynamicIotaIsStatic>(context);
|
||||||
|
results.insert<DynamicIotaBroadcast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -391,6 +391,30 @@ func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
|
||||||
return %0 : tensor<4xi32>
|
return %0 : tensor<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_iota_broadcast
|
||||||
|
func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||||
|
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
|
||||||
|
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
|
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
|
|
||||||
|
// CHECK: return [[BROADCAST]]
|
||||||
|
return %0 : tensor<5x?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dynamic_iota_broadcast_second
|
||||||
|
func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||||
|
// CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
|
||||||
|
// CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
|
||||||
|
// CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex>
|
||||||
|
// CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<?xi32>
|
||||||
|
// CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
|
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||||
|
|
||||||
|
// CHECK: return [[BROADCAST]]
|
||||||
|
return %0 : tensor<5x?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @iota_not_lowered_to_constant
|
// CHECK-LABEL: @iota_not_lowered_to_constant
|
||||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||||
|
@ -399,6 +423,24 @@ func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||||
return %0 : tensor<4xi32>
|
return %0 : tensor<4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @iota_broadcast
|
||||||
|
func @iota_broadcast() -> tensor<5x4xi32> {
|
||||||
|
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
|
||||||
|
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32>
|
||||||
|
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32>
|
||||||
|
|
||||||
|
return %0 : tensor<5x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @iota_broadcast
|
||||||
|
func @iota_broadcast_second() -> tensor<5x4xi32> {
|
||||||
|
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32>
|
||||||
|
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32>
|
||||||
|
|
||||||
|
return %0 : tensor<5x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @unary_einsum
|
// CHECK-LABEL: @unary_einsum
|
||||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
|
Loading…
Reference in New Issue