[MLIR] Add mhlo.logistic lowering to linalg
PiperOrigin-RevId: 353205440
This commit is contained in:
parent
c846f925d4
commit
ef8ccdaebc
|
@ -114,6 +114,8 @@ def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFinit
|
||||||
|
|
||||||
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
|
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
|
||||||
|
|
||||||
|
def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp;
|
||||||
|
|
||||||
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
|
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
|
||||||
|
|
||||||
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
|
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
|
||||||
|
|
|
@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(ImagOp);
|
||||||
MAP_HLO_TO_LHLO(IotaOp);
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
MAP_HLO_TO_LHLO(IsFiniteOp);
|
MAP_HLO_TO_LHLO(IsFiniteOp);
|
||||||
MAP_HLO_TO_LHLO(LogOp);
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
|
MAP_HLO_TO_LHLO(LogisticOp);
|
||||||
MAP_HLO_TO_LHLO(Log1pOp);
|
MAP_HLO_TO_LHLO(Log1pOp);
|
||||||
MAP_HLO_TO_LHLO(MaxOp);
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
MAP_HLO_TO_LHLO(MinOp);
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
|
|
|
@ -437,6 +437,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
auto ty = result_types.front().cast<FloatType>();
|
||||||
|
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
|
||||||
|
Value x = args.front();
|
||||||
|
Value neg_x = b->create<NegFOp>(loc, x);
|
||||||
|
Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x);
|
||||||
|
Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x);
|
||||||
|
return b->create<DivFOp>(loc, one, one_add_exp_neg_x);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -233,6 +233,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
if (!op.getType().isa<RankedTensorType>()) return failure();
|
||||||
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||||
|
|
||||||
if (insert_copy_) {
|
if (insert_copy_) {
|
||||||
|
|
|
@ -474,7 +474,8 @@ class HloDynamicBroadcastInDimConverter
|
||||||
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
||||||
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
||||||
}
|
}
|
||||||
auto result_type = op.getType().cast<RankedTensorType>();
|
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_type) return failure();
|
||||||
|
|
||||||
int64_t nloops = result_type.getRank();
|
int64_t nloops = result_type.getRank();
|
||||||
Value init = rewriter.create<linalg::InitTensorOp>(
|
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
@ -1240,6 +1241,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::LogisticOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::Log1pOp>,
|
PointwiseToLinalgConverter<lmhlo::Log1pOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||||
|
@ -1364,6 +1366,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
|
PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||||
|
|
|
@ -176,6 +176,22 @@ func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @float_logistic
|
||||||
|
func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32):
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00
|
||||||
|
// CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]]
|
||||||
|
// CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]]
|
||||||
|
// CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]]
|
||||||
|
// CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]]
|
||||||
|
// CHECK: linalg.yield %[[RESULT]]
|
||||||
|
%0 = "mhlo.logistic"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_ceil
|
// CHECK-LABEL: func @float_ceil
|
||||||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
Loading…
Reference in New Issue