[MLIR] Lower mhlo.clamp to linalg

PiperOrigin-RevId: 351998800
This commit is contained in:
A. Unique TensorFlower 2021-01-15 06:44:28 -08:00 committed by TensorFlow MLIR Team
parent 2af0dcd3e7
commit bcdb3c3548
4 changed files with 42 additions and 1 deletions

View File

@ -43,6 +43,7 @@ MAP_HLO_TO_LHLO(AndOp);
MAP_HLO_TO_LHLO(Atan2Op);
MAP_HLO_TO_LHLO(BroadcastInDimOp);
MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ClampOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp);

View File

@ -463,6 +463,23 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
assert(args.size() == 3 && "expected 3 arguments");
Value lb = args[0];
Value x = args[1];
Value ub = args[2];
// clamp(lb, x, ub) = max(min(x, ub), lb)
Value min_x_ub =
MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b);
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb},
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -1227,6 +1227,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::AndOp>,
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
PointwiseToLinalgConverter<lmhlo::CeilOp>,
PointwiseToLinalgConverter<lmhlo::ClampOp>,
PointwiseToLinalgConverter<lmhlo::CompareOp>,
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
@ -1349,6 +1350,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::ClampOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,

View File

@ -892,3 +892,24 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
// -----
// CHECK-LABEL: @clamp
// CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32>
func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
-> tensor<4xf32> {
// CHECK: %[[INIT:.*]] = linalg.init_tensor
// CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
// CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]]
// CHECK: %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]]
// CHECK: %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]]
// CHECK: %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]]
// CHECK: linalg.yield %[[MAX_X2_LB]]
// CHECK: } -> tensor<4xf32>
// CHECK: return %[[RESULT]] : tensor<4xf32>
%0 = "mhlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<4xf32>,
tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}