PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).
Copybara import of the project:

--
a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.

The new interface is more safe to be used during dialect conversion
(e.g. converting from tensor world to buffer world).

--
55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>:

minor fix

PiperOrigin-RevId: 375500273
This commit is contained in:
wyzhao 2021-05-24 10:10:21 -07:00 committed by TensorFlow MLIR Team
parent 28c4112f35
commit b93e54d8a4
8 changed files with 44 additions and 28 deletions

View File

@ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
operands, &reifiedReturnShapes);
}
}];
}

View File

@ -79,7 +79,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
//
// and returns %4 as the shape value.
LogicalResult deriveShapeFromFirstOperand(
OpBuilder *builder, Operation *op,
OpBuilder *builder, Operation *op, ValueRange operands,
SmallVectorImpl<Value> *reifiedReturnShapes);
// Type derivation function that returns a tensor type with a new element type.

View File

@ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
operands,
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
@ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
operands,
&reifiedReturnShapes);
}
bool inferInputsShapeEquality(int lhs, int rhs) {

View File

@ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
}
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
OpBuilder& builder, Operation* op,
OpBuilder& builder, Operation* op, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
auto loc = op->getLoc();
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
auto lhs = operands[0];
auto rhs = operands[1];
// Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
@ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
inferedReturnShapes);
}
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange,
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
operands, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
@ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
}
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange,
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
operands, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
@ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes(
inferedReturnShapes); \
} \
LogicalResult Op::reifyReturnTypeShapes( \
OpBuilder& builder, ValueRange, \
OpBuilder& builder, ValueRange operands, \
SmallVectorImpl<Value>& reifiedReturnShapes) { \
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
reifiedReturnShapes); \
return ReifyBroadcastBinaryOpReturnTypeShapes( \
builder, getOperation(), operands, reifiedReturnShapes); \
}
#define BROADCAST_BINARY_OP_DEFS(Op) \

View File

@ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
}
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange, SmallVectorImpl<Value>& reifiedReturnShapes) {
reifiedReturnShapes.push_back(output_dimensions());
OpBuilder&, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
reifiedReturnShapes.push_back(adaptor.output_dimensions());
return success();
}
@ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents(
LogicalResult SelectOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(),
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
&reifiedReturnShapes);
}
@ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents(
LogicalResult CompareOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(),
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
&reifiedReturnShapes);
}
@ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
//===----------------------------------------------------------------------===//
LogicalResult deriveShapeFromFirstOperand(
OpBuilder* builder, Operation* op,
OpBuilder* builder, Operation* op, ValueRange operands,
SmallVectorImpl<Value>* reifiedReturnShapes) {
Value operand = op->getOperand(0);
Value operand = operands.front();
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
if (!operand_type) {
op->emitOpError() << "first operand is not a shaped type";

View File

@ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result,
/// to the `results` vector.
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
ConversionPatternRewriter& rewriter) {
size_t num_operands = results.size();
SmallVector<Value, 2> tensor_operands;
for (auto result : llvm::enumerate(op->getResults())) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
@ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (tensor_operands.empty()) {
for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
auto tp = operand.getType().cast<ShapedType>();
tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
op->getLoc(),
RankedTensorType::get(tp.getShape(), tp.getElementType()),
operand));
}
}
SmallVector<Value, 1> results_shape;
auto status = shape_type_op.reifyReturnTypeShapes(
rewriter, shape_type_op->getOperands(), results_shape);
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
results_shape);
if (failed(status)) return failure();
results.push_back(
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
@ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (failed(shape_type_op.reifyReturnTypeShapes(
rewriter, shape_type_op->getOperands(), results_shape)))
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
results_shape)))
return failure();
bufferArgs[2] = InsertDynamicAllocAndDealloc(
@ -585,9 +597,9 @@ struct HloLegalizeToLhlo
target.addLegalDialect<shape::ShapeDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
// Declare tensor_load and tensor_store illegal.
target.addIllegalOp<mlir::memref::TensorLoadOp,
mlir::memref::TensorStoreOp>();
// Declare tensor_store illegal. tensor_load may be used to reify output
// shape computation during dialect conversion and will be handled later.
target.addIllegalOp<mlir::memref::TensorStoreOp>();
// buffer_cast is illegal if it has uses.
// TODO(b/175670649) Make buffer_cast illegal.
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
@ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])