Implement lowering of chlo::zeta to mhlo dialect.
PiperOrigin-RevId: 355395581
This commit is contained in:
parent
9d8a53c452
commit
6cd1875ee4
|
@ -36,7 +36,7 @@ class HloClientDialect : public Dialect {
|
||||||
void initialize();
|
void initialize();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit HloClientDialect(MLIRContext *context)
|
explicit HloClientDialect(MLIRContext* context)
|
||||||
: Dialect(getDialectNamespace(), context,
|
: Dialect(getDialectNamespace(), context,
|
||||||
TypeID::get<HloClientDialect>()) {
|
TypeID::get<HloClientDialect>()) {
|
||||||
initialize();
|
initialize();
|
||||||
|
@ -74,6 +74,8 @@ Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||||
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
|
Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
|
||||||
bool negative);
|
bool negative);
|
||||||
|
|
||||||
|
Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -256,7 +256,7 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
def HLOClient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
"broadcast_zeta",
|
"broadcast_zeta",
|
||||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
let summary = "Hurwitz zeta function";
|
let summary = "Hurwitz zeta function";
|
||||||
|
@ -352,15 +352,14 @@ def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_FpTensor:$lhs,
|
HLO_FpTensor:$x,
|
||||||
HLO_FpTensor:$rhs
|
HLO_FpTensor:$q
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$lhs `,` $rhs attr-dict `:`
|
$x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results)
|
||||||
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,13 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
|
||||||
let summary = "Legalize CHLO to HLO.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
|
let options = [
|
||||||
|
Option<"broadcast_only_", "broadcast-only", "bool",
|
||||||
|
/*default=*/"false", "Only lower broadcasting chlo to non-broadcasting equivalents">,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
|
|
|
@ -45,7 +45,8 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
/// Lowers from the CHLO dialect to the HLO dialect.
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
|
||||||
|
bool broadcast_only = false);
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary.
|
/// buffers if necessary.
|
||||||
|
|
|
@ -99,8 +99,14 @@ void PopulateTrigonometricToApproximationPatterns(
|
||||||
|
|
||||||
namespace chlo {
|
namespace chlo {
|
||||||
|
|
||||||
|
// Populates a collection of conversion patterns for legalizing broadcasting
|
||||||
|
// client-HLO to their non-broadcasting counterparts.
|
||||||
|
void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||||
// HLO.
|
// HLO. Includes decomposition of operations and inserting of explicit
|
||||||
|
// broadcasts.
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,13 @@ Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
|
||||||
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
|
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
|
||||||
|
Value val) {
|
||||||
|
auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
|
||||||
|
return getConstantLike(
|
||||||
|
b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
|
||||||
|
}
|
||||||
|
|
||||||
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
||||||
Value val) {
|
Value val) {
|
||||||
Type ty = getElementTypeOrSelf(val.getType());
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
|
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/OperationSupport.h"
|
#include "mlir/IR/OperationSupport.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
@ -766,6 +767,168 @@ struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value x, Value q) {
|
||||||
|
static const std::array<double, 12> kZetaCoeffs{
|
||||||
|
-7.1661652561756670113e18,
|
||||||
|
1.8152105401943546773e17,
|
||||||
|
-4.5979787224074726105e15,
|
||||||
|
1.1646782814350067249e14,
|
||||||
|
-2.950130727918164224e12,
|
||||||
|
7.47242496e10,
|
||||||
|
-1.8924375803183791606e9,
|
||||||
|
47900160.0,
|
||||||
|
-1209600.0,
|
||||||
|
30240.0,
|
||||||
|
-720.0,
|
||||||
|
12.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
// For speed we'll always use 9 iterations for the initial series estimate,
|
||||||
|
// and a 12 term expansion for the Euler-Maclaurin formula.
|
||||||
|
Value a = q;
|
||||||
|
Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a);
|
||||||
|
Value neg_power = zero_like_a;
|
||||||
|
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
||||||
|
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
|
||||||
|
Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a);
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
|
||||||
|
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
||||||
|
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
|
||||||
|
}
|
||||||
|
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
|
||||||
|
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
||||||
|
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
|
||||||
|
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
|
||||||
|
Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
|
||||||
|
Value neg_power_mul_a_div_x_minus_one =
|
||||||
|
rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
|
||||||
|
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
|
||||||
|
neg_power_mul_a_div_x_minus_one);
|
||||||
|
Value a_inverse_square = rewriter.create<mhlo::DivOp>(
|
||||||
|
loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a));
|
||||||
|
|
||||||
|
Value horner_sum = zero_like_a;
|
||||||
|
Value factor = one_like_a;
|
||||||
|
// Use Horner's rule for this.
|
||||||
|
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
|
||||||
|
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
|
||||||
|
// resulting in more numerically stable code.
|
||||||
|
for (int i = 0; i < 11; ++i) {
|
||||||
|
Value factor_lhs = rewriter.create<mhlo::SubOp>(
|
||||||
|
loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
|
||||||
|
Value factor_rhs = rewriter.create<mhlo::SubOp>(
|
||||||
|
loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
|
||||||
|
factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
|
||||||
|
horner_sum = rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, factor,
|
||||||
|
rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, a_inverse_square,
|
||||||
|
rewriter.create<mhlo::AddOp>(
|
||||||
|
loc, horner_sum,
|
||||||
|
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
|
||||||
|
}
|
||||||
|
Value zero_point_five_like_neg_power =
|
||||||
|
chlo::getConstantLike(rewriter, loc, .5, neg_power);
|
||||||
|
Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
|
||||||
|
s = rewriter.create<mhlo::AddOp>(
|
||||||
|
loc, s,
|
||||||
|
rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, neg_power,
|
||||||
|
rewriter.create<mhlo::AddOp>(
|
||||||
|
loc, zero_point_five_like_neg_power,
|
||||||
|
rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, x_div_a,
|
||||||
|
rewriter.create<mhlo::AddOp>(
|
||||||
|
loc,
|
||||||
|
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
|
||||||
|
a),
|
||||||
|
horner_sum)))));
|
||||||
|
const double nan = std::numeric_limits<double>::quiet_NaN();
|
||||||
|
const double inf = std::numeric_limits<double>::infinity();
|
||||||
|
// Use the initial zeta sum without the correction term coming
|
||||||
|
// from Euler-Maclaurin if it is accurate enough.
|
||||||
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||||
|
Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
|
||||||
|
Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
|
||||||
|
Value output = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, abs_neg_power,
|
||||||
|
rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, abs_initial_sum,
|
||||||
|
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
|
||||||
|
kLT),
|
||||||
|
initial_sum, s);
|
||||||
|
// This is the harmonic series.
|
||||||
|
const StringAttr kEQ = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
|
||||||
|
Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x);
|
||||||
|
output = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ),
|
||||||
|
inf_like_x, output);
|
||||||
|
// Function is not defined for x < 1.
|
||||||
|
Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x);
|
||||||
|
output = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT),
|
||||||
|
nan_like_x, output);
|
||||||
|
// If q <= 0, then when q is an integer or x is not an integer, this is
|
||||||
|
// NaN.
|
||||||
|
const StringAttr kLE = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
|
||||||
|
const StringAttr kNE = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
|
||||||
|
Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q);
|
||||||
|
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE);
|
||||||
|
Value domain_error = rewriter.create<mhlo::AndOp>(
|
||||||
|
loc, q_le_zero,
|
||||||
|
rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE));
|
||||||
|
Value negative_integer_q = rewriter.create<mhlo::AndOp>(
|
||||||
|
loc, q_le_zero,
|
||||||
|
rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ));
|
||||||
|
output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x,
|
||||||
|
output);
|
||||||
|
output =
|
||||||
|
rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
||||||
|
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ZetaOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
ZetaOpAdaptor adaptor(operands);
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
// Zeta is only defined on tensors of float elements and statically
|
||||||
|
// verified that both have the same type. So it suffices to look at one
|
||||||
|
// here.
|
||||||
|
auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
|
||||||
|
|
||||||
|
bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
|
||||||
|
|
||||||
|
Value x = adaptor.x();
|
||||||
|
Value q = adaptor.q();
|
||||||
|
|
||||||
|
if (needs_upcast) {
|
||||||
|
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
||||||
|
q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
|
||||||
|
}
|
||||||
|
Value result = MaterializeZetaComputation(rewriter, loc, x, q);
|
||||||
|
if (needs_upcast) {
|
||||||
|
result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
@ -904,10 +1067,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
#include "generated_chlo_legalize_to_hlo.inc"
|
#include "generated_chlo_legalize_to_hlo.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateChloBroadcastingPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(context, *patterns);
|
|
||||||
|
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
// not have special attributes that need to be preserved.
|
// not have special attributes that need to be preserved.
|
||||||
|
@ -915,6 +1076,12 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
context, patterns, 10);
|
context, patterns, 10);
|
||||||
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
|
||||||
context, patterns, 5);
|
context, patterns, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
populateWithGenerated(context, *patterns);
|
||||||
|
PopulateChloBroadcastingPatterns(context, patterns);
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -922,7 +1089,8 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
ConvertDigammaOp,
|
ConvertDigammaOp,
|
||||||
ConvertErfOp,
|
ConvertErfOp,
|
||||||
ConvertErfcOp,
|
ConvertErfcOp,
|
||||||
ConvertLgammaOp>(context);
|
ConvertLgammaOp,
|
||||||
|
ConvertZetaOp>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -29,7 +30,13 @@ namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct ChloLegalizeToHloPass
|
struct ChloLegalizeToHloPass
|
||||||
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
|
: public ChloLegalizeToHloPassBase<ChloLegalizeToHloPass> {
|
||||||
|
explicit ChloLegalizeToHloPass(bool broadcast_only)
|
||||||
|
: ChloLegalizeToHloPassBase<
|
||||||
|
ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() {
|
||||||
|
this->broadcast_only_ = broadcast_only;
|
||||||
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
@ -45,12 +52,16 @@ struct ChloLegalizeToHloPass
|
||||||
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect,
|
||||||
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
mlir::shape::ShapeDialect, mlir::scf::SCFDialect>();
|
||||||
|
|
||||||
// TODO(herhut): This is temporary while Zeta cannot be lowered to hlo.
|
if (broadcast_only_) {
|
||||||
conversionTarget.addLegalOp<chlo::ZetaOp>();
|
chlo::PopulateChloBroadcastingPatterns(&getContext(),
|
||||||
|
&conversionPatterns);
|
||||||
|
conversionTarget.addLegalOp<chlo::ZetaOp>();
|
||||||
|
} else {
|
||||||
|
chlo::PopulateLegalizeChloToHloPatterns(&getContext(),
|
||||||
|
&conversionPatterns);
|
||||||
|
}
|
||||||
|
|
||||||
chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
if (failed(applyPartialConversion(getOperation(), conversionTarget,
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
|
||||||
std::move(conversionPatterns)))) {
|
std::move(conversionPatterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@ -59,10 +70,9 @@ struct ChloLegalizeToHloPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(bool broadcast_only) {
|
||||||
return std::make_unique<ChloLegalizeToHloPass>();
|
return std::make_unique<ChloLegalizeToHloPass>(broadcast_only);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="broadcast-only=true" -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Check the non-broadcast case for each registered op, then just check a
|
// Check the non-broadcast case for each registered op, then just check a
|
||||||
// representative op for detailed broadcast semantics.
|
// representative op for detailed broadcast semantics.
|
||||||
|
|
|
@ -1105,3 +1105,184 @@ func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||||
%1 = chlo.digamma %arg : tensor<f16> -> tensor<f16>
|
%1 = chlo.digamma %arg : tensor<f16> -> tensor<f16>
|
||||||
return %1 : tensor<f16>
|
return %1 : tensor<f16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @zeta_f16(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
|
||||||
|
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
||||||
|
%0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "mhlo.negate"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mhlo.power %[[VAL_3]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = mhlo.add %[[VAL_3]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = mhlo.power %[[VAL_8]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = mhlo.add %[[VAL_6]], %[[VAL_9]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = mhlo.add %[[VAL_8]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = mhlo.power %[[VAL_11]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = mhlo.add %[[VAL_10]], %[[VAL_12]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = mhlo.add %[[VAL_11]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = mhlo.power %[[VAL_14]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = mhlo.add %[[VAL_13]], %[[VAL_15]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = mhlo.add %[[VAL_14]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mhlo.power %[[VAL_17]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[VAL_16]], %[[VAL_18]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = mhlo.add %[[VAL_17]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = mhlo.power %[[VAL_20]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_22:.*]] = mhlo.add %[[VAL_19]], %[[VAL_21]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_23:.*]] = mhlo.add %[[VAL_20]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_24:.*]] = mhlo.power %[[VAL_23]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_25:.*]] = mhlo.add %[[VAL_22]], %[[VAL_24]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_26:.*]] = mhlo.add %[[VAL_23]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_27:.*]] = mhlo.power %[[VAL_26]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_28:.*]] = mhlo.add %[[VAL_25]], %[[VAL_27]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_29:.*]] = mhlo.add %[[VAL_26]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_30:.*]] = mhlo.power %[[VAL_29]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_31:.*]] = mhlo.add %[[VAL_28]], %[[VAL_30]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_32:.*]] = mhlo.add %[[VAL_29]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_33:.*]] = mhlo.power %[[VAL_32]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_34:.*]] = mhlo.add %[[VAL_31]], %[[VAL_33]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_35:.*]] = mhlo.add %[[VAL_32]], %[[VAL_7]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_36:.*]] = mhlo.power %[[VAL_35]], %[[VAL_5]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_37:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_38:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_37]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_39:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_35]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_40:.*]] = mhlo.divide %[[VAL_39]], %[[VAL_38]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_41:.*]] = mhlo.add %[[VAL_34]], %[[VAL_40]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_42:.*]] = mhlo.multiply %[[VAL_35]], %[[VAL_35]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_43:.*]] = mhlo.divide %[[VAL_7]], %[[VAL_42]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_44:.*]] = mhlo.constant dense<2.200000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_45:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_44]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_46:.*]] = mhlo.constant dense<2.100000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_47:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_46]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_48:.*]] = mhlo.multiply %[[VAL_45]], %[[VAL_47]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_49:.*]] = mhlo.constant dense<-1.39544646E-19> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_50:.*]] = mhlo.add %[[VAL_4]], %[[VAL_49]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_51:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_50]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_52:.*]] = mhlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_53:.*]] = mhlo.constant dense<2.000000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_54:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_53]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_55:.*]] = mhlo.constant dense<1.900000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_56:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_55]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_57:.*]] = mhlo.multiply %[[VAL_54]], %[[VAL_56]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_58:.*]] = mhlo.constant dense<5.50900303E-18> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_59:.*]] = mhlo.add %[[VAL_52]], %[[VAL_58]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_60:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_59]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_61:.*]] = mhlo.multiply %[[VAL_57]], %[[VAL_60]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_62:.*]] = mhlo.constant dense<1.800000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_63:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_62]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_64:.*]] = mhlo.constant dense<1.700000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_65:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_64]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_66:.*]] = mhlo.multiply %[[VAL_63]], %[[VAL_65]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_67:.*]] = mhlo.constant dense<-2.17486866E-16> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_68:.*]] = mhlo.add %[[VAL_61]], %[[VAL_67]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_69:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_68]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_70:.*]] = mhlo.multiply %[[VAL_66]], %[[VAL_69]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_71:.*]] = mhlo.constant dense<1.600000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_72:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_71]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_73:.*]] = mhlo.constant dense<1.500000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_74:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_73]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_75:.*]] = mhlo.multiply %[[VAL_72]], %[[VAL_74]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_76:.*]] = mhlo.constant dense<8.58606213E-15> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_77:.*]] = mhlo.add %[[VAL_70]], %[[VAL_76]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_78:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_77]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_79:.*]] = mhlo.multiply %[[VAL_75]], %[[VAL_78]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_80:.*]] = mhlo.constant dense<1.400000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_81:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_80]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_82:.*]] = mhlo.constant dense<1.300000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_83:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_82]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_84:.*]] = mhlo.multiply %[[VAL_81]], %[[VAL_83]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_85:.*]] = mhlo.constant dense<-3.3896803E-13> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_86:.*]] = mhlo.add %[[VAL_79]], %[[VAL_85]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_87:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_86]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_88:.*]] = mhlo.multiply %[[VAL_84]], %[[VAL_87]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_89:.*]] = mhlo.constant dense<1.200000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_90:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_89]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_91:.*]] = mhlo.constant dense<1.100000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_92:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_91]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_93:.*]] = mhlo.multiply %[[VAL_90]], %[[VAL_92]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_94:.*]] = mhlo.constant dense<1.33825364E-11> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_95:.*]] = mhlo.add %[[VAL_88]], %[[VAL_94]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_96:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_95]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_97:.*]] = mhlo.multiply %[[VAL_93]], %[[VAL_96]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_98:.*]] = mhlo.constant dense<1.000000e+01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_99:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_98]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_100:.*]] = mhlo.constant dense<9.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_101:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_100]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_102:.*]] = mhlo.multiply %[[VAL_99]], %[[VAL_101]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_103:.*]] = mhlo.constant dense<-5.28419031E-10> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_104:.*]] = mhlo.add %[[VAL_97]], %[[VAL_103]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_105:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_104]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_106:.*]] = mhlo.multiply %[[VAL_102]], %[[VAL_105]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_107:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_108:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_107]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_109:.*]] = mhlo.constant dense<7.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_110:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_109]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_111:.*]] = mhlo.multiply %[[VAL_108]], %[[VAL_110]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_112:.*]] = mhlo.constant dense<2.08767563E-8> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_113:.*]] = mhlo.add %[[VAL_106]], %[[VAL_112]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_114:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_113]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_115:.*]] = mhlo.multiply %[[VAL_111]], %[[VAL_114]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_116:.*]] = mhlo.constant dense<6.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_117:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_116]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_118:.*]] = mhlo.constant dense<5.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_119:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_118]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_120:.*]] = mhlo.multiply %[[VAL_117]], %[[VAL_119]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_121:.*]] = mhlo.constant dense<-8.26719599E-7> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_122:.*]] = mhlo.add %[[VAL_115]], %[[VAL_121]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_123:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_122]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_124:.*]] = mhlo.multiply %[[VAL_120]], %[[VAL_123]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_125:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_126:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_125]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_127:.*]] = mhlo.constant dense<3.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_128:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_127]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_129:.*]] = mhlo.multiply %[[VAL_126]], %[[VAL_128]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_130:.*]] = mhlo.constant dense<3.30687835E-5> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_131:.*]] = mhlo.add %[[VAL_124]], %[[VAL_130]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_132:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_131]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_133:.*]] = mhlo.multiply %[[VAL_129]], %[[VAL_132]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_134:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_135:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_134]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_136:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_137:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_136]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_138:.*]] = mhlo.multiply %[[VAL_135]], %[[VAL_137]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_139:.*]] = mhlo.constant dense<-0.00138888892> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_140:.*]] = mhlo.add %[[VAL_133]], %[[VAL_139]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_141:.*]] = mhlo.multiply %[[VAL_43]], %[[VAL_140]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_142:.*]] = mhlo.multiply %[[VAL_138]], %[[VAL_141]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_143:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_144:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_35]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_145:.*]] = mhlo.constant dense<0.0833333358> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_146:.*]] = mhlo.add %[[VAL_145]], %[[VAL_142]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_147:.*]] = mhlo.multiply %[[VAL_144]], %[[VAL_146]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_148:.*]] = mhlo.add %[[VAL_143]], %[[VAL_147]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_149:.*]] = mhlo.multiply %[[VAL_36]], %[[VAL_148]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_150:.*]] = mhlo.add %[[VAL_41]], %[[VAL_149]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_151:.*]] = "mhlo.abs"(%[[VAL_36]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_152:.*]] = "mhlo.abs"(%[[VAL_34]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_153:.*]] = mhlo.constant dense<1.401300e-45> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_154:.*]] = mhlo.multiply %[[VAL_152]], %[[VAL_153]] : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_155:.*]] = "mhlo.compare"(%[[VAL_151]], %[[VAL_154]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_156:.*]] = "mhlo.select"(%[[VAL_155]], %[[VAL_34]], %[[VAL_150]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_157:.*]] = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_158:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_159:.*]] = "mhlo.select"(%[[VAL_158]], %[[VAL_157]], %[[VAL_156]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_160:.*]] = mhlo.constant dense<0x7FC00000> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_161:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_37]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_162:.*]] = "mhlo.select"(%[[VAL_161]], %[[VAL_160]], %[[VAL_159]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_163:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[VAL_164:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_163]]) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_165:.*]] = "mhlo.floor"(%[[VAL_2]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_166:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_165]]) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_167:.*]] = mhlo.and %[[VAL_164]], %[[VAL_166]] : tensor<i1>
|
||||||
|
// CHECK: %[[VAL_169:.*]] = "mhlo.floor"(%[[VAL_3]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_170:.*]] = "mhlo.compare"(%[[VAL_3]], %[[VAL_169]]) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[VAL_171:.*]] = mhlo.and %[[VAL_164]], %[[VAL_170]] : tensor<i1>
|
||||||
|
// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
|
||||||
|
return %0 : tensor<f16>
|
||||||
|
// CHECK: return %[[VAL_174]] : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue