From 55eda81407508a3391aa7d875515263dfe6044ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Mar 2021 02:41:10 -0800 Subject: [PATCH] [MLIR][HLO] Reify shape extents as `index` values PiperOrigin-RevId: 361519167 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 8 +++---- tests/hlo-legalize-to-lhlo.mlir | 24 +++++++------------ tests/mhlo_infer_shape_type_methods.mlir | 30 +++++++++++------------- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index cf5f5d2..9132418 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -3175,15 +3175,13 @@ LogicalResult deriveShapeFromFirstOperand( auto loc = op->getLoc(); SmallVector shape_values; shape_values.reserve(operand_type.getRank()); - auto shape_scalar_type = builder->getIntegerType(64); for (auto element : llvm::enumerate(operand_type.getShape())) { if (element.value() == ShapedType::kDynamicSize) { - Value dim = builder->create(loc, operand, element.index()); shape_values.push_back( - builder->create(loc, dim, shape_scalar_type)); + builder->create(loc, operand, element.index())); } else { - shape_values.push_back(builder->create( - loc, builder->getI64IntegerAttr(element.value()))); + shape_values.push_back( + builder->create(loc, element.value())); } } *reifiedReturnShapes = SmallVector{ diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 5c49724..b0cad9d 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -468,16 +468,12 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) -> tensor { : (tensor, tensor) -> tensor // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> - // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> - // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index - // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex> + // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> + // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> + // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return %result : tensor // CHECK: return %[[RESULT]] @@ -492,16 +488,12 @@ func @tanh_dyn(%arg0: tensor) -> tensor { : (tensor) -> tensor // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref - // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref - // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> - // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> - // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index - // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex> + // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> + // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> + // CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]]) // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return %result : tensor // CHECK: return %[[RESULT]] diff --git a/tests/mhlo_infer_shape_type_methods.mlir b/tests/mhlo_infer_shape_type_methods.mlir index c40eb3e..53d6fe4 100644 --- a/tests/mhlo_infer_shape_type_methods.mlir +++ b/tests/mhlo_infer_shape_type_methods.mlir @@ -4,34 +4,32 @@ // CHECK-LABEL: @select // CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>, func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) - -> tensor<2xi64> { - // CHECK: %[[C2:.*]] = constant 2 : i64 + -> tensor<2xindex> { + // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> - // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64> - // CHECK: return %[[SHAPE]] : tensor<2xi64> + // CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex> + // CHECK: return %[[SHAPE]] : tensor<2xindex> %0 = "mhlo.select"(%pred, %a, %b) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> %1 = "mhlo_test.reify_return_type_shapes"(%0) - : (tensor<2x?xf32>) -> tensor<2xi64> - return %1 : tensor<2xi64> + : (tensor<2x?xf32>) -> tensor<2xindex> + return %1 : tensor<2xindex> } // ----- // CHECK-LABEL: @compare // CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, -func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> { - // CHECK: %[[C2:.*]] = constant 2 : i64 +func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> { + // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> - // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 - // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64> - // CHECK: return %[[SHAPE]] : tensor<2xi64> + // CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32> + // CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex> + // CHECK: return %[[SHAPE]] : tensor<2xindex> %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"} : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> %1 = "mhlo_test.reify_return_type_shapes"(%0) - : (tensor<2x?xi1>) -> tensor<2xi64> - return %1 : tensor<2xi64> + : (tensor<2x?xi1>) -> tensor<2xindex> + return %1 : tensor<2xindex> }