Updates LLVM usage to match
[b24436ac96bd](https://github.com/llvm/llvm-project/commit/b24436ac96bd)

PiperOrigin-RevId: 364615807
This commit is contained in:
Stella Laurenzo 2021-03-23 12:18:57 -07:00 committed by TensorFlow MLIR Team
parent 8987dfd1d6
commit 7f2bf48b8b
23 changed files with 41 additions and 45 deletions

View File

@ -15,9 +15,9 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
LLVM_COMMIT = "0776eca7a4e76bfadc311f3607be3a4f0c0e989a"
LLVM_COMMIT = "b24436ac96bdf3f2c545fc85dc8af239d618c9c4"
LLVM_SHA256 = "93ac9bca36b7121d29a0415ce29c614293206622cf24b674e34e83545427c8bc"
LLVM_SHA256 = "6af626445defe88eb4ccaa1ebdc6f7642775a8c8a64f2213157b4a16c26a2319"
LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT)

View File

@ -1,2 +1,2 @@
0776eca7a4e76bfadc311f3607be3a4f0c0e989a
b24436ac96bdf3f2c545fc85dc8af239d618c9c4

View File

@ -24,8 +24,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class OwningRewritePatternList;
namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product.

View File

@ -1239,7 +1239,7 @@ void PopulateChloBroadcastingPatterns(MLIRContext *context,
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(context, *patterns);
populateWithGenerated(*patterns);
PopulateChloBroadcastingPatterns(context, patterns);
// Other patterns.

View File

@ -43,7 +43,7 @@ struct ChloLegalizeToHloPass
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
OwningRewritePatternList conversionPatterns(&getContext());
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests. Also add helper dialects

View File

@ -572,8 +572,8 @@ struct HloLegalizeToLhlo
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext();
OwningRewritePatternList patterns(&context);
ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
@ -608,16 +608,14 @@ struct HloLegalizeToLhlo
});
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateFuncOpTypeConversionPattern(patterns, &context, converter);
populateCallOpTypeConversionPattern(patterns, &context, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, &context,
converter);
populateReturnOpTypeConversionPattern(patterns, &context, converter);
populateEliminateBufferizeMaterializationsPatterns(&context, converter,
patterns);
populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
target);
// TODO(b/175789537) Remove this pattern.
patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);

View File

@ -131,7 +131,7 @@ struct LegalizeGatherToTorchIndexSelectPass
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}

View File

@ -1969,7 +1969,7 @@ struct LhloLegalizeToLinalgPass
}
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, memref::MemRefDialect,
@ -1991,7 +1991,7 @@ struct HloLegalizeToLinalgPass
}
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, StandardOpsDialect,

View File

@ -193,13 +193,13 @@ std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, *patterns);
mlir::populateWithGenerated(*patterns);
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
}
/// Perform the lowering to standard dialect.
void LegalizeToStandardPass::runOnFunction() {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}

View File

@ -155,7 +155,7 @@ struct LegalizeTrigonometricToApproximationPass
FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}

View File

@ -155,9 +155,9 @@ struct LhloLegalizeToAffinePass
registry.insert<AffineDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
auto func = getFunction();
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
OwningRewritePatternList patterns(&getContext());
populateLHLOToAffineConversionPattern(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};

View File

@ -178,7 +178,7 @@ struct LhloLegalizeToGpuPass
}
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,

View File

@ -706,7 +706,7 @@ struct LhloLegalizeToParallelLoopsPass
void runOnFunction() override {
auto func = getFunction();
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
// clang-format off
patterns.insert<
ReduceOpConverter,

View File

@ -59,7 +59,7 @@ namespace {
void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(context, *patterns);
populateWithGenerated(*patterns);
}
} // end namespace mhlo
} // end namespace mlir
@ -67,7 +67,7 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
// Lowers the complex operations that can be represented using other operations.
void LowerComplexPass::runOnFunction() {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -182,7 +182,7 @@ struct LegalizeGeneralDotPass
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}

View File

@ -31,7 +31,7 @@ struct TestMaterializeBroadcastsPass
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
OwningRewritePatternList conversionPatterns(&getContext());
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<MhloDialect>();

View File

@ -146,7 +146,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
PopulateMoveUpDynamicBroadcastsForFusionLegality(&target);
// Populate rewrite patterns.
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns);
// Apply transformation.

View File

@ -39,7 +39,7 @@ class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
// Lowers the complex operations that can be represented using other operations.
void OptimizeMhloPass::runOnFunction() {
// Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns;
mlir::OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -84,7 +84,7 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

View File

@ -557,7 +557,7 @@ struct TransformUnrankedHloPass
});
// Populate rewrite patterns.
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&ctx);
mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
// Apply transformation.

View File

@ -30,7 +30,7 @@ namespace {
struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
void runOnOperation() override {
OwningRewritePatternList patterns;
OwningRewritePatternList patterns(&getContext());
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

View File

@ -88,25 +88,25 @@ func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>
// CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
// CHECK-DAG: [[CST0:%.+]] = constant dense<0>
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
// CHECK-DAG: [[CST1:%.+]] = constant dense<1>
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
// CHECK-DAG: [[CST2:%.+]] = constant dense<[
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
// CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
}
// CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
// CHECK-DAG: [[CST0:%.+]] = constant dense<0.000000e+00>
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
// CHECK-DAG: [[CST1:%.+]] = constant dense<1.000000e+00>
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
// CHECK-DAG: [[CST2:%.+]] = constant dense<[
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
// CHECK: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
}

View File

@ -99,7 +99,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK: %[[C1:.*]] = constant 1 :
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]