[MHLO:Linalg] Add support for lowering concat of unsigned tensors

Nothing about concat here really. Just need to plumb through the type
conversion.

PiperOrigin-RevId: 372012957
This commit is contained in:
Geoffrey Martin-Noble 2021-05-04 15:56:56 -07:00 committed by TensorFlow MLIR Team
parent 5a60793b31
commit ac68145565
2 changed files with 62 additions and 1 deletions

View File

@ -953,7 +953,9 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
return success();
}
auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
auto result_type =
this->typeConverter->convertType(op.getResult().getType())
.dyn_cast<RankedTensorType>();
if (!result_type) return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

View File

@ -2243,6 +2243,65 @@ func @concatenate(%a: tensor<?x?xi32>, %b: tensor<?x?xi32>, %c: tensor<?x?xi32>)
// -----
// CHECK-LABEL: func @concatenate_unsigned(
// CHECK-SAME: %[[A_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[B_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[C_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK-DAG: %[[A_SIGNLESS:.*]] = unrealized_conversion_cast %[[A_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
// CHECK-DAG: %[[B_SIGNLESS:.*]] = unrealized_conversion_cast %[[B_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
// CHECK-DAG: %[[C_SIGNLESS:.*]] = unrealized_conversion_cast %[[C_UNSIGNED]] : tensor<?x?xui32> to tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 0 : index
// CHECK: %[[VAL_5:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_4]] : tensor<?x?xi32>
// CHECK: %[[VAL_6:.*]] = constant 1 : index
// CHECK: %[[VAL_7:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_6]] : tensor<?x?xi32>
// CHECK: %[[VAL_8:.*]] = constant 1 : index
// CHECK: %[[VAL_9:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_8]] : tensor<?x?xi32>
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index
// CHECK: %[[VAL_11:.*]] = constant 1 : index
// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor<?x?xi32>
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor<?x?xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor<?x?xi32>) {
// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32):
// CHECK: %[[VAL_19:.*]] = constant 1 : index
// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor<?x?xi32>
// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index
// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) {
// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_25]] : i32
// CHECK: } else {
// CHECK: %[[VAL_26:.*]] = constant 1 : index
// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor<?x?xi32>
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index
// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index
// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) {
// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_32]] : i32
// CHECK: } else {
// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index
// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor<?x?xi32>
// CHECK: scf.yield %[[VAL_34]] : i32
// CHECK: }
// CHECK: scf.yield %[[VAL_35:.*]] : i32
// CHECK: }
// CHECK: linalg.yield %[[VAL_36:.*]] : i32
// CHECK: } -> tensor<?x?xi32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<?x?xi32> to tensor<?x?xui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<?x?xui32>
// CHECK: }
func @concatenate_unsigned(%a: tensor<?x?xui32>, %b: tensor<?x?xui32>, %c: tensor<?x?xui32>) -> tensor<?x?xui32> {
%concat = "mhlo.concatenate"(%a, %b, %c) {
dimension = 1
} : (tensor<?x?xui32>, tensor<?x?xui32>, tensor<?x?xui32>) -> tensor<?x?xui32>
return %concat : tensor<?x?xui32>
}
// -----
// CHECK-LABEL: unsigned_divide
func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
// CHECK: linalg.generic