PR #40925: [MLIR] Update lhlo.const to linalg lowering to use affine.store inste…
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/40925 …ad of std.store The xla_lhlo.const lowering uses std.store to store a constant to 0-d memrefs. Update it to affine.store since such an access is trivially affine (no indices). An affine.store can always be lowered to std.store. Copybara import of the project: -- 9e18ede72fbbca107177bd742921e4cbf77adc82 by Uday Bondhugula <uday@polymagelabs.com>: [MLIR] Update lhlo.const to linalg lowering to use affine.store instead of std.store The xla_lhlo.const lowering uses std.store to store a constant to 0-d memrefs. Update it to affine.store since such an access is trivially affine (no indices). An affine.store can always be lowered to std.store. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/40925 from polymage-labs:lhlo_to_linalg_affine_store 9e18ede72fbbca107177bd742921e4cbf77adc82 PiperOrigin-RevId: 320623152
This commit is contained in:
parent
6eaccefdab
commit
d166b66cba
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||||
|
|
||||||
#include "third_party/absl/memory/memory.h"
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -692,7 +693,8 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
||||||
if (valueAttr.getType().getRank() != 0) return failure();
|
if (valueAttr.getType().getRank() != 0) return failure();
|
||||||
auto stdConstOp =
|
auto stdConstOp =
|
||||||
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
|
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
|
||||||
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
|
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
|
||||||
|
ValueRange());
|
||||||
rewriter.eraseOp(constOp);
|
rewriter.eraseOp(constOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
|
AffineDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
|
|
@ -329,7 +329,7 @@ func @constant(%value: memref<i32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
|
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
|
||||||
// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32>
|
// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref<i32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue