Upstream mhlo.dot_general lowering to Linalg to MHLO repo
PiperOrigin-RevId: 351514250
This commit is contained in:
parent
97a618f91a
commit
300a7c11ce
|
@ -99,6 +99,14 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
||||||
return dyn_sizes;
|
return dyn_sizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
|
||||||
|
SmallVector<int64_t, 4> ret;
|
||||||
|
for (const APInt& element : elements) {
|
||||||
|
ret.push_back(element.getLimitedValue());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename OpTy, bool isLHLO = true>
|
template <typename OpTy, bool isLHLO = true>
|
||||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
|
@ -1131,6 +1139,81 @@ class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
||||||
|
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
||||||
|
SmallVector<Value, 8> dyn_shape;
|
||||||
|
if (result_type.isDynamicDim(0))
|
||||||
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||||
|
if (result_type.isDynamicDim(1))
|
||||||
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
|
||||||
|
if (result_type.isDynamicDim(2))
|
||||||
|
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
|
||||||
|
return dyn_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
class DotGeneralOpOnTensorsConversion
|
||||||
|
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DotGeneralOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
||||||
|
auto lhs_bathcing_dims =
|
||||||
|
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
||||||
|
auto rhs_bathcing_dims =
|
||||||
|
Extract1DVector(dim_numbers.rhs_batching_dimensions());
|
||||||
|
auto lhs_contracting_dims =
|
||||||
|
Extract1DVector(dim_numbers.lhs_contracting_dimensions());
|
||||||
|
auto rhs_contracting_dims =
|
||||||
|
Extract1DVector(dim_numbers.rhs_contracting_dimensions());
|
||||||
|
if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected lhs batching dimensions exactly {0}");
|
||||||
|
}
|
||||||
|
if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected rhs batching dimensions exactly {0}");
|
||||||
|
}
|
||||||
|
if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected lhs contracting dimensions exactly {2}");
|
||||||
|
}
|
||||||
|
if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected rhs contracting dimensions exactly {1}");
|
||||||
|
}
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
||||||
|
Type result_type = op.getResult().getType();
|
||||||
|
auto shaped_type = result_type.cast<ShapedType>();
|
||||||
|
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
||||||
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
||||||
|
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
||||||
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||||
|
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
|
||||||
|
loc, result_type, dyn_shape);
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
|
||||||
|
rewriter.getIndexType());
|
||||||
|
Region& region = init_tensor.body();
|
||||||
|
Block* block = rewriter.createBlock(®ion, region.begin(), arg_types);
|
||||||
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
rewriter.create<YieldOp>(loc, zero);
|
||||||
|
}
|
||||||
|
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
||||||
|
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
||||||
|
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
||||||
|
/*outputBuffers=*/ValueRange{init_tensor});
|
||||||
|
rewriter.replaceOp(op, linalg_op.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -1299,7 +1382,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>,
|
TransposeConverter<mhlo::TransposeOp, false>,
|
||||||
DotOpOnTensorsConversion>(context);
|
DotOpOnTensorsConversion, DotGeneralOpOnTensorsConversion>(
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||||
|
|
|
@ -871,3 +871,24 @@ func @dot_dot(%arg0: tensor<?xf32>,
|
||||||
// CHECK: linalg.dot
|
// CHECK: linalg.dot
|
||||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
|
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @dot_general(%arg0: tensor<?x?x3xf32>,
|
||||||
|
%arg1: tensor<?x3x?xf32>) -> tensor<?x?x?xf32> {
|
||||||
|
%0 ="mhlo.dot_general"(%arg0, %arg1) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||||
|
rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT"]
|
||||||
|
} : (tensor<?x?x3xf32>, tensor<?x3x?xf32>) -> tensor<?x?x?xf32>
|
||||||
|
return %0 : tensor<?x?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: func @dot_general(%[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||||
|
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
|
||||||
|
// CHECK: linalg.batch_matmul
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x3xf32>, tensor<?x3x?xf32>)
|
||||||
|
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?xf32>)
|
||||||
|
|
Loading…
Reference in New Issue