diff --git a/BUILD b/BUILD index 0b19c55..1146494 100644 --- a/BUILD +++ b/BUILD @@ -595,6 +595,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", ], @@ -666,6 +667,7 @@ cc_library( "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -876,6 +878,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", diff --git a/WORKSPACE b/WORKSPACE index 91ee744..9f5846f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "6f04addc8b2eee0d88b97facfa4fb7424b4b74bd" +LLVM_COMMIT = "2bfe27da171e8a6dddac6c444c4bca003103941a" -LLVM_SHA256 = "76c531cd1e701ddf2765a95afb3cd3bdb9f7151f86536632c3072c25d9862e8e" +LLVM_SHA256 = "e94b8ec0d1b854da11dbf1587701e798df97be630f7019d0658452fae5e03bcf" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index aa70f60..7312638 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -6f04addc8b2eee0d88b97facfa4fb7424b4b74bd +2bfe27da171e8a6dddac6c444c4bca003103941a diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index bbe156d..7aa956f 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -24,6 +24,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" @@ -173,7 +174,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -246,7 +247,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -358,7 +359,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -367,7 +368,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -434,7 +435,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -463,7 +464,7 @@ inline Value MapLhloOpToStdScalarOp( Value one = b->create(loc, b->getFloatAttr(ty, 1.0)); Value x = args.front(); Value neg_x = b->create(loc, x); - Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x); + Value exp_neg_x = b->create<::mlir::math::ExpOp>(loc, neg_x); Value one_add_exp_neg_x = b->create(loc, one, exp_neg_x); return b->create(loc, one, one_add_exp_neg_x); } @@ -473,7 +474,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -579,7 +580,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -593,8 +594,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, // Floating point can use std::powf auto result_type = result_types.front(); if (result_type.isa<::mlir::FloatType>()) - return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args, - b); + return MapLhloOpToStdScalarOpImpl<::mlir::math::PowFOp>{}(loc, result_types, + args, b); assert(result_type.isa<::mlir::IntegerType>() && "only float and integer `pow` is supported right now"); @@ -746,7 +747,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -766,7 +767,7 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}( + return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 50866b7..d200be6 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -94,6 +94,7 @@ add_mlir_library(MhloToLhloConversion LmhloDialect MLIRIR MLIRPass + MLIRMath ) add_mlir_library(MhloToStandard diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index c5c9edd..a8841f8 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -1456,14 +1457,15 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, struct LhloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + math::MathDialect, StandardOpsDialect, + AffineDialect>(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1477,15 +1479,15 @@ struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + complex::ComplexDialect, math::MathDialect>(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + math::MathDialect, StandardOpsDialect, + tensor::TensorDialect, scf::SCFDialect>(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 9c1dc9c..64be60f 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -74,10 +75,10 @@ class ApproximateOnExtendedF32Lowering : public OpRewritePattern { }; class ApproximateTanhLowering - : public ApproximateOnExtendedF32Lowering { + : public ApproximateOnExtendedF32Lowering { public: explicit ApproximateTanhLowering(MLIRContext *ctx) - : ApproximateOnExtendedF32Lowering(ctx) {} + : ApproximateOnExtendedF32Lowering(ctx) {} // Emits the fast tanh approximation that is also used by XLA. Value emitApproximation(ValueRange args, Location loc, diff --git a/tests/end2end/legalize-trigonometric-to-approximation.mlir b/tests/end2end/legalize-trigonometric-to-approximation.mlir index 57c1d59..601dacb 100644 --- a/tests/end2end/legalize-trigonometric-to-approximation.mlir +++ b/tests/end2end/legalize-trigonometric-to-approximation.mlir @@ -13,7 +13,7 @@ func @print_f32(%arg : f32) { } func @tanh_f32(%arg : f32) { - %res = tanh %arg : f32 + %res = math.tanh %arg : f32 call @print_f32(%res) : (f32) -> () return } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 34dcfbf..59e7255 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -182,7 +182,7 @@ func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32): // CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00 // CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]] - // CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]] + // CHECK: %[[EXP_NEG_ARG:.*]] = math.exp %[[NEG_ARG]] // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]] // CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]] // CHECK: linalg.yield %[[RESULT]] @@ -834,7 +834,7 @@ func @float_pow(%lhs: tensor<2x2xf32>, // CHECK: ^{{[a-z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]] + // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]] // CHECK: linalg.yield %[[RESULT]] %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index 959b8c2..7178c6a 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -1,7 +1,7 @@ // RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s func @tanh_f64(%arg0 : f64) -> f64 { - %res = tanh %arg0 : f64 + %res = math.tanh %arg0 : f64 return %res : f64 } @@ -11,7 +11,7 @@ func @tanh_f64(%arg0 : f64) -> f64 { // ----- func @tanh_f32(%arg0 : f32) -> f32 { - %res = tanh %arg0 : f32 + %res = math.tanh %arg0 : f32 return %res : f32 } @@ -66,7 +66,7 @@ func @tanh_f32(%arg0 : f32) -> f32 { // ----- func @tanh_f16(%arg0 : f16) -> f16 { - %res = tanh %arg0 : f16 + %res = math.tanh %arg0 : f16 return %res : f16 } @@ -125,7 +125,7 @@ func @tanh_f16(%arg0 : f16) -> f16 { // CHECK-LABEL: @atan2_f64 func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { // CHECK: atan2 - %res = atan2 %arg0, %arg1 : f64 + %res = math.atan2 %arg0, %arg1 : f64 return %res : f64 } @@ -134,6 +134,6 @@ func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { // CHECK-LABEL: @atan_f64 func @atan_f64(%arg : f64) -> f64 { // CHECK: atan - %res = atan %arg : f64 + %res = math.atan %arg : f64 return %res : f64 } diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index 54aceaf..a46668c 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -92,7 +92,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, ins(%1 : memref<100x10xf32>) outs(%arg2 : memref<100x10xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors - %2 = exp %arg3 : f32 + %2 = math.exp %arg3 : f32 linalg.yield %2 : f32 } dealloc %1 : memref<100x10xf32> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index e31369c..e0ab3e2 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -10,7 +10,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): -// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.powf %[[LHS_IN]], %[[RHS_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -115,7 +115,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = exp %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.exp %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -127,7 +127,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.log %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -522,7 +522,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = cos %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.cos %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -536,7 +536,7 @@ func @sin(%input: memref<2x2xf32>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.sin %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -612,7 +612,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.rsqrt %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -676,7 +676,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.sqrt %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -688,7 +688,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = math.tanh %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // -----