[MLIR][HLO] Reify shape extents as `index` values

PiperOrigin-RevId: 361519167
This commit is contained in:
A. Unique TensorFlower 2021-03-08 02:41:10 -08:00 committed by TensorFlow MLIR Team
parent 5a415de33b
commit 55eda81407
3 changed files with 25 additions and 37 deletions

View File

@ -3175,15 +3175,13 @@ LogicalResult deriveShapeFromFirstOperand(
auto loc = op->getLoc();
SmallVector<Value, 4> shape_values;
shape_values.reserve(operand_type.getRank());
auto shape_scalar_type = builder->getIntegerType(64);
for (auto element : llvm::enumerate(operand_type.getShape())) {
if (element.value() == ShapedType::kDynamicSize) {
Value dim = builder->create<DimOp>(loc, operand, element.index());
shape_values.push_back(
builder->create<IndexCastOp>(loc, dim, shape_scalar_type));
builder->create<DimOp>(loc, operand, element.index()));
} else {
shape_values.push_back(builder->create<ConstantOp>(
loc, builder->getI64IntegerAttr(element.value())));
shape_values.push_back(
builder->create<ConstantIndexOp>(loc, element.value()));
}
}
*reifiedReturnShapes = SmallVector<Value, 1>{

View File

@ -468,16 +468,12 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]]
@ -492,16 +488,12 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]]

View File

@ -4,34 +4,32 @@
// CHECK-LABEL: @select
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
-> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
-> tensor<2xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
// CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
// CHECK: return %[[SHAPE]] : tensor<2xindex>
%0 = "mhlo.select"(%pred, %a, %b)
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xf32>) -> tensor<2xi64>
return %1 : tensor<2xi64>
: (tensor<2x?xf32>) -> tensor<2xindex>
return %1 : tensor<2xindex>
}
// -----
// CHECK-LABEL: @compare
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
// CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
// CHECK: return %[[SHAPE]] : tensor<2xindex>
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
: (tensor<2x?xi1>) -> tensor<2xindex>
return %1 : tensor<2xindex>
}