Updates LLVM usage to match
[b24436ac96bd](https://github.com/llvm/llvm-project/commit/b24436ac96bd)

PiperOrigin-RevId: 364615807
This commit is contained in:
Stella Laurenzo 2021-03-23 12:18:57 -07:00 committed by TensorFlow MLIR Team
parent 8987dfd1d6
commit 7f2bf48b8b
23 changed files with 41 additions and 45 deletions

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "0776eca7a4e76bfadc311f3607be3a4f0c0e989a" LLVM_COMMIT = "b24436ac96bdf3f2c545fc85dc8af239d618c9c4"
LLVM_SHA256 = "93ac9bca36b7121d29a0415ce29c614293206622cf24b674e34e83545427c8bc" LLVM_SHA256 = "6af626445defe88eb4ccaa1ebdc6f7642775a8c8a64f2213157b4a16c26a2319"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
0776eca7a4e76bfadc311f3607be3a4f0c0e989a b24436ac96bdf3f2c545fc85dc8af239d618c9c4

View File

@ -24,8 +24,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
class OwningRewritePatternList;
namespace mhlo { namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product. // Collection of rewrite patterns for lowering a general dot product.

View File

@ -1239,7 +1239,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
void PopulateLegalizeChloToHloPatterns(MLIRContext *context, void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
populateWithGenerated(context, *patterns); populateWithGenerated(*patterns);
PopulateChloBroadcastingPatterns(context, patterns); PopulateChloBroadcastingPatterns(context, patterns);
// Other patterns. // Other patterns.

View File

@ -43,7 +43,7 @@ struct ChloLegalizeToHloPass
void runOnFunction() override { void runOnFunction() override {
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns(&getContext());
conversionTarget.addIllegalDialect<chlo::HloClientDialect>(); conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests. Also add helper dialects // Consider the mhlo dialect legal for tests. Also add helper dialects

View File

@ -572,8 +572,8 @@ struct HloLegalizeToLhlo
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {} HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
void runOnOperation() override { void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext(); auto& context = getContext();
OwningRewritePatternList patterns(&context);
ConversionTarget target(context); ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>(); target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
@ -608,16 +608,14 @@ struct HloLegalizeToLhlo
}); });
populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateFuncOpTypeConversionPattern(patterns, &context, converter); populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, &context, converter); populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, &context, populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
converter); populateReturnOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, &context, converter); populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
populateEliminateBufferizeMaterializationsPatterns(&context, converter,
patterns);
populateShapeStructuralTypeConversionsAndLegality(&context, converter, populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
patterns, target); target);
// TODO(b/175789537) Remove this pattern. // TODO(b/175789537) Remove this pattern.
patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context); patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);

View File

@ -131,7 +131,7 @@ struct LegalizeGatherToTorchIndexSelectPass
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> { : public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }

View File

@ -1969,7 +1969,7 @@ struct LhloLegalizeToLinalgPass
} }
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect, target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, memref::MemRefDialect, math::MathDialect, memref::MemRefDialect,
@ -1991,7 +1991,7 @@ struct HloLegalizeToLinalgPass
} }
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect, target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, StandardOpsDialect, math::MathDialect, StandardOpsDialect,

View File

@ -193,13 +193,13 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) { mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, *patterns); mlir::populateWithGenerated(*patterns);
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx); patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
} }
/// Perform the lowering to standard dialect. /// Perform the lowering to standard dialect.
void LegalizeToStandardPass::runOnFunction() { void LegalizeToStandardPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }

View File

@ -155,7 +155,7 @@ struct LegalizeTrigonometricToApproximationPass
FunctionPass> { FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns); PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }

View File

@ -155,9 +155,9 @@ struct LhloLegalizeToAffinePass
registry.insert<AffineDialect>(); registry.insert<AffineDialect>();
} }
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
populateLHLOToAffineConversionPattern(func.getContext(), &patterns); OwningRewritePatternList patterns(&getContext());
populateLHLOToAffineConversionPattern(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns)); (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
} }
}; };

View File

@ -178,7 +178,7 @@ struct LhloLegalizeToGpuPass
} }
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect, target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect, StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,

View File

@ -706,7 +706,7 @@ struct LhloLegalizeToParallelLoopsPass
void runOnFunction() override { void runOnFunction() override {
auto func = getFunction(); auto func = getFunction();
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
// clang-format off // clang-format off
patterns.insert< patterns.insert<
ReduceOpConverter, ReduceOpConverter,

View File

@ -59,7 +59,7 @@ namespace {
void PopulateComplexLoweringPatterns(MLIRContext* context, void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
populateWithGenerated(context, *patterns); populateWithGenerated(*patterns);
} }
} // end namespace mhlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir
@ -67,7 +67,7 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void LowerComplexPass::runOnFunction() { void LowerComplexPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -182,7 +182,7 @@ struct LegalizeGeneralDotPass
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> { : public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
/// Lower all general dots that can be represented as a non-batched matmul. /// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }

View File

@ -31,7 +31,7 @@ struct TestMaterializeBroadcastsPass
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> { : public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns(&getContext());
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<MhloDialect>(); conversionTarget.addLegalDialect<MhloDialect>();

View File

@ -146,7 +146,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target); PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
// Populate rewrite patterns. // Populate rewrite patterns.
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
// Apply transformation. // Apply transformation.

View File

@ -39,7 +39,7 @@ class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void OptimizeMhloPass::runOnFunction() { void OptimizeMhloPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns; mlir::OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -84,7 +84,7 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
struct TestInferShapedTypeMethodsPass struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> { : public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext()); patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
patterns.insert<InferReturnTypeComponentsPattern>(&getContext()); patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -557,7 +557,7 @@ struct TransformUnrankedHloPass
}); });
// Populate rewrite patterns. // Populate rewrite patterns.
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&ctx);
mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns); mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
// Apply transformation. // Apply transformation.

View File

@ -30,7 +30,7 @@ namespace {
struct TestUnfuseBatchNormPass struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> { : public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
void runOnOperation() override { void runOnOperation() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns(&getContext());
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} }

View File

@ -88,25 +88,25 @@ func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>
// CHECK-LABEL: func @int_constant // CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) { func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32> // CHECK-DAG: [[CST0:%.+]] = constant dense<0>
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>) %0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> // CHECK-DAG: [[CST1:%.+]] = constant dense<1>
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) %1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> // CHECK-DAG: [[CST2:%.+]] = constant dense<[
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) %2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32> // CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32> return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
} }
// CHECK-LABEL: func @float_constant // CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) { func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32> // CHECK-DAG: [[CST0:%.+]] = constant dense<0.000000e+00>
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>) %0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> // CHECK-DAG: [[CST1:%.+]] = constant dense<1.000000e+00>
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) %1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> // CHECK-DAG: [[CST2:%.+]] = constant dense<[
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) %2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32> // CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32> return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
} }

View File

@ -99,7 +99,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
return return
} }
// CHECK-LABEL: func @fusion // CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1 // CHECK: %[[C1:.*]] = constant 1 :
// CHECK-NOT: linalg.generic // CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]] // CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]] // CHECK: scf.for {{.*}} step %[[C1]]