[MLIR][KernelGen] Add erf kernel for f32 arguments and missing lowerings
PiperOrigin-RevId: 352381016
This commit is contained in:
parent
b5dc600860
commit
3763740910
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#define _USE_MATH_DEFINES
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
|||
}
|
||||
};
|
||||
|
||||
Value MaterializePolynomialApproximation(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x,
|
||||
const std::vector<float> &coefficients) {
|
||||
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
|
||||
for (float c : coefficients) {
|
||||
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
|
||||
poly = rewriter.create<mhlo::AddOp>(
|
||||
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
|
||||
}
|
||||
return poly;
|
||||
}
|
||||
|
||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand) {
|
||||
const std::vector<float> kAlpha{
|
||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||
-1.60960333262415e-02f,
|
||||
};
|
||||
const std::vector<float> kBeta{
|
||||
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
|
||||
-7.37332916720468e-03f, -1.42647390514189e-02f,
|
||||
};
|
||||
|
||||
// Clamp argument between -4 and 4.
|
||||
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
|
||||
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
|
||||
Value x =
|
||||
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
|
||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||
|
||||
// Materialize polynomial approximation for x in [-4, 4].
|
||||
Value alpha_poly =
|
||||
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
|
||||
Value beta_poly =
|
||||
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
|
||||
Value mul_x_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
|
||||
return rewriter.create<mhlo::DivOp>(loc, mul_x_alpha_poly, beta_poly);
|
||||
}
|
||||
|
||||
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ErfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type ty = getElementTypeOrSelf(op.getType());
|
||||
|
||||
// For now, we support only f32.
|
||||
if (!ty.isF32()) return failure();
|
||||
|
||||
ErfOp::Adaptor transformed(operands);
|
||||
rewriter.replaceOp(op, MaterializeErfApproximationF32(
|
||||
rewriter, op.getLoc(), transformed.operand()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
|
@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp>(context);
|
||||
patterns->insert<ConvertConstantLikeOp, ConvertErfOp>(context);
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
|
|
|
@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
|
|||
return %1 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @erf_f32
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
|
||||
func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00>
|
||||
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]])
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]]
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]]
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6>
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5>
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]]
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4>
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03>
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]]
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5>
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4>
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702>
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925>
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]]
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]]
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue