diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 477cfda..538a475 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp traits, LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + operands, &reifiedReturnShapes); } }]; } diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 21e9c9f..427ef62 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -79,7 +79,7 @@ class TokenType : public Type::TypeBase { // // and returns %4 as the shape value. LogicalResult deriveShapeFromFirstOperand( - OpBuilder *builder, Operation *op, + OpBuilder *builder, Operation *op, ValueRange operands, SmallVectorImpl *reifiedReturnShapes); // Type derivation function that returns a tensor type with a new element type. diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index da97103..80fd274 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp traits, OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + operands, &reifiedReturnShapes); } bool inferInputOutputShapeEquality(int input, int output) { @@ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp traits> : OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + operands, &reifiedReturnShapes); } bool inferInputsShapeEquality(int lhs, int rhs) { diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 64ee93c..a6cc839 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents( } LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( - OpBuilder& builder, Operation* op, + OpBuilder& builder, Operation* op, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { auto loc = op->getLoc(); - auto lhs = op->getOperand(0); - auto rhs = op->getOperand(1); + auto lhs = operands[0]; + auto rhs = operands[1]; // Check for "numpy"-style rank broadcast. auto broadcast_dimensions = op->getAttr("broadcast_dimensions") @@ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents( inferedReturnShapes); } LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange, + OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), - reifiedReturnShapes); + operands, reifiedReturnShapes); } //===----------------------------------------------------------------------===// @@ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents( } LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange, + OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), - reifiedReturnShapes); + operands, reifiedReturnShapes); } //===----------------------------------------------------------------------===// @@ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes( inferedReturnShapes); \ } \ LogicalResult Op::reifyReturnTypeShapes( \ - OpBuilder& builder, ValueRange, \ + OpBuilder& builder, ValueRange operands, \ SmallVectorImpl& reifiedReturnShapes) { \ - return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \ - reifiedReturnShapes); \ + return ReifyBroadcastBinaryOpReturnTypeShapes( \ + builder, getOperation(), operands, reifiedReturnShapes); \ } #define BROADCAST_BINARY_OP_DEFS(Op) \ diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 56aa2bd..eb2bb35 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents( } LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( - OpBuilder&, ValueRange, SmallVectorImpl& reifiedReturnShapes) { - reifiedReturnShapes.push_back(output_dimensions()); + OpBuilder&, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + DynamicBroadcastInDimOp::Adaptor adaptor(operands); + reifiedReturnShapes.push_back(adaptor.output_dimensions()); return success(); } @@ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents( LogicalResult SelectOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), + return deriveShapeFromFirstOperand(&builder, getOperation(), operands, &reifiedReturnShapes); } @@ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents( LogicalResult CompareOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), + return deriveShapeFromFirstOperand(&builder, getOperation(), operands, &reifiedReturnShapes); } @@ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const { //===----------------------------------------------------------------------===// LogicalResult deriveShapeFromFirstOperand( - OpBuilder* builder, Operation* op, + OpBuilder* builder, Operation* op, ValueRange operands, SmallVectorImpl* reifiedReturnShapes) { - Value operand = op->getOperand(0); + Value operand = operands.front(); ShapedType operand_type = operand.getType().dyn_cast(); if (!operand_type) { op->emitOpError() << "first operand is not a shaped type"; diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index e36a821..fc64bda 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result, /// to the `results` vector. LogicalResult ConvertResults(Operation* op, SmallVectorImpl& results, ConversionPatternRewriter& rewriter) { + size_t num_operands = results.size(); + SmallVector tensor_operands; for (auto result : llvm::enumerate(op->getResults())) { RankedTensorType resultType = result.value().getType().dyn_cast(); @@ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl& results, auto shape_type_op = dyn_cast(op); if (!shape_type_op) return failure(); + if (tensor_operands.empty()) { + for (auto operand : ArrayRef(results).take_front(num_operands)) { + auto tp = operand.getType().cast(); + tensor_operands.push_back(rewriter.create( + op->getLoc(), + RankedTensorType::get(tp.getShape(), tp.getElementType()), + operand)); + } + } + SmallVector results_shape; - auto status = shape_type_op.reifyReturnTypeShapes( - rewriter, shape_type_op->getOperands(), results_shape); + auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands, + results_shape); if (failed(status)) return failure(); results.push_back( InsertDynamicAllocAndDealloc(op->getLoc(), result.value(), @@ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); - if (failed(shape_type_op.reifyReturnTypeShapes( - rewriter, shape_type_op->getOperands(), results_shape))) + if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands, + results_shape))) return failure(); bufferArgs[2] = InsertDynamicAllocAndDealloc( @@ -585,9 +597,9 @@ struct HloLegalizeToLhlo target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); - // Declare tensor_load and tensor_store illegal. - target.addIllegalOp(); + // Declare tensor_store illegal. tensor_load may be used to reify output + // shape computation during dialect conversion and will be handled later. + target.addIllegalOp(); // buffer_cast is illegal if it has uses. // TODO(b/175670649) Make buffer_cast illegal. target.addDynamicallyLegalOp( diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index 1c6aeac..a2c1a8d 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s // CHECK-LABEL: func @func_op_unranked_arg_result func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 467c4b6..504cd33 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { %result = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) @@ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { func @tanh_dyn(%arg0: tensor) -> tensor { %result = "mhlo.tanh"(%arg0) : (tensor) -> tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])