[MLIR][HLO] Reify shape extents as `index` values
PiperOrigin-RevId: 361519167
This commit is contained in:
parent
5a415de33b
commit
55eda81407
|
@ -3175,15 +3175,13 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
SmallVector<Value, 4> shape_values;
|
SmallVector<Value, 4> shape_values;
|
||||||
shape_values.reserve(operand_type.getRank());
|
shape_values.reserve(operand_type.getRank());
|
||||||
auto shape_scalar_type = builder->getIntegerType(64);
|
|
||||||
for (auto element : llvm::enumerate(operand_type.getShape())) {
|
for (auto element : llvm::enumerate(operand_type.getShape())) {
|
||||||
if (element.value() == ShapedType::kDynamicSize) {
|
if (element.value() == ShapedType::kDynamicSize) {
|
||||||
Value dim = builder->create<DimOp>(loc, operand, element.index());
|
|
||||||
shape_values.push_back(
|
shape_values.push_back(
|
||||||
builder->create<IndexCastOp>(loc, dim, shape_scalar_type));
|
builder->create<DimOp>(loc, operand, element.index()));
|
||||||
} else {
|
} else {
|
||||||
shape_values.push_back(builder->create<ConstantOp>(
|
shape_values.push_back(
|
||||||
loc, builder->getI64IntegerAttr(element.value())));
|
builder->create<ConstantIndexOp>(loc, element.value()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*reifiedReturnShapes = SmallVector<Value, 1>{
|
*reifiedReturnShapes = SmallVector<Value, 1>{
|
||||||
|
|
|
@ -468,16 +468,12 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
|
||||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
return %result : tensor<?x?xf32>
|
return %result : tensor<?x?xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
|
@ -492,16 +488,12 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]] : tensor<2xindex>
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
|
||||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
|
||||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
|
||||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
return %result : tensor<?x?xf32>
|
return %result : tensor<?x?xf32>
|
||||||
// CHECK: return %[[RESULT]]
|
// CHECK: return %[[RESULT]]
|
||||||
|
|
|
@ -4,34 +4,32 @@
|
||||||
// CHECK-LABEL: @select
|
// CHECK-LABEL: @select
|
||||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
|
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
|
||||||
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
|
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
|
||||||
-> tensor<2xi64> {
|
-> tensor<2xindex> {
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
|
// CHECK: %[[DIM:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
|
||||||
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
|
// CHECK: return %[[SHAPE]] : tensor<2xindex>
|
||||||
// CHECK: return %[[SHAPE]] : tensor<2xi64>
|
|
||||||
%0 = "mhlo.select"(%pred, %a, %b)
|
%0 = "mhlo.select"(%pred, %a, %b)
|
||||||
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
||||||
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
||||||
: (tensor<2x?xf32>) -> tensor<2xi64>
|
: (tensor<2x?xf32>) -> tensor<2xindex>
|
||||||
return %1 : tensor<2xi64>
|
return %1 : tensor<2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @compare
|
// CHECK-LABEL: @compare
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
|
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
|
||||||
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
|
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> {
|
||||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
|
// CHECK: %[[DIM:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
|
||||||
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xindex>
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
|
// CHECK: return %[[SHAPE]] : tensor<2xindex>
|
||||||
// CHECK: return %[[SHAPE]] : tensor<2xi64>
|
|
||||||
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
|
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
|
||||||
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
|
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
|
||||||
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
%1 = "mhlo_test.reify_return_type_shapes"(%0)
|
||||||
: (tensor<2x?xi1>) -> tensor<2xi64>
|
: (tensor<2x?xi1>) -> tensor<2xindex>
|
||||||
return %1 : tensor<2xi64>
|
return %1 : tensor<2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue