PR #40745: [MLIR] Add constant folder for xla_hlo.broadcast_in_dim op

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40745

Fold broadcast_in_dim op if the operand is the result of a tensor splat.
Copybara import of the project:

--
26c9f631448b8d6ffd20ece39ea8d4132b5550c7 by Uday Bondhugula <uday@polymagelabs.com>:

[MLIR] Add constant folder for xla_hlo.broadcast_in_dim op

Fold broadcast_in_dim op if the operand is the result of a tensor
splat.

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/40745 from polymage-labs:broadcast_in_dim_fold 26c9f631448b8d6ffd20ece39ea8d4132b5550c7
PiperOrigin-RevId: 320365164
This commit is contained in:
Uday Bondhugula 2020-07-09 10:36:37 +00:00 committed by Mehdi Amini
parent 506ddd9c4a
commit de0578b4f9
3 changed files with 38 additions and 13 deletions

View File

@ -620,19 +620,27 @@ static LogicalResult Verify(BroadcastInDimOp op) {
return success(); return success();
} }
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute>) { OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
auto type = getType().cast<RankedTensorType>(); auto type = getType().cast<RankedTensorType>();
if (type != getOperand().getType()) { if (type == getOperand().getType()) {
return nullptr;
}
auto broadcast_values = broadcast_dimensions().getValues<int64_t>(); auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
if (!std::equal(broadcast_values.begin(), broadcast_values.end(), if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
llvm::seq<int64_t>(0, type.getRank()).begin())) { llvm::seq<int64_t>(0, type.getRank()).begin())) {
return nullptr; return {};
} }
return getOperand(); return getOperand();
} }
// Constant fold when an operand is a splat tensor attribute.
if (!attrs[0] || !type.hasStaticShape()) return {};
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
if (!splatOperandAttr) return {};
// MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
// attribute iterator not implemented for complex element types.
if (type.getElementType().isa<ComplexType>()) return {};
return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DynamicBroadcastInDimOp // DynamicBroadcastInDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -365,6 +365,24 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
return %0 : tensor<5x4xf32> return %0 : tensor<5x4xf32>
} }
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>
%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x64x224x224xf32>
return %b : tensor<1x64x224x224xf32>
}
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32>
// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32>
// CHECK-LABEL: func @broadcast_in_dim_constant_fold
func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32>
%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32>
return %b : tensor<1x64x4x4xf32>
}
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32>
// CHECK-LABEL: @complex_expand_fold // CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>) %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)

View File

@ -10,8 +10,7 @@ func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>) %mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) { -> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32> // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
@ -51,7 +50,7 @@ func @batchNormInference_4D_middle_features(
// ----- // -----
// CHECK-LABEL: @batchNormInference_f64 // CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64 // Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf64>
func @batchNormInference_f64( func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>) %mean: tensor<256xf64>, %variance: tensor<256xf64>)
@ -66,7 +65,7 @@ func @batchNormInference_f64(
// ----- // -----
// CHECK-LABEL: @batchNormInference_f16 // CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64 // Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf16>
func @batchNormInference_f16( func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>) %mean: tensor<256xf16>, %variance: tensor<256xf16>)