Updates LLVM usage to match
[2bfe27da171e](https://github.com/llvm/llvm-project/commit/2bfe27da171e)

PiperOrigin-RevId: 357196336
This commit is contained in:
A. Unique TensorFlower 2021-02-12 08:30:51 -08:00 committed by TensorFlow MLIR Team
parent e993082b97
commit 4060a86fe2
12 changed files with 47 additions and 39 deletions

3
BUILD
View File

@ -595,6 +595,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
],
@ -666,6 +667,7 @@ cc_library(
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@ -876,6 +878,7 @@ cc_library(
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "6f04addc8b2eee0d88b97facfa4fb7424b4b74bd"
LLVM_COMMIT = "2bfe27da171e8a6dddac6c444c4bca003103941a"
LLVM_SHA256 = "76c531cd1e701ddf2765a95afb3cd3bdb9f7151f86536632c3072c25d9862e8e"
LLVM_SHA256 = "e94b8ec0d1b854da11dbf1587701e798df97be630f7019d0658452fae5e03bcf"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
6f04addc8b2eee0d88b97facfa4fb7424b4b74bd
2bfe27da171e8a6dddac6c444c4bca003103941a

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
@ -173,7 +174,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Atan2Op>{}(
loc, result_types, args, b);
}
@ -246,7 +247,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpOp>{}(
loc, result_types, args, b);
}
@ -358,7 +359,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::CosOp>{}(
loc, result_types, args, b);
}
@ -367,7 +368,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SinOp>{}(
loc, result_types, args, b);
}
@ -434,7 +435,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::LogOp>{}(
loc, result_types, args, b);
}
@ -463,7 +464,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::LogisticOp>(
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
Value x = args.front();
Value neg_x = b->create<NegFOp>(loc, x);
Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x);
Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x);
Value one_add_exp_neg_x = b->create<AddFOp>(loc, one, exp_neg_x);
return b->create<DivFOp>(loc, one, one_add_exp_neg_x);
}
@ -473,7 +474,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Log1pOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::Log1pOp>{}(
loc, result_types, args, b);
}
@ -579,7 +580,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::RsqrtOp>{}(
loc, result_types, args, b);
}
@ -593,8 +594,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
b);
return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types,
args, b);
assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now");
@ -746,7 +747,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::SqrtOp>{}(
loc, result_types, args, b);
}
@ -766,7 +767,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::TanhOp>{}(
loc, result_types, args, b);
}

View File

@ -94,6 +94,7 @@ add_mlir_library(MhloToLhloConversion
LmhloDialect
MLIRIR
MLIRPass
MLIRMath
)
add_mlir_library(MhloToStandard

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -1456,14 +1457,15 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect>();
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
StandardOpsDialect, AffineDialect>();
math::MathDialect, StandardOpsDialect,
AffineDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
@ -1477,15 +1479,15 @@ struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
complex::ComplexDialect>();
complex::ComplexDialect, math::MathDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
StandardOpsDialect, tensor::TensorDialect,
scf::SCFDialect>();
math::MathDialect, StandardOpsDialect,
tensor::TensorDialect, scf::SCFDialect>();
auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@ -74,10 +75,10 @@ class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
};
class ApproximateTanhLowering
: public ApproximateOnExtendedF32Lowering<TanhOp> {
: public ApproximateOnExtendedF32Lowering<math::TanhOp> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<TanhOp>(ctx) {}
: ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
// Emits the fast tanh approximation that is also used by XLA.
Value emitApproximation(ValueRange args, Location loc,

View File

@ -13,7 +13,7 @@ func @print_f32(%arg : f32) {
}
func @tanh_f32(%arg : f32) {
%res = tanh %arg : f32
%res = math.tanh %arg : f32
call @print_f32(%res) : (f32) -> ()
return
}

View File

@ -182,7 +182,7 @@ func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00
// CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]]
// CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]]
// CHECK: %[[EXP_NEG_ARG:.*]] = math.exp %[[NEG_ARG]]
// CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]]
// CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]]
// CHECK: linalg.yield %[[RESULT]]
@ -834,7 +834,7 @@ func @float_pow(%lhs: tensor<2x2xf32>,
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>

View File

@ -1,7 +1,7 @@
// RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64
%res = math.tanh %arg0 : f64
return %res : f64
}
@ -11,7 +11,7 @@ func @tanh_f64(%arg0 : f64) -> f64 {
// -----
func @tanh_f32(%arg0 : f32) -> f32 {
%res = tanh %arg0 : f32
%res = math.tanh %arg0 : f32
return %res : f32
}
@ -66,7 +66,7 @@ func @tanh_f32(%arg0 : f32) -> f32 {
// -----
func @tanh_f16(%arg0 : f16) -> f16 {
%res = tanh %arg0 : f16
%res = math.tanh %arg0 : f16
return %res : f16
}
@ -125,7 +125,7 @@ func @tanh_f16(%arg0 : f16) -> f16 {
// CHECK-LABEL: @atan2_f64
func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 {
// CHECK: atan2
%res = atan2 %arg0, %arg1 : f64
%res = math.atan2 %arg0, %arg1 : f64
return %res : f64
}
@ -134,6 +134,6 @@ func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 {
// CHECK-LABEL: @atan_f64
func @atan_f64(%arg : f64) -> f64 {
// CHECK: atan
%res = atan %arg : f64
%res = math.atan %arg : f64
return %res : f64
}

View File

@ -92,7 +92,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%2 = exp %arg3 : f32
%2 = math.exp %arg3 : f32
linalg.yield %2 : f32
}
dealloc %1 : memref<100x10xf32>

View File

@ -10,7 +10,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.powf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -115,7 +115,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -127,7 +127,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.log %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -522,7 +522,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = cos %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.cos %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -536,7 +536,7 @@ func @sin(%input: memref<2x2xf32>,
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.sin %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -612,7 +612,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.rsqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -676,7 +676,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.sqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
@ -688,7 +688,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = math.tanh %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----