From f6af1fc1341b4d3b5d05a1c595b94573b5cb8a55 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 9 Oct 2020 07:13:14 -0700 Subject: [PATCH] Support hlo to lhlo buffer placement through shape.assuming ops. PiperOrigin-RevId: 336287728 --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 7 +++++++ tests/hlo-legalize-to-lhlo.mlir | 21 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 3485aff..22338d2 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -20,6 +20,8 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -448,6 +450,10 @@ struct HloLegalizeToLhlo return std::all_of(op.operand_type_begin(), op.operand_type_end(), isMemRefType); }); + target.addDynamicallyLegalOp([&](shape::AssumingOp op) { + return std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); auto kind = results_escape_function ? BufferAssignmentTypeConverter::KeepAsFunctionResult @@ -460,6 +466,7 @@ struct HloLegalizeToLhlo populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter, &patterns); + populateShapeTypeConversionPatterns(&context, &converter, &patterns); if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); } diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 75e5c1b..3caa4f0 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -612,3 +612,24 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) { tensor_store %result_tensor, %result: memref<2x2xi1> return } + +// ----- + +// Test that assuming ops propagate memref types. +// BOTH-LABEL: func @shape_assuming_memref +func @shape_assuming_memref(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = shape.const_witness true + // BOTH: shape.assuming %{{.*}} -> (memref) + %2 = shape.assuming %1 -> (tensor) { + %3 = shape.shape_of %arg0 : tensor -> tensor + %4 = tensor_cast %3 : tensor to tensor<1xindex> + %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor + // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () + %7 = mhlo.maximum %5, %6 : tensor + // BOTH: shape.assuming_yield %{{.*}} : memref + shape.assuming_yield %7 : tensor + } + return %2 : tensor +} \ No newline at end of file