[MLIR] Add mhlo.logistic lowering to linalg
PiperOrigin-RevId: 353205440
This commit is contained in:
		
							parent
							
								
									c846f925d4
								
							
						
					
					
						commit
						ef8ccdaebc
					
				|  | @ -114,6 +114,8 @@ def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFinit | |||
| 
 | ||||
| def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp; | ||||
| 
 | ||||
| def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp; | ||||
| 
 | ||||
| def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp; | ||||
| 
 | ||||
| def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; | ||||
|  |  | |||
|  | @ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(ImagOp); | |||
| MAP_HLO_TO_LHLO(IotaOp); | ||||
| MAP_HLO_TO_LHLO(IsFiniteOp); | ||||
| MAP_HLO_TO_LHLO(LogOp); | ||||
| MAP_HLO_TO_LHLO(LogisticOp); | ||||
| MAP_HLO_TO_LHLO(Log1pOp); | ||||
| MAP_HLO_TO_LHLO(MaxOp); | ||||
| MAP_HLO_TO_LHLO(MinOp); | ||||
|  |  | |||
|  | @ -437,6 +437,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc, | |||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>( | ||||
|     Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, | ||||
|     OpBuilder* b) { | ||||
|   auto ty = result_types.front().cast<FloatType>(); | ||||
|   Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0)); | ||||
|   Value x = args.front(); | ||||
|   Value neg_x = b->create<NegFOp>(loc, x); | ||||
|   Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x); | ||||
|   Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x); | ||||
|   return b->create<DivFOp>(loc, one, one_add_exp_neg_x); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc, | ||||
|                                                     ArrayRef<Type> result_types, | ||||
|  |  | |||
|  | @ -233,6 +233,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter | |||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     if (!op.getType().isa<RankedTensorType>()) return failure(); | ||||
|     Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); | ||||
| 
 | ||||
|     if (insert_copy_) { | ||||
|  |  | |||
|  | @ -474,7 +474,8 @@ class HloDynamicBroadcastInDimConverter | |||
|       Value index = rewriter.create<ConstantIndexOp>(loc, i); | ||||
|       dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index)); | ||||
|     } | ||||
|     auto result_type = op.getType().cast<RankedTensorType>(); | ||||
|     auto result_type = op.getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!result_type) return failure(); | ||||
| 
 | ||||
|     int64_t nloops = result_type.getRank(); | ||||
|     Value init = rewriter.create<linalg::InitTensorOp>( | ||||
|  | @ -1240,6 +1241,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, | |||
|                    PointwiseToLinalgConverter<lmhlo::ImagOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::IsFiniteOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::LogOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::LogisticOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::Log1pOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::MaxOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::MinOp>, | ||||
|  | @ -1364,6 +1366,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | |||
|                PointwiseToLinalgConverter<mhlo::ImagOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::LogOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::LogisticOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::Log1pOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MaxOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MinOp, false>, | ||||
|  |  | |||
|  | @ -176,6 +176,22 @@ func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @float_logistic | ||||
| func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   // CHECK: linalg.generic | ||||
|   // CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32): | ||||
|   // CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00 | ||||
|   // CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]] | ||||
|   // CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]] | ||||
|   // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]] | ||||
|   // CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]] | ||||
|   // CHECK: linalg.yield %[[RESULT]] | ||||
|   %0 = "mhlo.logistic"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||||
|   return %0 : tensor<2x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @float_ceil | ||||
| func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   // CHECK: linalg.generic | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue