Support hlo to lhlo buffer placement through shape.assuming ops.
PiperOrigin-RevId: 336287728
This commit is contained in:
parent
d986bd7ad7
commit
f6af1fc134
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -448,6 +450,10 @@ struct HloLegalizeToLhlo
|
|||
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
target.addDynamicallyLegalOp<shape::AssumingOp>([&](shape::AssumingOp op) {
|
||||
return std::all_of(op.result_type_begin(), op.result_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
|
||||
auto kind = results_escape_function
|
||||
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||
|
@ -460,6 +466,7 @@ struct HloLegalizeToLhlo
|
|||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
|
||||
&patterns);
|
||||
populateShapeTypeConversionPatterns(&context, &converter, &patterns);
|
||||
if (failed(applyPartialConversion(getOperation(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -612,3 +612,24 @@ func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
|||
tensor_store %result_tensor, %result: memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that assuming ops propagate memref types.
|
||||
// BOTH-LABEL: func @shape_assuming_memref
|
||||
func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
|
||||
%1 = shape.const_witness true
|
||||
// BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
|
||||
%2 = shape.assuming %1 -> (tensor<?xf16>) {
|
||||
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
|
||||
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
||||
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
// BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
||||
// BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||
shape.assuming_yield %7 : tensor<?xf16>
|
||||
}
|
||||
return %2 : tensor<?xf16>
|
||||
}
|
Loading…
Reference in New Issue