PR #42509: [MLIR] Add folder for mhlo get_dimension_size
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42509 Add folder for mhlo GetDimensionSizeOp (`mhlo.get_dimension_size`). `get_dimension_size` folds to a constant when the corresponding tensor dimension size is statically known / constant. Copybara import of the project: -- 5994915525ec2e932125aa1f133ce2260ba100af by Uday Bondhugula <uday@polymagelabs.com>: [MLIR] Add folder for mhlo get_dimension_size Add folder for mhlo GetDimensionSizeOp. get_dimension_size folds to a constant when the corresponding tensor dimension size is statically known / constant. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/42509 from polymage-labs:get_dimension_size_fold 5994915525ec2e932125aa1f133ce2260ba100af PiperOrigin-RevId: 328222517
This commit is contained in:
parent
73b5a44f33
commit
94296bb7ec
|
@ -1079,6 +1079,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
|||
// XLA semantics is available. This limitation is because of the current XLA
|
||||
// implementation.
|
||||
let results = (outs I32Tensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_MapOp: HLO_Op<"map",
|
||||
|
|
|
@ -165,6 +165,20 @@ static LogicalResult Verify(DotGeneralOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetDimensionSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Fold get_dimension_size when the said shape dimension is a constant.
|
||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
||||
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
||||
int32_t dim = dimension().getSExtValue();
|
||||
if (type.isDynamic(dim)) return {};
|
||||
// The result type is always is a 0-d i32 tensor.
|
||||
return DenseIntElementsAttr::get<int32_t>(
|
||||
getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IotaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -618,3 +618,11 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
|
|||
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
return %N : memref<256x1024xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_get_dimension_size
|
||||
func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
|
||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||
return %size : tensor<i32>
|
||||
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
||||
// CHECK-NEXT: return %[[C]]
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue