diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index ea28d37..2dedf0e 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -953,7 +953,9 @@ struct ConcatenateConverter : public OpConversionPattern { return success(); } - auto result_type = op.getResult().getType().dyn_cast(); + auto result_type = + this->typeConverter->convertType(op.getResult().getType()) + .dyn_cast(); if (!result_type) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 8c52cd8..470e754 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -2243,6 +2243,65 @@ func @concatenate(%a: tensor, %b: tensor, %c: tensor) // ----- +// CHECK-LABEL: func @concatenate_unsigned( +// CHECK-SAME: %[[A_UNSIGNED:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[B_UNSIGNED:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[C_UNSIGNED:[a-zA-Z0-9_]*]] +// CHECK-DAG: %[[A_SIGNLESS:.*]] = unrealized_conversion_cast %[[A_UNSIGNED]] : tensor to tensor +// CHECK-DAG: %[[B_SIGNLESS:.*]] = unrealized_conversion_cast %[[B_UNSIGNED]] : tensor to tensor +// CHECK-DAG: %[[C_SIGNLESS:.*]] = unrealized_conversion_cast %[[C_UNSIGNED]] : tensor to tensor +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_8]] : tensor +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_7]], %[[VAL_9]] : index +// CHECK: %[[VAL_11:.*]] = constant 1 : index +// CHECK: %[[VAL_12:.*]] = memref.dim %[[C_SIGNLESS]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_10]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = linalg.init_tensor [%[[VAL_5]], %[[VAL_13]]] : tensor +// CHECK: %[[RET_SIGNLESS:.*]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_14]] : tensor) { +// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32): +// CHECK: %[[VAL_19:.*]] = constant 1 : index +// CHECK: %[[VAL_20:.*]] = memref.dim %[[A_SIGNLESS]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_21:.*]] = addi %[[VAL_3]], %[[VAL_20]] : index +// CHECK: %[[VAL_22:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_23:.*]] = scf.if %[[VAL_22]] -> (i32) { +// CHECK: %[[VAL_24:.*]] = subi %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_25:.*]] = tensor.extract %[[A_SIGNLESS]][%[[VAL_16]], %[[VAL_24]]] : tensor +// CHECK: scf.yield %[[VAL_25]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_26:.*]] = constant 1 : index +// CHECK: %[[VAL_27:.*]] = memref.dim %[[B_SIGNLESS]], %[[VAL_26]] : tensor +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_21]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = cmpi ult, %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_30:.*]] = scf.if %[[VAL_29]] -> (i32) { +// CHECK: %[[VAL_31:.*]] = subi %[[VAL_17]], %[[VAL_21]] : index +// CHECK: %[[VAL_32:.*]] = tensor.extract %[[B_SIGNLESS]][%[[VAL_16]], %[[VAL_31]]] : tensor +// CHECK: scf.yield %[[VAL_32]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_33:.*]] = subi %[[VAL_17]], %[[VAL_28]] : index +// CHECK: %[[VAL_34:.*]] = tensor.extract %[[C_SIGNLESS]][%[[VAL_16]], %[[VAL_33]]] : tensor +// CHECK: scf.yield %[[VAL_34]] : i32 +// CHECK: } +// CHECK: scf.yield %[[VAL_35:.*]] : i32 +// CHECK: } +// CHECK: linalg.yield %[[VAL_36:.*]] : i32 +// CHECK: } -> tensor +// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor to tensor +// CHECK: return %[[RET_UNSIGNED]] : tensor +// CHECK: } +func @concatenate_unsigned(%a: tensor, %b: tensor, %c: tensor) -> tensor { + %concat = "mhlo.concatenate"(%a, %b, %c) { + dimension = 1 + } : (tensor, tensor, tensor) -> tensor + return %concat : tensor +} + +// ----- + // CHECK-LABEL: unsigned_divide func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { // CHECK: linalg.generic