Upstream mhlo.dot_general lowering to Linalg to MHLO repo

PiperOrigin-RevId: 351514250
This commit is contained in:
Hanhan Wang 2021-01-12 22:07:29 -08:00 committed by TensorFlow MLIR Team
parent 97a618f91a
commit 300a7c11ce
2 changed files with 106 additions and 1 deletions

View File

@ -99,6 +99,14 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
return dyn_sizes;
}
SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
SmallVector<int64_t, 4> ret;
for (const APInt& element : elements) {
ret.push_back(element.getLimitedValue());
}
return ret;
}
template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
@ -1131,6 +1139,81 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
}
};
SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
SmallVector<Value, 8> dyn_shape;
if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
if (result_type.isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
if (result_type.isDynamicDim(2))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
return dyn_shape;
}
class DotGeneralOpOnTensorsConversion
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DotGeneralOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure();
}
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
auto lhs_bathcing_dims =
Extract1DVector(dim_numbers.lhs_batching_dimensions());
auto rhs_bathcing_dims =
Extract1DVector(dim_numbers.rhs_batching_dimensions());
auto lhs_contracting_dims =
Extract1DVector(dim_numbers.lhs_contracting_dimensions());
auto rhs_contracting_dims =
Extract1DVector(dim_numbers.rhs_contracting_dimensions());
if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected lhs batching dimensions exactly {0}");
}
if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected rhs batching dimensions exactly {0}");
}
if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
return rewriter.notifyMatchFailure(
op, "expected lhs contracting dimensions exactly {2}");
}
if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
return rewriter.notifyMatchFailure(
op, "expected rhs contracting dimensions exactly {1}");
}
Location loc = op.getLoc();
mhlo::DotGeneralOp::Adaptor adaptor(args);
Type result_type = op.getResult().getType();
auto shaped_type = result_type.cast<ShapedType>();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
loc, result_type, dyn_shape);
{
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
rewriter.getIndexType());
Region& region = init_tensor.body();
Block* block = rewriter.createBlock(&region, region.begin(), arg_types);
rewriter.setInsertionPointToEnd(block);
rewriter.create<YieldOp>(loc, zero);
}
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{result_type},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{init_tensor});
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
@ -1299,7 +1382,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion>(context);
DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>(
context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -871,3 +871,24 @@ func @dot_dot(%arg0: tensor<?xf32>,
// CHECK: linalg.dot
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
// -----
func @dot_general(%arg0: tensor<?x?x3xf32>,
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
%0 ="mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
},
precision_config = ["DEFAULT", "DEFAULT"]
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)