[MHLO:Linalg] Add support for lowering concat of unsigned tensors
Nothing about concat here really. Just need to plumb through the type conversion. PiperOrigin-RevId: 372012957
This commit is contained in:
parent
5a60793b31
commit
ac68145565
|
@ -953,7 +953,9 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
|
auto result_type =
|
||||||
|
this->typeConverter->convertType(op.getResult().getType())
|
||||||
|
.dyn_cast<RankedTensorType>();
|
||||||
if (!result_type) return failure();
|
if (!result_type) return failure();
|
||||||
|
|
||||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||||
|
|
|
@ -2243,6 +2243,65 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @concatenate_unsigned(
|
||||||
|
// CHECK-SAME: %[[A_UNSIGNED:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[B_UNSIGNED:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[C_UNSIGNED:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-DAG: %[[A_SIGNLESS:.*]] = unrealized_conversion_cast %[[A_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
|
||||||
|
// CHECK-DAG: %[[B_SIGNLESS:.*]] = unrealized_conversion_cast %[[B_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
|
||||||
|
// CHECK-DAG: %[[C_SIGNLESS:.*]] = unrealized_conversion_cast %[[C_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_4:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[VAL_5:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_4]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_7:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_6]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_9:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_8]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index
|
||||||
|
// CHECK: %[[VAL_11:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
|
||||||
|
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
|
||||||
|
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
|
||||||
|
// CHECK: %[[VAL_19:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
|
||||||
|
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
|
||||||
|
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
|
||||||
|
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
|
||||||
|
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_25]] : i32
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_26:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
|
||||||
|
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
|
||||||
|
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
|
||||||
|
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
|
||||||
|
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_32]] : i32
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
|
||||||
|
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
|
||||||
|
// CHECK: scf.yield %[[VAL_34]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: scf.yield %[[VAL_35:.*]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: linalg.yield %[[VAL_36:.*]] : i32
|
||||||
|
// CHECK: } -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<?x?xi32> to tensor<?x?xui32>
|
||||||
|
// CHECK: return %[[RET_UNSIGNED]] : tensor<?x?xui32>
|
||||||
|
// CHECK: }
|
||||||
|
func @concatenate_unsigned(%a: tensor<?x?xui32>, %b: tensor<?x?xui32>, %c: tensor<?x?xui32>) -> tensor<?x?xui32> {
|
||||||
|
%concat = "mhlo.concatenate"(%a, %b, %c) {
|
||||||
|
dimension = 1
|
||||||
|
} : (tensor<?x?xui32>, tensor<?x?xui32>, tensor<?x?xui32>) -> tensor<?x?xui32>
|
||||||
|
return %concat : tensor<?x?xui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: unsigned_divide
|
// CHECK-LABEL: unsigned_divide
|
||||||
func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
Loading…
Reference in New Issue