diff --git a/WORKSPACE b/WORKSPACE index da69d1f..2b586a2 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,9 +15,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "0776eca7a4e76bfadc311f3607be3a4f0c0e989a" +LLVM_COMMIT = "b24436ac96bdf3f2c545fc85dc8af239d618c9c4" -LLVM_SHA256 = "93ac9bca36b7121d29a0415ce29c614293206622cf24b674e34e83545427c8bc" +LLVM_SHA256 = "6af626445defe88eb4ccaa1ebdc6f7642775a8c8a64f2213157b4a16c26a2319" LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 695a15e..2b0574e 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1,2 +1,2 @@ -0776eca7a4e76bfadc311f3607be3a4f0c0e989a +b24436ac96bdf3f2c545fc85dc8af239d618c9c4 diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 2c7c1cf..672711d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -24,8 +24,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" namespace mlir { -class OwningRewritePatternList; - namespace mhlo { // Collection of rewrite patterns for lowering a general dot product. diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 90d0c91..da883a6 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -1239,7 +1239,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context, void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, *patterns); + populateWithGenerated(*patterns); PopulateChloBroadcastingPatterns(context, patterns); // Other patterns. diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 5daf607..c31fd28 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -43,7 +43,7 @@ struct ChloLegalizeToHloPass void runOnFunction() override { ConversionTarget conversionTarget(getContext()); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); conversionTarget.addIllegalDialect(); // Consider the mhlo dialect legal for tests. Also add helper dialects diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 38f817b..ea9ad5a 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -572,8 +572,8 @@ struct HloLegalizeToLhlo HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {} void runOnOperation() override { - OwningRewritePatternList patterns; auto& context = getContext(); + OwningRewritePatternList patterns(&context); ConversionTarget target(context); target.addLegalDialect(); target.addLegalDialect(); @@ -608,16 +608,14 @@ struct HloLegalizeToLhlo }); populateHLOToLHLOConversionPattern(&context, &converter, &patterns); - populateFuncOpTypeConversionPattern(patterns, &context, converter); - populateCallOpTypeConversionPattern(patterns, &context, converter); - populateBranchOpInterfaceTypeConversionPattern(patterns, &context, - converter); - populateReturnOpTypeConversionPattern(patterns, &context, converter); - populateEliminateBufferizeMaterializationsPatterns(&context, converter, - patterns); + populateFuncOpTypeConversionPattern(patterns, converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateBranchOpInterfaceTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + populateEliminateBufferizeMaterializationsPatterns(converter, patterns); - populateShapeStructuralTypeConversionsAndLegality(&context, converter, - patterns, target); + populateShapeStructuralTypeConversionsAndLegality(converter, patterns, + target); // TODO(b/175789537) Remove this pattern. patterns.insert(&context); diff --git a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc index 0f6efbc..2416076 100644 --- a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc +++ b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc @@ -131,7 +131,7 @@ struct LegalizeGatherToTorchIndexSelectPass : public PassWrapper { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 493c6ad..c7fbd80 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1969,7 +1969,7 @@ struct LhloLegalizeToLinalgPass } void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect> createLegalizeToStdPass() { void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, mlir::MLIRContext *ctx) { - mlir::populateWithGenerated(ctx, *patterns); + mlir::populateWithGenerated(*patterns); patterns->insert(ctx); } /// Perform the lowering to standard dialect. void LegalizeToStandardPass::runOnFunction() { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index a6e0b78..180bf9d 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -155,7 +155,7 @@ struct LegalizeTrigonometricToApproximationPass FunctionPass> { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 72af982..0235d2d 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -155,9 +155,9 @@ struct LhloLegalizeToAffinePass registry.insert(); } void runOnFunction() override { - OwningRewritePatternList patterns; auto func = getFunction(); - populateLHLOToAffineConversionPattern(func.getContext(), &patterns); + OwningRewritePatternList patterns(&getContext()); + populateLHLOToAffineConversionPattern(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 91c50b0..915fef9 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -178,7 +178,7 @@ struct LhloLegalizeToGpuPass } void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect { /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index d410a26..4cb618d 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -31,7 +31,7 @@ struct TestMaterializeBroadcastsPass : public PassWrapper { void runOnFunction() override { ConversionTarget conversionTarget(getContext()); - OwningRewritePatternList conversionPatterns; + OwningRewritePatternList conversionPatterns(&getContext()); // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 6bf23f4..34a3ffc 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -146,7 +146,7 @@ struct MoveUpDynamicBroadcastsForFusionPass PopulateMoveUpDynamicBroadcastsForFusionLegality(&target); // Populate rewrite patterns. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&ctx); mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); // Apply transformation. diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc index 0d181e4..539b70a 100644 --- a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -39,7 +39,7 @@ class OptimizeMhloPass : public PassWrapper { // Lowers the complex operations that can be represented using other operations. void OptimizeMhloPass::runOnFunction() { // Add lowering patterns to the list. - mlir::OwningRewritePatternList patterns; + mlir::OwningRewritePatternList patterns(&getContext()); mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index 9cf7c00..de06621 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -84,7 +84,7 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { struct TestInferShapedTypeMethodsPass : public PassWrapper { void runOnFunction() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index ded4ee3..e3984d1 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -557,7 +557,7 @@ struct TransformUnrankedHloPass }); // Populate rewrite patterns. - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&ctx); mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns); // Apply transformation. diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index e2b684a..33e0a12 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -30,7 +30,7 @@ namespace { struct TestUnfuseBatchNormPass : public PassWrapper> { void runOnOperation() override { - OwningRewritePatternList patterns; + OwningRewritePatternList patterns(&getContext()); PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/tests/legalize-to-std.mlir b/tests/legalize-to-std.mlir index 14a310b..0581f2f 100644 --- a/tests/legalize-to-std.mlir +++ b/tests/legalize-to-std.mlir @@ -88,25 +88,25 @@ func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1> // CHECK-LABEL: func @int_constant func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { - // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor + // CHECK-DAG: [[CST0:%.+]] = constant dense<0> %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor) - // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> + // CHECK-DAG: [[CST1:%.+]] = constant dense<1> %1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) - // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> + // CHECK-DAG: [[CST2:%.+]] = constant dense<[ %2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) - // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xi32>, tensor<2x3xi32> + // CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xi32>, tensor<2x3xi32> return %0, %1, %2: tensor, tensor<2x3xi32>, tensor<2x3xi32> } // CHECK-LABEL: func @float_constant func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { - // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor + // CHECK-DAG: [[CST0:%.+]] = constant dense<0.000000e+00> %0 = "mhlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) - // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> + // CHECK-DAG: [[CST1:%.+]] = constant dense<1.000000e+00> %1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) - // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> + // CHECK-DAG: [[CST2:%.+]] = constant dense<[ %2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) - // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xf32>, tensor<2x3xf32> + // CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xf32>, tensor<2x3xf32> return %0, %1, %2: tensor, tensor<2x3xf32>, tensor<2x3xf32> } diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index 2e5d494..05a3984 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -99,7 +99,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, return } // CHECK-LABEL: func @fusion -// CHECK: %[[C1:.*]] = constant 1 +// CHECK: %[[C1:.*]] = constant 1 : // CHECK-NOT: linalg.generic // CHECK: scf.for {{.*}} step %[[C1]] // CHECK: scf.for {{.*}} step %[[C1]]