Implement lowering of chlo::zeta to mhlo dialect.
PiperOrigin-RevId: 355395581
This commit is contained in:
		
							parent
							
								
									9d8a53c452
								
							
						
					
					
						commit
						6cd1875ee4
					
				|  | @ -36,7 +36,7 @@ class HloClientDialect : public Dialect { | ||||||
|   void initialize(); |   void initialize(); | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|   explicit HloClientDialect(MLIRContext *context) |   explicit HloClientDialect(MLIRContext* context) | ||||||
|       : Dialect(getDialectNamespace(), context, |       : Dialect(getDialectNamespace(), context, | ||||||
|                 TypeID::get<HloClientDialect>()) { |                 TypeID::get<HloClientDialect>()) { | ||||||
|     initialize(); |     initialize(); | ||||||
|  | @ -74,6 +74,8 @@ Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); | ||||||
| Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val, | Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val, | ||||||
|                               bool negative); |                               bool negative); | ||||||
| 
 | 
 | ||||||
|  | Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc, Value val); | ||||||
|  | 
 | ||||||
| }  // namespace chlo
 | }  // namespace chlo
 | ||||||
| }  // namespace mlir
 | }  // namespace mlir
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -256,7 +256,7 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp< | ||||||
|   }]; |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp< | def HLOClient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp< | ||||||
|     "broadcast_zeta", |     "broadcast_zeta", | ||||||
|     [NoSideEffect, SameOperandsAndResultElementType]> { |     [NoSideEffect, SameOperandsAndResultElementType]> { | ||||||
|   let summary = "Hurwitz zeta function"; |   let summary = "Hurwitz zeta function"; | ||||||
|  | @ -352,15 +352,14 @@ def HLOClient_ZetaOp : HLOClient_Op<"zeta", | ||||||
|   }]; |   }]; | ||||||
| 
 | 
 | ||||||
|   let arguments = (ins |   let arguments = (ins | ||||||
|     HLO_FpTensor:$lhs, |     HLO_FpTensor:$x, | ||||||
|     HLO_FpTensor:$rhs |     HLO_FpTensor:$q | ||||||
|   ); |   ); | ||||||
| 
 | 
 | ||||||
|   let results = (outs HLO_FpTensor); |   let results = (outs HLO_FpTensor); | ||||||
| 
 | 
 | ||||||
|   let assemblyFormat = [{ |   let assemblyFormat = [{ | ||||||
|     $lhs `,` $rhs attr-dict `:` |     $x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results) | ||||||
|       `(` type($lhs) `,` type($rhs) `)` `->` type(results) |  | ||||||
|   }]; |   }]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,9 +15,13 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| include "mlir/Pass/PassBase.td" | include "mlir/Pass/PassBase.td" | ||||||
| 
 | 
 | ||||||
| def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> { | def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> { | ||||||
|   let summary = "Legalize CHLO to HLO."; |   let summary = "Legalize CHLO to HLO."; | ||||||
|   let constructor = "createChloLegalizeToHloPass()"; |   let constructor = "createChloLegalizeToHloPass()"; | ||||||
|  |   let options = [ | ||||||
|  |     Option<"broadcast_only_", "broadcast-only", "bool", | ||||||
|  |            /*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">, | ||||||
|  |   ]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { | def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { | ||||||
|  |  | ||||||
|  | @ -45,7 +45,8 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass(); | ||||||
| std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); | std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); | ||||||
| 
 | 
 | ||||||
| /// Lowers from the CHLO dialect to the HLO dialect.
 | /// Lowers from the CHLO dialect to the HLO dialect.
 | ||||||
| std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(); | std::unique_ptr<FunctionPass> createChloLegalizeToHloPass( | ||||||
|  |     bool broadcast_only = false); | ||||||
| 
 | 
 | ||||||
| /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 | /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 | ||||||
| /// buffers if necessary.
 | /// buffers if necessary.
 | ||||||
|  |  | ||||||
|  | @ -99,8 +99,14 @@ void PopulateTrigonometricToApproximationPatterns( | ||||||
| 
 | 
 | ||||||
| namespace chlo { | namespace chlo { | ||||||
| 
 | 
 | ||||||
|  | // Populates a collection of conversion patterns for legalizing broadcasting
 | ||||||
|  | // client-HLO to their non-broadcasting counterparts.
 | ||||||
|  | void PopulateChloBroadcastingPatterns(MLIRContext *context, | ||||||
|  |                                       OwningRewritePatternList *patterns); | ||||||
|  | 
 | ||||||
| // Populates a collection of conversion patterns for legalizing client-HLO to
 | // Populates a collection of conversion patterns for legalizing client-HLO to
 | ||||||
| // HLO.
 | // HLO. Includes decomposition of operations and inserting of explicit
 | ||||||
|  | // broadcasts.
 | ||||||
| void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | ||||||
|                                        OwningRewritePatternList *patterns); |                                        OwningRewritePatternList *patterns); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -46,6 +46,13 @@ Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val, | ||||||
|       b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); |       b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc, | ||||||
|  |                                          Value val) { | ||||||
|  |   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>(); | ||||||
|  |   return getConstantLike( | ||||||
|  |       b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, | Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, | ||||||
|                       Value val) { |                       Value val) { | ||||||
|   Type ty = getElementTypeOrSelf(val.getType()); |   Type ty = getElementTypeOrSelf(val.getType()); | ||||||
|  |  | ||||||
|  | @ -33,6 +33,7 @@ limitations under the License. | ||||||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||
| #include "mlir/IR/Attributes.h" | #include "mlir/IR/Attributes.h" | ||||||
| #include "mlir/IR/BuiltinTypes.h" | #include "mlir/IR/BuiltinTypes.h" | ||||||
|  | #include "mlir/IR/ImplicitLocOpBuilder.h" | ||||||
| #include "mlir/IR/MLIRContext.h" | #include "mlir/IR/MLIRContext.h" | ||||||
| #include "mlir/IR/OperationSupport.h" | #include "mlir/IR/OperationSupport.h" | ||||||
| #include "mlir/IR/PatternMatch.h" | #include "mlir/IR/PatternMatch.h" | ||||||
|  | @ -766,6 +767,168 @@ struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter, | ||||||
|  |                                  Location loc, Value x, Value q) { | ||||||
|  |   static const std::array<double, 12> kZetaCoeffs{ | ||||||
|  |       -7.1661652561756670113e18, | ||||||
|  |       1.8152105401943546773e17, | ||||||
|  |       -4.5979787224074726105e15, | ||||||
|  |       1.1646782814350067249e14, | ||||||
|  |       -2.950130727918164224e12, | ||||||
|  |       7.47242496e10, | ||||||
|  |       -1.8924375803183791606e9, | ||||||
|  |       47900160.0, | ||||||
|  |       -1209600.0, | ||||||
|  |       30240.0, | ||||||
|  |       -720.0, | ||||||
|  |       12.0, | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   // For speed we'll always use 9 iterations for the initial series estimate,
 | ||||||
|  |   // and a 12 term expansion for the Euler-Maclaurin formula.
 | ||||||
|  |   Value a = q; | ||||||
|  |   Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a); | ||||||
|  |   Value neg_power = zero_like_a; | ||||||
|  |   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x); | ||||||
|  |   Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x); | ||||||
|  |   Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a); | ||||||
|  |   for (int i = 0; i < 9; ++i) { | ||||||
|  |     a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a); | ||||||
|  |     neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); | ||||||
|  |     initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power); | ||||||
|  |   } | ||||||
|  |   a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a); | ||||||
|  |   neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); | ||||||
|  |   Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x); | ||||||
|  |   Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x); | ||||||
|  |   Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a); | ||||||
|  |   Value neg_power_mul_a_div_x_minus_one = | ||||||
|  |       rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one); | ||||||
|  |   Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum, | ||||||
|  |                                          neg_power_mul_a_div_x_minus_one); | ||||||
|  |   Value a_inverse_square = rewriter.create<mhlo::DivOp>( | ||||||
|  |       loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a)); | ||||||
|  | 
 | ||||||
|  |   Value horner_sum = zero_like_a; | ||||||
|  |   Value factor = one_like_a; | ||||||
|  |   // Use Horner's rule for this.
 | ||||||
|  |   // Note this differs from Cephes which does a 'naive' polynomial evaluation.
 | ||||||
|  |   // Using Horner's rule allows to avoid some NaN's and Infs from happening,
 | ||||||
|  |   // resulting in more numerically stable code.
 | ||||||
|  |   for (int i = 0; i < 11; ++i) { | ||||||
|  |     Value factor_lhs = rewriter.create<mhlo::SubOp>( | ||||||
|  |         loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x)); | ||||||
|  |     Value factor_rhs = rewriter.create<mhlo::SubOp>( | ||||||
|  |         loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x)); | ||||||
|  |     factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs); | ||||||
|  |     horner_sum = rewriter.create<mhlo::MulOp>( | ||||||
|  |         loc, factor, | ||||||
|  |         rewriter.create<mhlo::MulOp>( | ||||||
|  |             loc, a_inverse_square, | ||||||
|  |             rewriter.create<mhlo::AddOp>( | ||||||
|  |                 loc, horner_sum, | ||||||
|  |                 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); | ||||||
|  |   } | ||||||
|  |   Value zero_point_five_like_neg_power = | ||||||
|  |       chlo::getConstantLike(rewriter, loc, .5, neg_power); | ||||||
|  |   Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a); | ||||||
|  |   s = rewriter.create<mhlo::AddOp>( | ||||||
|  |       loc, s, | ||||||
|  |       rewriter.create<mhlo::MulOp>( | ||||||
|  |           loc, neg_power, | ||||||
|  |           rewriter.create<mhlo::AddOp>( | ||||||
|  |               loc, zero_point_five_like_neg_power, | ||||||
|  |               rewriter.create<mhlo::MulOp>( | ||||||
|  |                   loc, x_div_a, | ||||||
|  |                   rewriter.create<mhlo::AddOp>( | ||||||
|  |                       loc, | ||||||
|  |                       chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], | ||||||
|  |                                             a), | ||||||
|  |                       horner_sum))))); | ||||||
|  |   const double nan = std::numeric_limits<double>::quiet_NaN(); | ||||||
|  |   const double inf = std::numeric_limits<double>::infinity(); | ||||||
|  |   // Use the initial zeta sum without the correction term coming
 | ||||||
|  |   // from Euler-Maclaurin if it is accurate enough.
 | ||||||
|  |   const StringAttr kLT = rewriter.getStringAttr( | ||||||
|  |       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); | ||||||
|  |   Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power); | ||||||
|  |   Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum); | ||||||
|  |   Value output = rewriter.create<mhlo::SelectOp>( | ||||||
|  |       loc, | ||||||
|  |       rewriter.create<mhlo::CompareOp>( | ||||||
|  |           loc, abs_neg_power, | ||||||
|  |           rewriter.create<mhlo::MulOp>( | ||||||
|  |               loc, abs_initial_sum, | ||||||
|  |               chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)), | ||||||
|  |           kLT), | ||||||
|  |       initial_sum, s); | ||||||
|  |   // This is the harmonic series.
 | ||||||
|  |   const StringAttr kEQ = rewriter.getStringAttr( | ||||||
|  |       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ)); | ||||||
|  |   Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x); | ||||||
|  |   output = rewriter.create<mhlo::SelectOp>( | ||||||
|  |       loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ), | ||||||
|  |       inf_like_x, output); | ||||||
|  |   // Function is not defined for x < 1.
 | ||||||
|  |   Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x); | ||||||
|  |   output = rewriter.create<mhlo::SelectOp>( | ||||||
|  |       loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), | ||||||
|  |       nan_like_x, output); | ||||||
|  |   // If q <= 0, then when q is an integer or x is not an integer, this is
 | ||||||
|  |   // NaN.
 | ||||||
|  |   const StringAttr kLE = rewriter.getStringAttr( | ||||||
|  |       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE)); | ||||||
|  |   const StringAttr kNE = rewriter.getStringAttr( | ||||||
|  |       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE)); | ||||||
|  |   Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q); | ||||||
|  |   Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE); | ||||||
|  |   Value domain_error = rewriter.create<mhlo::AndOp>( | ||||||
|  |       loc, q_le_zero, | ||||||
|  |       rewriter.create<mhlo::CompareOp>( | ||||||
|  |           loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE)); | ||||||
|  |   Value negative_integer_q = rewriter.create<mhlo::AndOp>( | ||||||
|  |       loc, q_le_zero, | ||||||
|  |       rewriter.create<mhlo::CompareOp>( | ||||||
|  |           loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ)); | ||||||
|  |   output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x, | ||||||
|  |                                            output); | ||||||
|  |   output = | ||||||
|  |       rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output); | ||||||
|  |   return output; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct ConvertZetaOp : public OpConversionPattern<ZetaOp> { | ||||||
|  |   using OpConversionPattern<ZetaOp>::OpConversionPattern; | ||||||
|  |   LogicalResult matchAndRewrite( | ||||||
|  |       ZetaOp op, ArrayRef<Value> operands, | ||||||
|  |       ConversionPatternRewriter &rewriter) const override { | ||||||
|  |     ZetaOpAdaptor adaptor(operands); | ||||||
|  |     Location loc = op.getLoc(); | ||||||
|  | 
 | ||||||
|  |     // Zeta is only defined on tensors of float elements and statically
 | ||||||
|  |     // verified that both have the same type. So it suffices to look at one
 | ||||||
|  |     // here.
 | ||||||
|  |     auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType(); | ||||||
|  | 
 | ||||||
|  |     bool needs_upcast = elm_type.isF16() || elm_type.isBF16(); | ||||||
|  | 
 | ||||||
|  |     Value x = adaptor.x(); | ||||||
|  |     Value q = adaptor.q(); | ||||||
|  | 
 | ||||||
|  |     if (needs_upcast) { | ||||||
|  |       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||||
|  |       q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type()); | ||||||
|  |     } | ||||||
|  |     Value result = MaterializeZetaComputation(rewriter, loc, x, q); | ||||||
|  |     if (needs_upcast) { | ||||||
|  |       result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type); | ||||||
|  |     } | ||||||
|  |     rewriter.replaceOp(op, {result}); | ||||||
|  | 
 | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| // Converts binary ops that statically are determined to not broadcast directly
 | // Converts binary ops that statically are determined to not broadcast directly
 | ||||||
| // to the corresponding mhlo non-broadcasting op.
 | // to the corresponding mhlo non-broadcasting op.
 | ||||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||||
|  | @ -904,10 +1067,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp | ||||||
| #include "generated_chlo_legalize_to_hlo.inc" | #include "generated_chlo_legalize_to_hlo.inc" | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | void PopulateChloBroadcastingPatterns(MLIRContext *context, | ||||||
|                                        OwningRewritePatternList *patterns) { |                                       OwningRewritePatternList *patterns) { | ||||||
|   populateWithGenerated(context, *patterns); |  | ||||||
| 
 |  | ||||||
|   // Instantiate conversion templates for conforming binary elementwise ops
 |   // Instantiate conversion templates for conforming binary elementwise ops
 | ||||||
|   // that do not have different dtypes between operands and results and do
 |   // that do not have different dtypes between operands and results and do
 | ||||||
|   // not have special attributes that need to be preserved.
 |   // not have special attributes that need to be preserved.
 | ||||||
|  | @ -915,6 +1076,12 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | ||||||
|       context, patterns, 10); |       context, patterns, 10); | ||||||
|   PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( |   PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( | ||||||
|       context, patterns, 5); |       context, patterns, 5); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | ||||||
|  |                                        OwningRewritePatternList *patterns) { | ||||||
|  |   populateWithGenerated(context, *patterns); | ||||||
|  |   PopulateChloBroadcastingPatterns(context, patterns); | ||||||
| 
 | 
 | ||||||
|   // Other patterns.
 |   // Other patterns.
 | ||||||
|   // clang-format off
 |   // clang-format off
 | ||||||
|  | @ -922,7 +1089,8 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | ||||||
|                    ConvertDigammaOp, |                    ConvertDigammaOp, | ||||||
|                    ConvertErfOp, |                    ConvertErfOp, | ||||||
|                    ConvertErfcOp, |                    ConvertErfcOp, | ||||||
|                    ConvertLgammaOp>(context); |                    ConvertLgammaOp, | ||||||
|  |                    ConvertZetaOp>(context); | ||||||
|   // clang-format on
 |   // clang-format on
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" | #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||||
| #include "mlir/Dialect/SCF/SCF.h" | #include "mlir/Dialect/SCF/SCF.h" | ||||||
|  | @ -29,7 +30,13 @@ namespace mhlo { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| struct ChloLegalizeToHloPass | struct ChloLegalizeToHloPass | ||||||
|     : public PassWrapper<ChloLegalizeToHloPass, FunctionPass> { |     : public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> { | ||||||
|  |   explicit ChloLegalizeToHloPass(bool broadcast_only) | ||||||
|  |       : ChloLegalizeToHloPassBase< | ||||||
|  |             ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() { | ||||||
|  |     this->broadcast_only_ = broadcast_only; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void getDependentDialects(DialectRegistry ®istry) const override { |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|     registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>(); |     registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>(); | ||||||
|   } |   } | ||||||
|  | @ -45,12 +52,16 @@ struct ChloLegalizeToHloPass | ||||||
|         MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, |         MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, | ||||||
|         mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); |         mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); | ||||||
| 
 | 
 | ||||||
|     // TODO(herhut): This is temporary while Zeta cannot be lowered to hlo.
 |     if (broadcast_only_) { | ||||||
|     conversionTarget.addLegalOp<chlo::ZetaOp>(); |       chlo::PopulateChloBroadcastingPatterns(&getContext(), | ||||||
|  |                                              &conversionPatterns); | ||||||
|  |       conversionTarget.addLegalOp<chlo::ZetaOp>(); | ||||||
|  |     } else { | ||||||
|  |       chlo::PopulateLegalizeChloToHloPatterns(&getContext(), | ||||||
|  |                                               &conversionPatterns); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); |     if (failed(applyPartialConversion(getOperation(), conversionTarget, | ||||||
| 
 |  | ||||||
|     if (failed(applyPartialConversion(getFunction(), conversionTarget, |  | ||||||
|                                       std::move(conversionPatterns)))) { |                                       std::move(conversionPatterns)))) { | ||||||
|       return signalPassFailure(); |       return signalPassFailure(); | ||||||
|     } |     } | ||||||
|  | @ -59,10 +70,9 @@ struct ChloLegalizeToHloPass | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() { | std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) { | ||||||
|   return std::make_unique<ChloLegalizeToHloPass>(); |   return std::make_unique<ChloLegalizeToHloPass>(broadcast_only); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace mhlo
 | }  // namespace mhlo
 | ||||||
| }  // namespace mlir
 | }  // namespace mlir
 | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s | // RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s | ||||||
| 
 | 
 | ||||||
| // Check the non-broadcast case for each registered op, then just check a | // Check the non-broadcast case for each registered op, then just check a | ||||||
| // representative op for detailed broadcast semantics. | // representative op for detailed broadcast semantics. | ||||||
|  |  | ||||||
|  | @ -1105,3 +1105,184 @@ func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> { | ||||||
|   %1 = chlo.digamma %arg : tensor<f16> -> tensor<f16> |   %1 = chlo.digamma %arg : tensor<f16> -> tensor<f16> | ||||||
|   return %1 : tensor<f16> |   return %1 : tensor<f16> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL:   func @zeta_f16( | ||||||
|  | // CHECK-SAME:                   %[[VAL_0:.*]]: tensor<f16>, | ||||||
|  | // CHECK-SAME:                   %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> { | ||||||
|  | func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> { | ||||||
|  |   %0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16> | ||||||
|  | // CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_5:.*]] = "mhlo.negate"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_6:.*]] = mhlo.power %[[VAL_3]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_8:.*]] = mhlo.add %[[VAL_3]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_9:.*]] = mhlo.power %[[VAL_8]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_10:.*]] = mhlo.add %[[VAL_6]], %[[VAL_9]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_11:.*]] = mhlo.add %[[VAL_8]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_12:.*]] = mhlo.power %[[VAL_11]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_13:.*]] = mhlo.add %[[VAL_10]], %[[VAL_12]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_14:.*]] = mhlo.add %[[VAL_11]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_15:.*]] = mhlo.power %[[VAL_14]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_16:.*]] = mhlo.add %[[VAL_13]], %[[VAL_15]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_17:.*]] = mhlo.add %[[VAL_14]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_18:.*]] = mhlo.power %[[VAL_17]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_19:.*]] = mhlo.add %[[VAL_16]], %[[VAL_18]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_20:.*]] = mhlo.add %[[VAL_17]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_21:.*]] = mhlo.power %[[VAL_20]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_22:.*]] = mhlo.add %[[VAL_19]], %[[VAL_21]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_23:.*]] = mhlo.add %[[VAL_20]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_24:.*]] = mhlo.power %[[VAL_23]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_25:.*]] = mhlo.add %[[VAL_22]], %[[VAL_24]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_27:.*]] = mhlo.power %[[VAL_26]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_28:.*]] = mhlo.add %[[VAL_25]], %[[VAL_27]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_29:.*]] = mhlo.add %[[VAL_26]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_30:.*]] = mhlo.power %[[VAL_29]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_31:.*]] = mhlo.add %[[VAL_28]], %[[VAL_30]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_32:.*]] = mhlo.add %[[VAL_29]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_33:.*]] = mhlo.power %[[VAL_32]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_34:.*]] = mhlo.add %[[VAL_31]], %[[VAL_33]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_35:.*]] = mhlo.add %[[VAL_32]], %[[VAL_7]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_36:.*]] = mhlo.power %[[VAL_35]], %[[VAL_5]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_37:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_38:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_37]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_39:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_35]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_40:.*]] = mhlo.divide %[[VAL_39]], %[[VAL_38]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_41:.*]] = mhlo.add %[[VAL_34]], %[[VAL_40]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_42:.*]] = mhlo.multiply %[[VAL_35]], %[[VAL_35]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_43:.*]] = mhlo.divide %[[VAL_7]], %[[VAL_42]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_44:.*]] = mhlo.constant dense<2.200000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_45:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_44]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_46:.*]] = mhlo.constant dense<2.100000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_47:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_46]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_48:.*]] = mhlo.multiply %[[VAL_45]], %[[VAL_47]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_49:.*]] = mhlo.constant dense<-1.39544646E-19> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_50:.*]] = mhlo.add %[[VAL_4]], %[[VAL_49]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_51:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_50]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_52:.*]] = mhlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_53:.*]] = mhlo.constant dense<2.000000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_54:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_53]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_55:.*]] = mhlo.constant dense<1.900000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_56:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_55]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_57:.*]] = mhlo.multiply %[[VAL_54]], %[[VAL_56]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_58:.*]] = mhlo.constant dense<5.50900303E-18> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_59:.*]] = mhlo.add %[[VAL_52]], %[[VAL_58]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_60:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_59]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_61:.*]] = mhlo.multiply %[[VAL_57]], %[[VAL_60]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_62:.*]] = mhlo.constant dense<1.800000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_63:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_62]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_64:.*]] = mhlo.constant dense<1.700000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_65:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_64]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_66:.*]] = mhlo.multiply %[[VAL_63]], %[[VAL_65]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_67:.*]] = mhlo.constant dense<-2.17486866E-16> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_68:.*]] = mhlo.add %[[VAL_61]], %[[VAL_67]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_69:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_68]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_70:.*]] = mhlo.multiply %[[VAL_66]], %[[VAL_69]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_71:.*]] = mhlo.constant dense<1.600000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_72:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_71]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_73:.*]] = mhlo.constant dense<1.500000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_74:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_73]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_75:.*]] = mhlo.multiply %[[VAL_72]], %[[VAL_74]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_76:.*]] = mhlo.constant dense<8.58606213E-15> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_77:.*]] = mhlo.add %[[VAL_70]], %[[VAL_76]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_78:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_77]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_79:.*]] = mhlo.multiply %[[VAL_75]], %[[VAL_78]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_80:.*]] = mhlo.constant dense<1.400000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_81:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_80]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_82:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_83:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_82]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_84:.*]] = mhlo.multiply %[[VAL_81]], %[[VAL_83]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_85:.*]] = mhlo.constant dense<-3.3896803E-13> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_86:.*]] = mhlo.add %[[VAL_79]], %[[VAL_85]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_87:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_86]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_88:.*]] = mhlo.multiply %[[VAL_84]], %[[VAL_87]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_89:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_90:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_89]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_91:.*]] = mhlo.constant dense<1.100000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_92:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_91]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_93:.*]] = mhlo.multiply %[[VAL_90]], %[[VAL_92]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_94:.*]] = mhlo.constant dense<1.33825364E-11> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_95:.*]] = mhlo.add %[[VAL_88]], %[[VAL_94]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_96:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_95]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_97:.*]] = mhlo.multiply %[[VAL_93]], %[[VAL_96]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_98:.*]] = mhlo.constant dense<1.000000e+01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_99:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_98]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_100:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_101:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_100]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_102:.*]] = mhlo.multiply %[[VAL_99]], %[[VAL_101]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_103:.*]] = mhlo.constant dense<-5.28419031E-10> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_104:.*]] = mhlo.add %[[VAL_97]], %[[VAL_103]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_105:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_104]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_106:.*]] = mhlo.multiply %[[VAL_102]], %[[VAL_105]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_107:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_108:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_107]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_109:.*]] = mhlo.constant dense<7.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_110:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_109]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_111:.*]] = mhlo.multiply %[[VAL_108]], %[[VAL_110]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_112:.*]] = mhlo.constant dense<2.08767563E-8> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_113:.*]] = mhlo.add %[[VAL_106]], %[[VAL_112]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_114:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_113]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_115:.*]] = mhlo.multiply %[[VAL_111]], %[[VAL_114]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_116:.*]] = mhlo.constant dense<6.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_117:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_116]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_118:.*]] = mhlo.constant dense<5.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_119:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_118]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_120:.*]] = mhlo.multiply %[[VAL_117]], %[[VAL_119]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_121:.*]] = mhlo.constant dense<-8.26719599E-7> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_122:.*]] = mhlo.add %[[VAL_115]], %[[VAL_121]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_123:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_122]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_124:.*]] = mhlo.multiply %[[VAL_120]], %[[VAL_123]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_125:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_126:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_125]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_127:.*]] = mhlo.constant dense<3.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_128:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_127]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_129:.*]] = mhlo.multiply %[[VAL_126]], %[[VAL_128]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_130:.*]] = mhlo.constant dense<3.30687835E-5> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_131:.*]] = mhlo.add %[[VAL_124]], %[[VAL_130]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_132:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_131]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_133:.*]] = mhlo.multiply %[[VAL_129]], %[[VAL_132]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_134:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_135:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_134]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_136:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_137:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_136]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_138:.*]] = mhlo.multiply %[[VAL_135]], %[[VAL_137]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_139:.*]] = mhlo.constant dense<-0.00138888892> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_140:.*]] = mhlo.add %[[VAL_133]], %[[VAL_139]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_141:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_140]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_142:.*]] = mhlo.multiply %[[VAL_138]], %[[VAL_141]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_143:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_144:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_35]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_145:.*]] = mhlo.constant dense<0.0833333358> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_146:.*]] = mhlo.add %[[VAL_145]], %[[VAL_142]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_147:.*]] = mhlo.multiply %[[VAL_144]], %[[VAL_146]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_148:.*]] = mhlo.add %[[VAL_143]], %[[VAL_147]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_149:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_148]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_150:.*]] = mhlo.add %[[VAL_41]], %[[VAL_149]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_151:.*]] = "mhlo.abs"(%[[VAL_36]]) : (tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_152:.*]] = "mhlo.abs"(%[[VAL_34]]) : (tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_153:.*]] = mhlo.constant dense<1.401300e-45> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_154:.*]] = mhlo.multiply %[[VAL_152]], %[[VAL_153]] : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_155:.*]] = "mhlo.compare"(%[[VAL_151]], %[[VAL_154]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_156:.*]] = "mhlo.select"(%[[VAL_155]], %[[VAL_34]], %[[VAL_150]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_157:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_158:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_159:.*]] = "mhlo.select"(%[[VAL_158]], %[[VAL_157]], %[[VAL_156]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_160:.*]] = mhlo.constant dense<0x7FC00000> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_161:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_162:.*]] = "mhlo.select"(%[[VAL_161]], %[[VAL_160]], %[[VAL_159]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_163:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> | ||||||
|  | // CHECK: %[[VAL_164:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_163]]) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_165:.*]] = "mhlo.floor"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_166:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_165]]) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_167:.*]] = mhlo.and %[[VAL_164]], %[[VAL_166]] : tensor<i1> | ||||||
|  | // CHECK: %[[VAL_169:.*]] = "mhlo.floor"(%[[VAL_3]]) : (tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_170:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_169]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  | // CHECK: %[[VAL_171:.*]] = mhlo.and %[[VAL_164]], %[[VAL_170]] : tensor<i1> | ||||||
|  | // CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  | // CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16> | ||||||
|  |   return %0 : tensor<f16> | ||||||
|  | // CHECK: return %[[VAL_174]] : tensor<f16> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue