PR #42509: [MLIR] Add folder for mhlo get_dimension_size

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42509

Add folder for mhlo GetDimensionSizeOp (`mhlo.get_dimension_size`).
`get_dimension_size` folds to a constant when the corresponding tensor
dimension size is statically known / constant.
Copybara import of the project:

--
5994915525ec2e932125aa1f133ce2260ba100af by Uday Bondhugula <uday@polymagelabs.com>:

[MLIR] Add folder for mhlo get_dimension_size

Add folder for mhlo GetDimensionSizeOp. get_dimension_size folds to a
constant when the corresponding tensor dimension size is statically
known / constant.

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42509 from polymage-labs:get_dimension_size_fold 5994915525ec2e932125aa1f133ce2260ba100af
PiperOrigin-RevId: 328222517
This commit is contained in:
Uday Bondhugula 2020-08-24 15:35:37 -07:00 committed by TensorFlow MLIR Team
parent 73b5a44f33
commit 94296bb7ec
3 changed files with 24 additions and 0 deletions

View File

@ -1079,6 +1079,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
// XLA semantics is available. This limitation is because of the current XLA
// implementation.
let results = (outs I32Tensor);
let hasFolder = 1;
}
def HLO_MapOp: HLO_Op<"map",

View File

@ -165,6 +165,20 @@ static LogicalResult Verify(DotGeneralOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// GetDimensionSizeOp
//===----------------------------------------------------------------------===//
/// Fold get_dimension_size when the said shape dimension is a constant.
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
RankedTensorType type = operand().getType().cast<RankedTensorType>();
int32_t dim = dimension().getSExtValue();
if (type.isDynamic(dim)) return {};
// The result type is always is a 0-d i32 tensor.
return DenseIntElementsAttr::get<int32_t>(
getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
}
//===----------------------------------------------------------------------===//
// IotaOp
//===----------------------------------------------------------------------===//

View File

@ -618,3 +618,11 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
return %N : memref<256x1024xf32>
}
// CHECK-LABEL: func @fold_get_dimension_size
func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
return %size : tensor<i32>
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
// CHECK-NEXT: return %[[C]]
}