diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index b8cac4c..a2b54f2 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1148,8 +1148,7 @@ SmallVector GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc, return dyn_shape; } -template +template class DotOpOnTensorsConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1159,28 +1158,13 @@ class DotOpOnTensorsConversion : public OpConversionPattern { if (!VerifyHloOpBufferOrTensorSemantics(op)) { return failure(); } + if (GetDotOperationType(op) != op_type) return failure(); mhlo::DotOp::Adaptor adaptor(args); - auto lhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - auto rhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || - lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { - return failure(); - } - + Location loc = op.getLoc(); auto output_type = op.getType().cast(); auto output_el_type = output_type.getElementType(); - if (!output_el_type.isa() || - output_el_type.getIntOrFloatBitWidth() != output_bit_width) { - return failure(); - } - - if (GetDotOperationType(op) != op_type) return failure(); - - Location loc = op.getLoc(); auto zero_attr = rewriter.getZeroAttr(output_el_type); Value zero = rewriter.create(loc, zero_attr); SmallVector dyn_shape = GetDotOpInitTensorDynSizes( @@ -1207,8 +1191,6 @@ SmallVector GetDotGeneralOpInitTensorDynSizes( return dyn_shape; } -template class DotGeneralOpOnTensorsConversion : public OpConversionPattern { public: @@ -1247,23 +1229,10 @@ class DotGeneralOpOnTensorsConversion } mhlo::DotGeneralOp::Adaptor adaptor(args); - auto lhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - auto rhs_el_type = - adaptor.lhs().getType().cast().getElementType(); - if (lhs_el_type != rhs_el_type || !lhs_el_type.isa() || - lhs_el_type.getIntOrFloatBitWidth() != input_bit_width) { - return failure(); - } - - auto output_type = op.getType().cast(); - auto output_el_type = output_type.getElementType(); - if (!output_el_type.isa() || - output_el_type.getIntOrFloatBitWidth() != output_bit_width) { - return failure(); - } Location loc = op.getLoc(); + auto output_type = op.getType().cast(); + auto output_el_type = output_type.getElementType(); SmallVector dyn_shape = GetDotGeneralOpInitTensorDynSizes( rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type); auto zero_attr = rewriter.getZeroAttr(output_el_type); @@ -1271,7 +1240,7 @@ class DotGeneralOpOnTensorsConversion auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape); Value zero_tensor = rewriter.create(loc, init_tensor, zero).getResult(0); - Operation* linalg_op = rewriter.create( + Operation* linalg_op = rewriter.create( loc, /*resultTensorTypes=*/TypeRange{op.getType()}, /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()}, /*outputBuffers=*/ValueRange{zero_tensor}); @@ -1709,49 +1678,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, ReverseConverter, SliceConverter, TransposeConverter, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, - DotGeneralOpOnTensorsConversion, + DotOpOnTensorsConversion, + DotGeneralOpOnTensorsConversion, NormalConvOpOnTensorsConversion, ReduceOnTensorsConversion, PadOpOnTensorsConversion>(context); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index f55d6fc..af42311 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1004,7 +1004,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i8_i8_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1022,7 +1022,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i16_i16_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1040,7 +1040,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, // CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.matmul_i32_i32_i32 +// CHECK: linalg.matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) // CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) @@ -1131,7 +1131,7 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor, // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.batch_matmul_i8_i8_i32 +// CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) @@ -1160,7 +1160,7 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor, // CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]] // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]] // CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]] -// CHECK: linalg.batch_matmul_i16_i16_i32 +// CHECK: linalg.batch_matmul // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) // CHECK-SAME: outs(%[[FILL]] : tensor)