Upstream mhlo.dot_general lowering to Linalg to MHLO repo
PiperOrigin-RevId: 351514250
This commit is contained in:
parent
97a618f91a
commit
300a7c11ce
|
@ -99,6 +99,14 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
|||
return dyn_sizes;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
|
||||
SmallVector<int64_t, 4> ret;
|
||||
for (const APInt& element : elements) {
|
||||
ret.push_back(element.getLimitedValue());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
|
@ -1131,6 +1139,81 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
|||
}
|
||||
};
|
||||
|
||||
SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
||||
SmallVector<Value, 8> dyn_shape;
|
||||
if (result_type.isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
if (result_type.isDynamicDim(1))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
|
||||
if (result_type.isDynamicDim(2))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
class DotGeneralOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||
public:
|
||||
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DotGeneralOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||
return failure();
|
||||
}
|
||||
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
||||
auto lhs_bathcing_dims =
|
||||
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
||||
auto rhs_bathcing_dims =
|
||||
Extract1DVector(dim_numbers.rhs_batching_dimensions());
|
||||
auto lhs_contracting_dims =
|
||||
Extract1DVector(dim_numbers.lhs_contracting_dimensions());
|
||||
auto rhs_contracting_dims =
|
||||
Extract1DVector(dim_numbers.rhs_contracting_dimensions());
|
||||
if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected lhs batching dimensions exactly {0}");
|
||||
}
|
||||
if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected rhs batching dimensions exactly {0}");
|
||||
}
|
||||
if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected lhs contracting dimensions exactly {2}");
|
||||
}
|
||||
if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected rhs contracting dimensions exactly {1}");
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||
Type result_type = op.getResult().getType();
|
||||
auto shaped_type = result_type.cast<ShapedType>();
|
||||
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
||||
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
|
||||
loc, result_type, dyn_shape);
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
||||
rewriter.getIndexType());
|
||||
Region& region = init_tensor.body();
|
||||
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
rewriter.create<YieldOp>(loc, zero);
|
||||
}
|
||||
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
||||
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||
/*outputBuffers=*/ValueRange{init_tensor});
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
@ -1299,7 +1382,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>,
|
||||
DotOpOnTensorsConversion>(context);
|
||||
DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>(
|
||||
context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
|
|
@ -871,3 +871,24 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
|||
// CHECK: linalg.dot
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 ="mhlo.dot_general"(%arg0, %arg1) {
|
||||
dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
},
|
||||
precision_config = ["DEFAULT", "DEFAULT"]
|
||||
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
|
||||
|
|
Loading…
Reference in New Issue