PR #40745: [MLIR] Add constant folder for xla_hlo.broadcast_in_dim op
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40745 Fold broadcast_in_dim op if the operand is the result of a tensor splat. Copybara import of the project: -- 26c9f631448b8d6ffd20ece39ea8d4132b5550c7 by Uday Bondhugula <uday@polymagelabs.com>: [MLIR] Add constant folder for xla_hlo.broadcast_in_dim op Fold broadcast_in_dim op if the operand is the result of a tensor splat. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/40745 from polymage-labs:broadcast_in_dim_fold 26c9f631448b8d6ffd20ece39ea8d4132b5550c7 PiperOrigin-RevId: 320365164
This commit is contained in:
parent
506ddd9c4a
commit
de0578b4f9
|
@ -620,17 +620,25 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute>) {
|
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
|
||||||
auto type = getType().cast<RankedTensorType>();
|
auto type = getType().cast<RankedTensorType>();
|
||||||
if (type != getOperand().getType()) {
|
if (type == getOperand().getType()) {
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
|
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
|
||||||
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
|
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
|
||||||
llvm::seq<int64_t>(0, type.getRank()).begin())) {
|
llvm::seq<int64_t>(0, type.getRank()).begin())) {
|
||||||
return nullptr;
|
return {};
|
||||||
}
|
}
|
||||||
return getOperand();
|
return getOperand();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant fold when an operand is a splat tensor attribute.
|
||||||
|
if (!attrs[0] || !type.hasStaticShape()) return {};
|
||||||
|
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
|
||||||
|
if (!splatOperandAttr) return {};
|
||||||
|
// MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
|
||||||
|
// attribute iterator not implemented for complex element types.
|
||||||
|
if (type.getElementType().isa<ComplexType>()) return {};
|
||||||
|
return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -365,6 +365,24 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
|
||||||
return %0 : tensor<5x4xf32>
|
return %0 : tensor<5x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
|
||||||
|
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
|
||||||
|
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x64x224x224xf32>
|
||||||
|
return %b : tensor<1x64x224x224xf32>
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32>
|
||||||
|
// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @broadcast_in_dim_constant_fold
|
||||||
|
func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> {
|
||||||
|
%cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32>
|
||||||
|
%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32>
|
||||||
|
return %b : tensor<1x64x4x4xf32>
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32>
|
||||||
|
// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: @complex_expand_fold
|
// CHECK-LABEL: @complex_expand_fold
|
||||||
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||||
|
|
|
@ -10,8 +10,7 @@ func @batchNormInference_2D_inner_features(
|
||||||
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||||
-> (tensor<4x256xf32>) {
|
-> (tensor<4x256xf32>) {
|
||||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
|
// CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32>
|
||||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
|
||||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||||
|
@ -51,7 +50,7 @@ func @batchNormInference_4D_middle_features(
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @batchNormInference_f64
|
// CHECK-LABEL: @batchNormInference_f64
|
||||||
// Validate that epsilon is properly promoted to f64
|
// Validate that epsilon is properly promoted to f64
|
||||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf64>
|
||||||
func @batchNormInference_f64(
|
func @batchNormInference_f64(
|
||||||
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
|
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
|
||||||
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
|
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
|
||||||
|
@ -66,7 +65,7 @@ func @batchNormInference_f64(
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @batchNormInference_f16
|
// CHECK-LABEL: @batchNormInference_f16
|
||||||
// Validate that epsilon is properly promoted to f64
|
// Validate that epsilon is properly promoted to f64
|
||||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
|
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf16>
|
||||||
func @batchNormInference_f16(
|
func @batchNormInference_f16(
|
||||||
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
|
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
|
||||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||||
|
|
Loading…
Reference in New Issue