PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).
Copybara import of the project:

--
a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).

--
55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>:

minor fix

PiperOrigin-RevId: 375500273
This commit is contained in:
wyzhao 2021-05-24 10:10:21 -07:00 committed by TensorFlow MLIR Team
parent 28c4112f35
commit b93e54d8a4
8 changed files with 44 additions and 28 deletions

View File

@ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands, LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes); operands, &reifiedReturnShapes);
} }
}]; }];
} }

View File

@ -79,7 +79,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
// //
// and returns %4 as the shape value. // and returns %4 as the shape value.
LogicalResult deriveShapeFromFirstOperand( LogicalResult deriveShapeFromFirstOperand(
OpBuilder *builder, Operation *op, OpBuilder *builder, Operation *op, ValueRange operands,
SmallVectorImpl<Value> *reifiedReturnShapes); SmallVectorImpl<Value> *reifiedReturnShapes);
// Type derivation function that returns a tensor type with a new element type. // Type derivation function that returns a tensor type with a new element type.

View File

@ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
OpBuilder& builder, ValueRange operands, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
operands,
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputOutputShapeEquality(int input, int output) { bool inferInputOutputShapeEquality(int input, int output) {
@ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
OpBuilder& builder, ValueRange operands, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
operands,
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputsShapeEquality(int lhs, int rhs) { bool inferInputsShapeEquality(int lhs, int rhs) {

View File

@ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
} }
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
OpBuilder& builder, Operation* op, OpBuilder& builder, Operation* op, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto lhs = op->getOperand(0); auto lhs = operands[0];
auto rhs = op->getOperand(1); auto rhs = operands[1];
// Check for "numpy"-style rank broadcast. // Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op->getAttr("broadcast_dimensions") auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
@ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
inferedReturnShapes); inferedReturnShapes);
} }
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes); operands, reifiedReturnShapes);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
} }
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes); operands, reifiedReturnShapes);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes(
inferedReturnShapes); \ inferedReturnShapes); \
} \ } \
LogicalResult Op::reifyReturnTypeShapes( \ LogicalResult Op::reifyReturnTypeShapes( \
OpBuilder& builder, ValueRange, \ OpBuilder& builder, ValueRange operands, \
SmallVectorImpl<Value>& reifiedReturnShapes) { \ SmallVectorImpl<Value>& reifiedReturnShapes) { \
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \ return ReifyBroadcastBinaryOpReturnTypeShapes( \
reifiedReturnShapes); \ builder, getOperation(), operands, reifiedReturnShapes); \
} }
#define BROADCAST_BINARY_OP_DEFS(Op) \ #define BROADCAST_BINARY_OP_DEFS(Op) \

View File

@ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
} }
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder&, ValueRange operands,
reifiedReturnShapes.push_back(output_dimensions()); SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
reifiedReturnShapes.push_back(adaptor.output_dimensions());
return success(); return success();
} }
@ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents(
LogicalResult SelectOp::reifyReturnTypeShapes( LogicalResult SelectOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
&reifiedReturnShapes); &reifiedReturnShapes);
} }
@ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents(
LogicalResult CompareOp::reifyReturnTypeShapes( LogicalResult CompareOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands, OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
&reifiedReturnShapes); &reifiedReturnShapes);
} }
@ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult deriveShapeFromFirstOperand( LogicalResult deriveShapeFromFirstOperand(
OpBuilder* builder, Operation* op, OpBuilder* builder, Operation* op, ValueRange operands,
SmallVectorImpl<Value>* reifiedReturnShapes) { SmallVectorImpl<Value>* reifiedReturnShapes) {
Value operand = op->getOperand(0); Value operand = operands.front();
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>(); ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
if (!operand_type) { if (!operand_type) {
op->emitOpError() << "first operand is not a shaped type"; op->emitOpError() << "first operand is not a shaped type";

View File

@ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result,
/// to the `results` vector. /// to the `results` vector.
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results, LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
ConversionPatternRewriter& rewriter) { ConversionPatternRewriter& rewriter) {
size_t num_operands = results.size();
SmallVector<Value, 2> tensor_operands;
for (auto result : llvm::enumerate(op->getResults())) { for (auto result : llvm::enumerate(op->getResults())) {
RankedTensorType resultType = RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>(); result.value().getType().dyn_cast<RankedTensorType>();
@ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure(); if (!shape_type_op) return failure();
if (tensor_operands.empty()) {
for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
auto tp = operand.getType().cast<ShapedType>();
tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
op->getLoc(),
RankedTensorType::get(tp.getShape(), tp.getElementType()),
operand));
}
}
SmallVector<Value, 1> results_shape; SmallVector<Value, 1> results_shape;
auto status = shape_type_op.reifyReturnTypeShapes( auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
rewriter, shape_type_op->getOperands(), results_shape); results_shape);
if (failed(status)) return failure(); if (failed(status)) return failure();
results.push_back( results.push_back(
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(), InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
@ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter
} else { } else {
SmallVector<Value, 1> results_shape; SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (failed(shape_type_op.reifyReturnTypeShapes( if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
rewriter, shape_type_op->getOperands(), results_shape))) results_shape)))
return failure(); return failure();
bufferArgs[2] = InsertDynamicAllocAndDealloc( bufferArgs[2] = InsertDynamicAllocAndDealloc(
@ -585,9 +597,9 @@ struct HloLegalizeToLhlo
target.addLegalDialect<shape::ShapeDialect>(); target.addLegalDialect<shape::ShapeDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalDialect<mhlo::MhloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
// Declare tensor_load and tensor_store illegal. // Declare tensor_store illegal. tensor_load may be used to reify output
target.addIllegalOp<mlir::memref::TensorLoadOp, // shape computation during dialect conversion and will be handled later.
mlir::memref::TensorStoreOp>(); target.addIllegalOp<mlir::memref::TensorStoreOp>();
// buffer_cast is illegal if it has uses. // buffer_cast is illegal if it has uses.
// TODO(b/175670649) Make buffer_cast illegal. // TODO(b/175670649) Make buffer_cast illegal.
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>( target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result // CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.add"(%lhs, %rhs) %result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
@ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.tanh"(%arg0) %result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])