diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index c0d6e30..79530c0 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. #include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" @@ -692,7 +693,8 @@ class ConstConverter : public OpConversionPattern { if (valueAttr.getType().getRank() != 0) return failure(); auto stdConstOp = rewriter.create(loc, valueAttr.getValue({})); - rewriter.create(loc, stdConstOp, constOp.getOperand()); + rewriter.create(loc, stdConstOp, constOp.getOperand(), + ValueRange()); rewriter.eraseOp(constOp); return success(); } @@ -827,7 +829,8 @@ struct LhloLegalizeToLinalg void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 6981466..dd88e5c 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -329,7 +329,7 @@ func @constant(%value: memref) { return } // CHECK: %[[CONSTANT:.*]] = constant 10 : i32 -// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref +// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref // -----