diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index e954917..dc755b1 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -20,6 +20,7 @@ limitations under the License. #define _USE_MATH_DEFINES #include #include +#include #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" @@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern { } }; +Value MaterializePolynomialApproximation( + ConversionPatternRewriter &rewriter, Location loc, Value x, + const std::vector &coefficients) { + Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x); + for (float c : coefficients) { + poly = rewriter.create(loc, x.getType(), poly, x); + poly = rewriter.create( + loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x)); + } + return poly; +} + +Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, + Location loc, Value operand) { + const std::vector kAlpha{ + -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, + -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, + -1.60960333262415e-02f, + }; + const std::vector kBeta{ + -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, + -7.37332916720468e-03f, -1.42647390514189e-02f, + }; + + // Clamp argument between -4 and 4. + Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand); + Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand); + Value x = + rewriter.create(loc, operand.getType(), lb, operand, ub); + Value x_sq = rewriter.create(loc, x, x); + + // Materialize polynomial approximation for x in [-4, 4]. + Value alpha_poly = + MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha); + Value beta_poly = + MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta); + Value mul_x_alpha_poly = rewriter.create(loc, x, alpha_poly); + return rewriter.create(loc, mul_x_alpha_poly, beta_poly); +} + +struct ConvertErfOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + ErfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type ty = getElementTypeOrSelf(op.getType()); + + // For now, we support only f32. + if (!ty.isF32()) return failure(); + + ErfOp::Adaptor transformed(operands); + rewriter.replaceOp(op, MaterializeErfApproximationF32( + rewriter, op.getLoc(), transformed.operand())); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns, 5); // Other patterns. - patterns->insert(context); + patterns->insert(context); } } // namespace chlo diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 42a1154..fbba761 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { return %1 : tensor<3xcomplex> } +// CHECK-LABEL: @erf_f32 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @erf_f32(%arg : tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00> + // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]]) + // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]] + // CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]] + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10> + // CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]] + // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]] + // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8> + // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] + // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]] + // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6> + // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] + // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]] + // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5> + // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] + // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]] + // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4> + // CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]] + // CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]] + // CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03> + // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]] + // CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]] + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332> + // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]] + // CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]] + // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5> + // CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]] + // CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]] + // CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4> + // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]] + // CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]] + // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702> + // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]] + // CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925> + // CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]] + // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]] + // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391> + // CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]] + // CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]] + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]] + // CHECK: return %[[RESULT]] + %1 = "chlo.erf"(%arg) : (tensor) -> tensor + return %1 : tensor +}