PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454 The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). Copybara import of the project: -- a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface. The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). -- 55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>: minor fix PiperOrigin-RevId: 375500273
This commit is contained in:
parent
28c4112f35
commit
b93e54d8a4
|
@ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|||
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
operands, &reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
|||
//
|
||||
// and returns %4 as the shape value.
|
||||
LogicalResult deriveShapeFromFirstOperand(
|
||||
OpBuilder *builder, Operation *op,
|
||||
OpBuilder *builder, Operation *op, ValueRange operands,
|
||||
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||
|
||||
// Type derivation function that returns a tensor type with a new element type.
|
||||
|
|
|
@ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
operands,
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
|
@ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
|||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
operands,
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputsShapeEquality(int lhs, int rhs) {
|
||||
|
|
|
@ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
|
|||
}
|
||||
|
||||
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
||||
OpBuilder& builder, Operation* op,
|
||||
OpBuilder& builder, Operation* op, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
auto loc = op->getLoc();
|
||||
auto lhs = op->getOperand(0);
|
||||
auto rhs = op->getOperand(1);
|
||||
auto lhs = operands[0];
|
||||
auto rhs = operands[1];
|
||||
|
||||
// Check for "numpy"-style rank broadcast.
|
||||
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
|
||||
|
@ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
|
|||
inferedReturnShapes);
|
||||
}
|
||||
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange,
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||
reifiedReturnShapes);
|
||||
operands, reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange,
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||
reifiedReturnShapes);
|
||||
operands, reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes(
|
|||
inferedReturnShapes); \
|
||||
} \
|
||||
LogicalResult Op::reifyReturnTypeShapes( \
|
||||
OpBuilder& builder, ValueRange, \
|
||||
OpBuilder& builder, ValueRange operands, \
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) { \
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
|
||||
reifiedReturnShapes); \
|
||||
return ReifyBroadcastBinaryOpReturnTypeShapes( \
|
||||
builder, getOperation(), operands, reifiedReturnShapes); \
|
||||
}
|
||||
|
||||
#define BROADCAST_BINARY_OP_DEFS(Op) \
|
||||
|
|
|
@ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
||||
OpBuilder&, ValueRange, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
reifiedReturnShapes.push_back(output_dimensions());
|
||||
OpBuilder&, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
|
||||
reifiedReturnShapes.push_back(adaptor.output_dimensions());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents(
|
|||
LogicalResult SelectOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
|
@ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents(
|
|||
LogicalResult CompareOp::reifyReturnTypeShapes(
|
||||
OpBuilder& builder, ValueRange operands,
|
||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
|
@ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult deriveShapeFromFirstOperand(
|
||||
OpBuilder* builder, Operation* op,
|
||||
OpBuilder* builder, Operation* op, ValueRange operands,
|
||||
SmallVectorImpl<Value>* reifiedReturnShapes) {
|
||||
Value operand = op->getOperand(0);
|
||||
Value operand = operands.front();
|
||||
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
|
||||
if (!operand_type) {
|
||||
op->emitOpError() << "first operand is not a shaped type";
|
||||
|
|
|
@ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result,
|
|||
/// to the `results` vector.
|
||||
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
size_t num_operands = results.size();
|
||||
SmallVector<Value, 2> tensor_operands;
|
||||
for (auto result : llvm::enumerate(op->getResults())) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
|||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
if (tensor_operands.empty()) {
|
||||
for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
|
||||
auto tp = operand.getType().cast<ShapedType>();
|
||||
tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(tp.getShape(), tp.getElementType()),
|
||||
operand));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status = shape_type_op.reifyReturnTypeShapes(
|
||||
rewriter, shape_type_op->getOperands(), results_shape);
|
||||
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
|
||||
results_shape);
|
||||
if (failed(status)) return failure();
|
||||
results.push_back(
|
||||
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||
|
@ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter
|
|||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (failed(shape_type_op.reifyReturnTypeShapes(
|
||||
rewriter, shape_type_op->getOperands(), results_shape)))
|
||||
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
|
||||
results_shape)))
|
||||
return failure();
|
||||
|
||||
bufferArgs[2] = InsertDynamicAllocAndDealloc(
|
||||
|
@ -585,9 +597,9 @@ struct HloLegalizeToLhlo
|
|||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
// Declare tensor_load and tensor_store illegal.
|
||||
target.addIllegalOp<mlir::memref::TensorLoadOp,
|
||||
mlir::memref::TensorStoreOp>();
|
||||
// Declare tensor_store illegal. tensor_load may be used to reify output
|
||||
// shape computation during dialect conversion and will be handled later.
|
||||
target.addIllegalOp<mlir::memref::TensorStoreOp>();
|
||||
// buffer_cast is illegal if it has uses.
|
||||
// TODO(b/175670649) Make buffer_cast illegal.
|
||||
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @func_op_unranked_arg_result
|
||||
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
|
|
@ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
|||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||
|
@ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||
|
|
Loading…
Reference in New Issue