diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 8db5d84..b6e55a9 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -46,7 +46,6 @@ namespace { sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftRightLogicalOp) sep fn(SubOp) -// TODO(frgossen): Make it variadic. template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -75,28 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { // Generate IR to flatten the operand. auto loc = op.getLoc(); - Value shape = rewriter.create(loc, operand); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShapeAsDimTensor = - rewriter.create(loc, numElementsAsIndex); + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shape = + rewriter.create(loc, extentTensorTy, operand); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandTy.getElementType()); Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShapeAsDimTensor); + loc, flatTensorTy, operand, flatShape); // Generate IR for the actual operation. Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); // Generate IR to restore the original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, operandTy, flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, operandTy, + flatResult, shape); return success(); } @@ -122,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { } // Flatten operands. - Type shapeTy = shape::ShapeType::get(rewriter.getContext()); auto loc = op.getLoc(); - Value shapeLhs = rewriter.create(loc, op.lhs()); - Value shapeRhs = rewriter.create(loc, op.rhs()); - Value shape = rewriter.create(loc, shapeTy, + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shapeLhs = + rewriter.create(loc, extentTensorTy, op.lhs()); + Value shapeRhs = + rewriter.create(loc, extentTensorTy, op.rhs()); + Value shape = rewriter.create(loc, extentTensorTy, ValueRange{shapeLhs, shapeRhs}); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShape = - rewriter.create(loc, numElementsAsIndex); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); TensorType lhsTy = op.lhs().getType().template cast(); Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, lhsTy.getElementType()); @@ -148,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { Value flatResult = rewriter.create(loc, flatLhs, flatRhs); // Restore original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, op.getType(), flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, + shape); return success(); } diff --git a/tests/mhlo-transform-unranked.mlir b/tests/mhlo-transform-unranked.mlir index 6cc07e0..56a7cf7 100644 --- a/tests/mhlo-transform-unranked.mlir +++ b/tests/mhlo-transform-unranked.mlir @@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. %shape = shape.shape_of %a : tensor<*xf32> -> tensor %num_elements = shape.num_elements %shape : tensor -> index - %num_elements_as_index = shape.size_to_index %num_elements : index - %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. - %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor -> tensor - %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor - // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> @@ -75,13 +71,11 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] // CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]]) // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor - // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32>