[MLIR] Lower mhlo.clamp to linalg
PiperOrigin-RevId: 351998800
This commit is contained in:
parent
2af0dcd3e7
commit
bcdb3c3548
|
@ -43,6 +43,7 @@ MAP_HLO_TO_LHLO(AndOp);
|
||||||
MAP_HLO_TO_LHLO(Atan2Op);
|
MAP_HLO_TO_LHLO(Atan2Op);
|
||||||
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
||||||
MAP_HLO_TO_LHLO(CeilOp);
|
MAP_HLO_TO_LHLO(CeilOp);
|
||||||
|
MAP_HLO_TO_LHLO(ClampOp);
|
||||||
MAP_HLO_TO_LHLO(ConstOp);
|
MAP_HLO_TO_LHLO(ConstOp);
|
||||||
MAP_HLO_TO_LHLO(CompareOp);
|
MAP_HLO_TO_LHLO(CompareOp);
|
||||||
MAP_HLO_TO_LHLO(ComplexOp);
|
MAP_HLO_TO_LHLO(ComplexOp);
|
||||||
|
|
|
@ -463,6 +463,23 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||||
args, b);
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
assert(args.size() == 3 && "expected 3 arguments");
|
||||||
|
Value lb = args[0];
|
||||||
|
Value x = args[1];
|
||||||
|
Value ub = args[2];
|
||||||
|
|
||||||
|
// clamp(lb, x, ub) = max(min(x, ub), lb)
|
||||||
|
Value min_x_ub =
|
||||||
|
MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b);
|
||||||
|
return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb},
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -1227,6 +1227,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
||||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::ClampOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
||||||
|
@ -1349,6 +1350,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::ClampOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||||
|
|
|
@ -892,3 +892,24 @@ func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||||
// CHECK: linalg.batch_matmul
|
// CHECK: linalg.batch_matmul
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
|
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @clamp
|
||||||
|
// CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32>
|
||||||
|
func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>)
|
||||||
|
-> tensor<4xf32> {
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor
|
||||||
|
// CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>)
|
||||||
|
// CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32):
|
||||||
|
// CHECK: %[[LT_X_UB:.*]] = cmpf olt, %[[SCALAR_X]], %[[SCALAR_UB]]
|
||||||
|
// CHECK: %[[X2:.*]] = select %[[LT_X_UB]], %[[SCALAR_X]], %[[SCALAR_UB]]
|
||||||
|
// CHECK: %[[GT_X2_LB:.*]] = cmpf ogt, %[[X2]], %[[SCALAR_LB]]
|
||||||
|
// CHECK: %[[MAX_X2_LB:.*]] = select %[[GT_X2_LB]], %[[X2]], %[[SCALAR_LB]]
|
||||||
|
// CHECK: linalg.yield %[[MAX_X2_LB]]
|
||||||
|
// CHECK: } -> tensor<4xf32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<4xf32>
|
||||||
|
%0 = "mhlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<4xf32>,
|
||||||
|
tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
return %0 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue