diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index d0abbe0..12b9f5a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1079,6 +1079,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, // XLA semantics is available. This limitation is because of the current XLA // implementation. let results = (outs I32Tensor); + + let hasFolder = 1; } def HLO_MapOp: HLO_Op<"map", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f5deb94..b0fa4ce 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -165,6 +165,20 @@ static LogicalResult Verify(DotGeneralOp op) { return success(); } +//===----------------------------------------------------------------------===// +// GetDimensionSizeOp +//===----------------------------------------------------------------------===// + +/// Fold get_dimension_size when the said shape dimension is a constant. +OpFoldResult GetDimensionSizeOp::fold(ArrayRef attrs) { + RankedTensorType type = operand().getType().cast(); + int32_t dim = dimension().getSExtValue(); + if (type.isDynamic(dim)) return {}; + // The result type is always is a 0-d i32 tensor. + return DenseIntElementsAttr::get( + getResult().getType().cast(), type.getDimSize(dim)); +} + //===----------------------------------------------------------------------===// // IotaOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 0d20c3f..0b6cc1c 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -618,3 +618,11 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () return %N : memref<256x1024xf32> } + +// CHECK-LABEL: func @fold_get_dimension_size +func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor { + %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor + return %size : tensor + // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor + // CHECK-NEXT: return %[[C]] +}