From 5d3cc2105eaf164d7b9e737315ae84aaa66d915e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Aug 2020 11:10:20 -0700 Subject: [PATCH] [MLIR][HLO] Remove redundant casts from unranked to ranked transformation The transformation of unranked to ranked operations no longer generates cast operations for shapes and sizes. Instead, we use the newly introduced support for extent tensor and index types directly. PiperOrigin-RevId: 325057440 --- .../mhlo/transforms/transform_unranked_hlo.cc | 53 ++++++++----------- tests/mhlo-transform-unranked.mlir | 20 +++---- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 8db5d84..b6e55a9 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -46,7 +46,6 @@ namespace { sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ sep fn(ShiftRightLogicalOp) sep fn(SubOp) -// TODO(frgossen): Make it variadic. template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -75,28 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { // Generate IR to flatten the operand. auto loc = op.getLoc(); - Value shape = rewriter.create(loc, operand); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShapeAsDimTensor = - rewriter.create(loc, numElementsAsIndex); + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shape = + rewriter.create(loc, extentTensorTy, operand); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandTy.getElementType()); Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShapeAsDimTensor); + loc, flatTensorTy, operand, flatShape); // Generate IR for the actual operation. Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); // Generate IR to restore the original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, operandTy, flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, operandTy, + flatResult, shape); return success(); } @@ -122,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { } // Flatten operands. - Type shapeTy = shape::ShapeType::get(rewriter.getContext()); auto loc = op.getLoc(); - Value shapeLhs = rewriter.create(loc, op.lhs()); - Value shapeRhs = rewriter.create(loc, op.rhs()); - Value shape = rewriter.create(loc, shapeTy, + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shapeLhs = + rewriter.create(loc, extentTensorTy, op.lhs()); + Value shapeRhs = + rewriter.create(loc, extentTensorTy, op.rhs()); + Value shape = rewriter.create(loc, extentTensorTy, ValueRange{shapeLhs, shapeRhs}); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShape = - rewriter.create(loc, numElementsAsIndex); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); TensorType lhsTy = op.lhs().getType().template cast(); Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, lhsTy.getElementType()); @@ -148,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { Value flatResult = rewriter.create(loc, flatLhs, flatRhs); // Restore original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, op.getType(), flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, + shape); return success(); } diff --git a/tests/mhlo-transform-unranked.mlir b/tests/mhlo-transform-unranked.mlir index 6cc07e0..56a7cf7 100644 --- a/tests/mhlo-transform-unranked.mlir +++ b/tests/mhlo-transform-unranked.mlir @@ -7,8 +7,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. %shape = shape.shape_of %a : tensor<*xf32> -> tensor %num_elements = shape.num_elements %shape : tensor -> index - %num_elements_as_index = shape.size_to_index %num_elements : index - %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. - %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor -> tensor - %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor - // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> @@ -75,13 +71,11 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] // CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]]) // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor - // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32>