diff --git a/BUILD b/BUILD index 1146494..b4dfa36 100644 --- a/BUILD +++ b/BUILD @@ -977,6 +977,7 @@ cc_library( ":chlo_legalize_to_hlo_inc_gen", ":hlo", ":map_chlo_to_hlo_op", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 808b0af..4a2fd8c 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" @@ -96,7 +97,8 @@ Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter, // argument and derive the final approximation for all |x| >= 1. // This implementation is based on Cephes. Value MaterializeErfcApproximationF64ForMagnituteGEOne( - ConversionPatternRewriter &rewriter, Location loc, Value x) { + ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF64() && "expect f64 element type"); const double kMaxlog = 7.09782712893383996843E2; @@ -179,7 +181,8 @@ Value MaterializeErfcApproximationF64ForMagnituteGEOne( // Precondition is |x| <= 1. Use erfc approximation, otherwise. // This implementation is based on Cephes. Value MaterializeErfApproximationF64ForMagnituteLEOne( - ConversionPatternRewriter &rewriter, Location loc, Value x) { + ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF64() && "expect f64 element type"); const std::vector kErfTCoefficients{ @@ -204,7 +207,8 @@ Value MaterializeErfApproximationF64ForMagnituteLEOne( // This implementation is based on Cephes. Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, Value x) { + Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF64() && "expect f64 element type"); @@ -230,7 +234,8 @@ Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter, } Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, Value x) { + Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF64() && "expect f64 element type"); @@ -261,7 +266,8 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter, // argument and derive the final approximation for all |x| >= 1. // This implementation is based on Cephes. Value MaterializeErfcApproximationF32ForMagnitudeGEOne( - ConversionPatternRewriter &rewriter, Location loc, Value x) { + ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF32() && "expect f32 element type"); const double kMaxlog = 88.72283905206835; @@ -325,7 +331,8 @@ Value MaterializeErfcApproximationF32ForMagnitudeGEOne( // Precondition is |x| <= 1. Use erfc approximation, otherwise. // This implementation is based on Cephes. Value MaterializeErfApproximationF32ForMagnitudeLEOne( - ConversionPatternRewriter &rewriter, Location loc, Value x) { + ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF32() && "expect f32 element type"); const std::vector kErfTCoefficients{ @@ -344,8 +351,9 @@ Value MaterializeErfApproximationF32ForMagnitudeLEOne( // This is the same approximation as used in Eigen. Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, Value operand) { - assert(operand.getType().cast().getElementType().isF32() && + Location loc, ValueRange args) { + Value x = args.front(); + assert(x.getType().cast().getElementType().isF32() && "expect f32 element type"); const std::vector kAlpha{ -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, @@ -358,10 +366,9 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, }; // Clamp argument between -4 and 4. - Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand); - Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand); - Value x = - rewriter.create(loc, operand.getType(), lb, operand, ub); + Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x); + Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x); + x = rewriter.create(loc, x.getType(), lb, x, ub); Value x_sq = rewriter.create(loc, x, x); // Materialize polynomial approximation for x in [-4, 4] as @@ -375,7 +382,8 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, } Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, Value x) { + Location loc, ValueRange args) { + Value x = args.front(); assert(x.getType().cast().getElementType().isF32() && "expect f32 element type"); @@ -401,18 +409,30 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter, } Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, - Value arg, FloatType min_precision_ty, + ValueRange args, FloatType min_precision_ty, Value callback(ConversionPatternRewriter &, - Location, Value)) { - auto original_ty = getElementTypeOrSelf(arg.getType()).cast(); + Location, ValueRange)) { + auto original_ty = + getElementTypeOrSelf(args.front().getType()).cast(); bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth(); + + // Upcast arguments if necessary. + llvm::SmallVector casted_args; if (needs_upcast) { - arg = rewriter.create(loc, arg, min_precision_ty); + for (Value a : args) { + casted_args.push_back( + rewriter.create(loc, a, min_precision_ty)); + } + args = casted_args; } - Value result = callback(rewriter, loc, arg); + + Value result = callback(rewriter, loc, args); + + // Cast back if necessary. if (needs_upcast) { result = rewriter.create(loc, result, original_ty); } + return result; } @@ -434,9 +454,9 @@ struct ConvertErfOp : public OpConversionPattern { return success(); } - rewriter.replaceOp( - op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), - &MaterializeErfApproximationF32)); + rewriter.replaceOp(op, MaterializeWithUpcast( + rewriter, loc, operands, rewriter.getF32Type(), + &MaterializeErfApproximationF32)); return success(); } }; @@ -459,9 +479,9 @@ struct ConvertErfcOp : public OpConversionPattern { return success(); } - rewriter.replaceOp( - op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), - &MaterializeErfcApproximationF32)); + rewriter.replaceOp(op, MaterializeWithUpcast( + rewriter, loc, operands, rewriter.getF32Type(), + &MaterializeErfcApproximationF32)); return success(); } }; @@ -491,12 +511,13 @@ constexpr std::array kLanczosCoefficients = { // a(z) = kBaseLanczosCoeff // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, - Value x) { + ValueRange args) { // If the input is less than 0.5 use Euler's reflection formula. // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) // Let z be // z = -x if x < 1/2 // z = x - 1 otheriwse + Value x = args.front(); const StringAttr kLT = rewriter.getStringAttr( mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); Value half = getConstantLike(rewriter, loc, 0.5, x); @@ -635,12 +656,13 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc, - Value x) { + ValueRange args) { // If the input is less than 0.5 use Euler's reflection formula. // digamma(x) = digamma(1 - x) - pi * cot(pi * x) // Let z be // z = -x if x < 1/2 // z = x - 1 otheriwse + Value x = args.front(); const StringAttr kLT = rewriter.getStringAttr( mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); Value half = getConstantLike(rewriter, loc, 0.5, x); @@ -739,36 +761,11 @@ Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc, digamma); } -struct ConvertLgammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - LgammaOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - LgammaOp::Adaptor transformed(operands); - FloatType min_precision_ty = rewriter.getF32Type(); - rewriter.replaceOp( - op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(), - min_precision_ty, &MaterializeLgamma)); - return success(); - } -}; - -struct ConvertDigammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - DigammaOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - DigammaOp::Adaptor transformed(operands); - FloatType min_precision_ty = rewriter.getF32Type(); - rewriter.replaceOp( - op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(), - min_precision_ty, &MaterializeDigamma)); - return success(); - } -}; - -Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter, - Location loc, Value x, Value q) { +Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc, + ValueRange args) { + assert(args.size() == 2); + Value x = args[0]; + Value q = args[1]; static const std::array kZetaCoeffs{ -7.1661652561756670113e18, 1.8152105401943546773e17, @@ -897,34 +894,42 @@ Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter, return output; } +struct ConvertLgammaOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + LgammaOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FloatType min_precision_ty = rewriter.getF32Type(); + rewriter.replaceOp( + op, MaterializeWithUpcast(rewriter, op.getLoc(), operands, + min_precision_ty, &MaterializeLgamma)); + return success(); + } +}; + +struct ConvertDigammaOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + DigammaOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FloatType min_precision_ty = rewriter.getF32Type(); + rewriter.replaceOp( + op, MaterializeWithUpcast(rewriter, op.getLoc(), operands, + min_precision_ty, &MaterializeDigamma)); + return success(); + } +}; + struct ConvertZetaOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( ZetaOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - ZetaOpAdaptor adaptor(operands); Location loc = op.getLoc(); - - // Zeta is only defined on tensors of float elements and statically - // verified that both have the same type. So it suffices to look at one - // here. - auto elm_type = adaptor.x().getType().cast().getElementType(); - - bool needs_upcast = elm_type.isF16() || elm_type.isBF16(); - - Value x = adaptor.x(); - Value q = adaptor.q(); - - if (needs_upcast) { - x = rewriter.create(loc, x, rewriter.getF32Type()); - q = rewriter.create(loc, q, rewriter.getF32Type()); - } - Value result = MaterializeZetaComputation(rewriter, loc, x, q); - if (needs_upcast) { - result = rewriter.create(loc, result, elm_type); - } - rewriter.replaceOp(op, {result}); - + FloatType min_precision_ty = rewriter.getF32Type(); + rewriter.replaceOp( + op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty, + &MaterializeZeta)); return success(); } };