PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454 The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). Copybara import of the project: -- a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface. The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). -- 55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>: minor fix PiperOrigin-RevId: 375500273
This commit is contained in:
parent
28c4112f35
commit
b93e54d8a4
|
@ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
|
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
&reifiedReturnShapes);
|
operands, &reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||||
//
|
//
|
||||||
// and returns %4 as the shape value.
|
// and returns %4 as the shape value.
|
||||||
LogicalResult deriveShapeFromFirstOperand(
|
LogicalResult deriveShapeFromFirstOperand(
|
||||||
OpBuilder *builder, Operation *op,
|
OpBuilder *builder, Operation *op, ValueRange operands,
|
||||||
SmallVectorImpl<Value> *reifiedReturnShapes);
|
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||||
|
|
||||||
// Type derivation function that returns a tensor type with a new element type.
|
// Type derivation function that returns a tensor type with a new element type.
|
||||||
|
|
|
@ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
OpBuilder& builder, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
|
operands,
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
bool inferInputOutputShapeEquality(int input, int output) {
|
bool inferInputOutputShapeEquality(int input, int output) {
|
||||||
|
@ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||||
OpBuilder& builder, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
|
operands,
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
bool inferInputsShapeEquality(int lhs, int rhs) {
|
bool inferInputsShapeEquality(int lhs, int rhs) {
|
||||||
|
|
|
@ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
||||||
OpBuilder& builder, Operation* op,
|
OpBuilder& builder, Operation* op, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto lhs = op->getOperand(0);
|
auto lhs = operands[0];
|
||||||
auto rhs = op->getOperand(1);
|
auto rhs = operands[1];
|
||||||
|
|
||||||
// Check for "numpy"-style rank broadcast.
|
// Check for "numpy"-style rank broadcast.
|
||||||
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
|
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
|
||||||
|
@ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
|
||||||
inferedReturnShapes);
|
inferedReturnShapes);
|
||||||
}
|
}
|
||||||
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, ValueRange,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||||
reifiedReturnShapes);
|
operands, reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, ValueRange,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||||
reifiedReturnShapes);
|
operands, reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes(
|
||||||
inferedReturnShapes); \
|
inferedReturnShapes); \
|
||||||
} \
|
} \
|
||||||
LogicalResult Op::reifyReturnTypeShapes( \
|
LogicalResult Op::reifyReturnTypeShapes( \
|
||||||
OpBuilder& builder, ValueRange, \
|
OpBuilder& builder, ValueRange operands, \
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) { \
|
SmallVectorImpl<Value>& reifiedReturnShapes) { \
|
||||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
|
return ReifyBroadcastBinaryOpReturnTypeShapes( \
|
||||||
reifiedReturnShapes); \
|
builder, getOperation(), operands, reifiedReturnShapes); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BROADCAST_BINARY_OP_DEFS(Op) \
|
#define BROADCAST_BINARY_OP_DEFS(Op) \
|
||||||
|
|
|
@ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
||||||
OpBuilder&, ValueRange, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
OpBuilder&, ValueRange operands,
|
||||||
reifiedReturnShapes.push_back(output_dimensions());
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
DynamicBroadcastInDimOp::Adaptor adaptor(operands);
|
||||||
|
reifiedReturnShapes.push_back(adaptor.output_dimensions());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents(
|
||||||
LogicalResult SelectOp::reifyReturnTypeShapes(
|
LogicalResult SelectOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents(
|
||||||
LogicalResult CompareOp::reifyReturnTypeShapes(
|
LogicalResult CompareOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, ValueRange operands,
|
OpBuilder& builder, ValueRange operands,
|
||||||
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
return deriveShapeFromFirstOperand(&builder, getOperation(), operands,
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult deriveShapeFromFirstOperand(
|
LogicalResult deriveShapeFromFirstOperand(
|
||||||
OpBuilder* builder, Operation* op,
|
OpBuilder* builder, Operation* op, ValueRange operands,
|
||||||
SmallVectorImpl<Value>* reifiedReturnShapes) {
|
SmallVectorImpl<Value>* reifiedReturnShapes) {
|
||||||
Value operand = op->getOperand(0);
|
Value operand = operands.front();
|
||||||
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
|
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
|
||||||
if (!operand_type) {
|
if (!operand_type) {
|
||||||
op->emitOpError() << "first operand is not a shaped type";
|
op->emitOpError() << "first operand is not a shaped type";
|
||||||
|
|
|
@ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result,
|
||||||
/// to the `results` vector.
|
/// to the `results` vector.
|
||||||
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ConversionPatternRewriter& rewriter) {
|
||||||
|
size_t num_operands = results.size();
|
||||||
|
SmallVector<Value, 2> tensor_operands;
|
||||||
for (auto result : llvm::enumerate(op->getResults())) {
|
for (auto result : llvm::enumerate(op->getResults())) {
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType =
|
||||||
result.value().getType().dyn_cast<RankedTensorType>();
|
result.value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
if (!shape_type_op) return failure();
|
if (!shape_type_op) return failure();
|
||||||
|
|
||||||
|
if (tensor_operands.empty()) {
|
||||||
|
for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
|
||||||
|
auto tp = operand.getType().cast<ShapedType>();
|
||||||
|
tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(tp.getShape(), tp.getElementType()),
|
||||||
|
operand));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value, 1> results_shape;
|
SmallVector<Value, 1> results_shape;
|
||||||
auto status = shape_type_op.reifyReturnTypeShapes(
|
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
|
||||||
rewriter, shape_type_op->getOperands(), results_shape);
|
results_shape);
|
||||||
if (failed(status)) return failure();
|
if (failed(status)) return failure();
|
||||||
results.push_back(
|
results.push_back(
|
||||||
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||||
|
@ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter
|
||||||
} else {
|
} else {
|
||||||
SmallVector<Value, 1> results_shape;
|
SmallVector<Value, 1> results_shape;
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
if (failed(shape_type_op.reifyReturnTypeShapes(
|
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
|
||||||
rewriter, shape_type_op->getOperands(), results_shape)))
|
results_shape)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bufferArgs[2] = InsertDynamicAllocAndDealloc(
|
bufferArgs[2] = InsertDynamicAllocAndDealloc(
|
||||||
|
@ -585,9 +597,9 @@ struct HloLegalizeToLhlo
|
||||||
target.addLegalDialect<shape::ShapeDialect>();
|
target.addLegalDialect<shape::ShapeDialect>();
|
||||||
target.addLegalDialect<tensor::TensorDialect>();
|
target.addLegalDialect<tensor::TensorDialect>();
|
||||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||||
// Declare tensor_load and tensor_store illegal.
|
// Declare tensor_store illegal. tensor_load may be used to reify output
|
||||||
target.addIllegalOp<mlir::memref::TensorLoadOp,
|
// shape computation during dialect conversion and will be handled later.
|
||||||
mlir::memref::TensorStoreOp>();
|
target.addIllegalOp<mlir::memref::TensorStoreOp>();
|
||||||
// buffer_cast is illegal if it has uses.
|
// buffer_cast is illegal if it has uses.
|
||||||
// TODO(b/175670649) Make buffer_cast illegal.
|
// TODO(b/175670649) Make buffer_cast illegal.
|
||||||
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
|
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_op_unranked_arg_result
|
// CHECK-LABEL: func @func_op_unranked_arg_result
|
||||||
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
|
|
@ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
||||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
%result = "mhlo.add"(%lhs, %rhs)
|
%result = "mhlo.add"(%lhs, %rhs)
|
||||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||||
|
@ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
%result = "mhlo.tanh"(%arg0)
|
%result = "mhlo.tanh"(%arg0)
|
||||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||||
|
|
Loading…
Reference in New Issue