From 67a770e4e066814ca147bb7a3c4b7d411767fcfa Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 11 Mar 2021 02:00:50 -0800 Subject: [PATCH] [HLO:MLIR] Make binary op type reification emit shape_of instead of tensor ops This gives cleaner code and allows shape optimizations to happen on the result. PiperOrigin-RevId: 362242975 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 18 +++++------------- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + tests/hlo-legalize-to-lhlo.mlir | 12 ++---------- tests/mhlo_infer_shape_type_methods.mlir | 10 ++-------- .../move_up_dynamic_broadcasts_for_fusion.mlir | 6 ++++-- 5 files changed, 14 insertions(+), 33 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index c2a0fa8..fd43805 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -3151,19 +3151,11 @@ LogicalResult deriveShapeFromFirstOperand( return failure(); } auto loc = op->getLoc(); - SmallVector shape_values; - shape_values.reserve(operand_type.getRank()); - for (auto element : llvm::enumerate(operand_type.getShape())) { - if (element.value() == ShapedType::kDynamicSize) { - shape_values.push_back( - builder->create(loc, operand, element.index())); - } else { - shape_values.push_back( - builder->create(loc, element.value())); - } - } - *reifiedReturnShapes = SmallVector{ - builder->create(loc, shape_values)}; + // Some users rely on the result type being a static shape. + auto shape_type = + RankedTensorType::get(operand_type.getRank(), builder->getIndexType()); + reifiedReturnShapes->assign( + {builder->create(loc, shape_type, operand)}); return success(); } diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index abaa53f..cc9ce58 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -577,6 +577,7 @@ struct HloLegalizeToLhlo ConversionTarget target(context); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); // Declare tensor_load and tensor_store illegal. diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index b0cad9d..7d577c9 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -466,11 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { %result = "mhlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) @@ -486,11 +482,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { func @tanh_dyn(%arg0: tensor) -> tensor { %result = "mhlo.tanh"(%arg0) : (tensor) -> tensor - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref -> tensor<2xindex> // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) diff --git a/tests/mhlo_infer_shape_type_methods.mlir b/tests/mhlo_infer_shape_type_methods.mlir index 53d6fe4..5de60df 100644 --- a/tests/mhlo_infer_shape_type_methods.mlir +++ b/tests/mhlo_infer_shape_type_methods.mlir @@ -5,10 +5,7 @@ // CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>, func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> { - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<2x?xi1> -> tensor<2xindex> // CHECK: return %[[SHAPE]] : tensor<2xindex> %0 = "mhlo.select"(%pred, %a, %b) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> @@ -21,10 +18,7 @@ func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) // CHECK-LABEL: @compare // CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> { - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<2x?xf32> -> tensor<2xindex> // CHECK: return %[[SHAPE]] : tensor<2xindex> %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"} : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index a07210c..a9acfbd 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -4,7 +4,8 @@ // CHECK-LABEL: @shape_of_unary // CHECK-SAME: (%[[ARG:.*]]: tensor) func @shape_of_unary(%arg : tensor) { - // CHECK-NOT: shape_of + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<2xindex> + // CHECK: "use"(%[[SHAPE]]) %0 = "mhlo.convert"(%arg) : (tensor) -> tensor %1 = shape.shape_of %0 : tensor -> tensor "use"(%1) : (tensor) -> () @@ -17,7 +18,8 @@ func @shape_of_unary(%arg : tensor) { // CHECK-LABEL: @shape_of_nary // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { - // CHECK-NOT: shape_of + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> + // CHECK: "use"(%[[SHAPE]]) %0 = mhlo.subtract %arg0, %arg1 : tensor %1 = shape.shape_of %0 : tensor -> tensor "use"(%1) : (tensor) -> ()