Support hlo to lhlo buffer placement through shape.assuming ops.

PiperOrigin-RevId: 336287728
This commit is contained in:
Tres Popp 2020-10-09 07:13:14 -07:00 committed by TensorFlow MLIR Team
parent d986bd7ad7
commit f6af1fc134
2 changed files with 28 additions and 0 deletions

View File

@ -20,6 +20,8 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@ -448,6 +450,10 @@ struct HloLegalizeToLhlo
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType);
});
target.addDynamicallyLegalOp<shape::AssumingOp>([&](shape::AssumingOp op) {
return std::all_of(op.result_type_begin(), op.result_type_end(),
isMemRefType);
});
auto kind = results_escape_function
? BufferAssignmentTypeConverter::KeepAsFunctionResult
@ -460,6 +466,7 @@ struct HloLegalizeToLhlo
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
&patterns);
populateShapeTypeConversionPatterns(&context, &converter, &patterns);
if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}

View File

@ -612,3 +612,24 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
tensor_store %result_tensor, %result: memref<2x2xi1>
return
}
// -----
// Test that assuming ops propagate memref types.
// BOTH-LABEL: func @shape_assuming_memref
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
%1 = shape.const_witness true
// BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
%2 = shape.assuming %1 -> (tensor<?xf16>) {
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
// BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16>
}
return %2 : tensor<?xf16>
}