PR #40925: [MLIR] Update lhlo.const to linalg lowering to use affine.store inste…

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40925

…ad of std.store

The xla_lhlo.const lowering uses std.store to store a constant to
0-d memrefs. Update it to affine.store since such an access is trivially
affine (no indices). An affine.store can always be lowered to std.store.
Copybara import of the project:

--
9e18ede72fbbca107177bd742921e4cbf77adc82 by Uday Bondhugula <uday@polymagelabs.com>:

[MLIR] Update lhlo.const to linalg lowering to use affine.store instead of std.store

The xla_lhlo.const lowering uses std.store to store a constant to
0-d memrefs. Update it to affine.store since such an access is trivially
affine (no indices). An affine.store can always be lowered to std.store.

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/40925 from polymage-labs:lhlo_to_linalg_affine_store 9e18ede72fbbca107177bd742921e4cbf77adc82
PiperOrigin-RevId: 320623152
This commit is contained in:
Uday Bondhugula 2020-07-10 17:03:44 +00:00 committed by Mehdi Amini
parent 6eaccefdab
commit d166b66cba
2 changed files with 6 additions and 3 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "third_party/absl/memory/memory.h" #include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
@ -692,7 +693,8 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
if (valueAttr.getType().getRank() != 0) return failure(); if (valueAttr.getType().getRank() != 0) return failure();
auto stdConstOp = auto stdConstOp =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({})); rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand()); rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
ValueRange());
rewriter.eraseOp(constOp); rewriter.eraseOp(constOp);
return success(); return success();
} }
@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>(); target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
AffineDialect>();
auto func = getFunction(); auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);

View File

@ -329,7 +329,7 @@ func @constant(%value: memref<i32>) {
return return
} }
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32 // CHECK: %[[CONSTANT:.*]] = constant 10 : i32
// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32> // CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref<i32>
// ----- // -----