[MLIR][KernelGen] Add erf kernel for f32 arguments and missing lowerings
PiperOrigin-RevId: 352381016
This commit is contained in:
parent
b5dc600860
commit
3763740910
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
@ -75,6 +76,63 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Value MaterializePolynomialApproximation(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value x,
|
||||||
|
const std::vector<float> &coefficients) {
|
||||||
|
Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
|
||||||
|
for (float c : coefficients) {
|
||||||
|
poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
|
||||||
|
poly = rewriter.create<mhlo::AddOp>(
|
||||||
|
loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
|
||||||
|
}
|
||||||
|
return poly;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value operand) {
|
||||||
|
const std::vector<float> kAlpha{
|
||||||
|
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||||
|
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||||
|
-1.60960333262415e-02f,
|
||||||
|
};
|
||||||
|
const std::vector<float> kBeta{
|
||||||
|
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
|
||||||
|
-7.37332916720468e-03f, -1.42647390514189e-02f,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Clamp argument between -4 and 4.
|
||||||
|
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
|
||||||
|
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
|
||||||
|
Value x =
|
||||||
|
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
|
||||||
|
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||||
|
|
||||||
|
// Materialize polynomial approximation for x in [-4, 4].
|
||||||
|
Value alpha_poly =
|
||||||
|
MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
|
||||||
|
Value beta_poly =
|
||||||
|
MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
|
||||||
|
Value mul_x_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
|
||||||
|
return rewriter.create<mhlo::DivOp>(loc, mul_x_alpha_poly, beta_poly);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||||
|
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ErfOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type ty = getElementTypeOrSelf(op.getType());
|
||||||
|
|
||||||
|
// For now, we support only f32.
|
||||||
|
if (!ty.isF32()) return failure();
|
||||||
|
|
||||||
|
ErfOp::Adaptor transformed(operands);
|
||||||
|
rewriter.replaceOp(op, MaterializeErfApproximationF32(
|
||||||
|
rewriter, op.getLoc(), transformed.operand()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
@ -226,7 +284,7 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
context, patterns, 5);
|
context, patterns, 5);
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
patterns->insert<ConvertConstantLikeOp>(context);
|
patterns->insert<ConvertConstantLikeOp, ConvertErfOp>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
|
|
|
@ -35,3 +35,54 @@ func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> {
|
||||||
return %1 : tensor<3xcomplex<f32>>
|
return %1 : tensor<3xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @erf_f32
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
|
||||||
|
func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00>
|
||||||
|
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00>
|
||||||
|
// CHECK: %[[TMP_2:.*]] = "mhlo.clamp"(%[[TMP_0]], %[[ARG]], %[[TMP_1]])
|
||||||
|
// CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]]
|
||||||
|
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<0.000000e+00>
|
||||||
|
// CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10>
|
||||||
|
// CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]]
|
||||||
|
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8>
|
||||||
|
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]]
|
||||||
|
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6>
|
||||||
|
// CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]]
|
||||||
|
// CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5>
|
||||||
|
// CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]]
|
||||||
|
// CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4>
|
||||||
|
// CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]]
|
||||||
|
// CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03>
|
||||||
|
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]]
|
||||||
|
// CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332>
|
||||||
|
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]]
|
||||||
|
// CHECK: %[[TMP_26:.*]] = mhlo.constant dense<0.000000e+00>
|
||||||
|
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5>
|
||||||
|
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
|
||||||
|
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4>
|
||||||
|
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
|
||||||
|
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702>
|
||||||
|
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
|
||||||
|
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925>
|
||||||
|
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
|
||||||
|
// CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]]
|
||||||
|
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391>
|
||||||
|
// CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]]
|
||||||
|
// CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]]
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]]
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue