[MLIR][KernelGen] Add erf kernel for f32 arguments and missing lowerings

PiperOrigin-RevId: 352381016
This commit is contained in:
A. Unique TensorFlower 2021-01-18 03:34:19 -08:00 committed by TensorFlow MLIR Team
parent b5dc600860
commit 3763740910
2 changed files with 110 additions and 1 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#define _USE_MATH_DEFINES
#include <cmath>
#include <numeric>
#include <vector>
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
}
};
Value MaterializePolynomialApproximation(
ConversionPatternRewriter &rewriter, Location loc, Value x,
const std::vector<float> &coefficients) {
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
for (float c : coefficients) {
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
poly = rewriter.create<mhlo::AddOp>(
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
}
return poly;
}
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) {
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
const std::vector<float> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
// Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
Value x =
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4].
Value alpha_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
Value beta_poly =
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
Value mul_x_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
return rewriter.create<mhlo::DivOp>(loc, mul_x_alpha_poly, beta_poly);
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type ty = getElementTypeOrSelf(op.getType());
// For now, we support only f32.
if (!ty.isF32()) return failure();
ErfOp::Adaptor transformed(operands);
rewriter.replaceOp(op, MaterializeErfApproximationF32(
rewriter, op.getLoc(), transformed.operand()));
return success();
}
};
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
context, patterns, 5);
// Other patterns.
patterns->insert<ConvertConstantLikeOp>(context);
patterns->insert<ConvertConstantLikeOp, ConvertErfOp>(context);
}
} // namespace chlo

View File

@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
return %1 : tensor<3xcomplex<f32>>
}
// CHECK-LABEL: @erf_f32
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00>
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00>
// CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]])
// CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]]
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]]
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10>
// CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]]
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]]
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8>
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]]
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]]
// CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6>
// CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]]
// CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]]
// CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5>
// CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]]
// CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]]
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4>
// CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]]
// CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]]
// CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03>
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]]
// CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]]
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332>
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]]
// CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]]
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5>
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]]
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4>
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]]
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702>
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]]
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925>
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
// CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]]
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391>
// CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]]
// CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]]
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]]
// CHECK: return %[[RESULT]]
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}