Upstream mhlo.dot lowering to Linalg to MHLO repo.

We prototyped the lowering from mhlo.dot to linalg.matmul in IREE. Since Linalg
now supports matmul in tensors world, we can move the lowering logic to tensors
world, and upstream to legalize_to_linalg.cc. The patch lowers the mhlo.dot to
the linalg.matmul/matvec/dot in tensors world.

PiperOrigin-RevId: 351184911
This commit is contained in:
Hanhan Wang 2021-01-11 10:33:14 -08:00 committed by TensorFlow MLIR Team
parent 180f917446
commit 8f58f844e5
2 changed files with 160 additions and 1 deletions

View File

@ -1014,6 +1014,123 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
} }
}; };
enum class DotOperationType {
kVectorDot = 0,
kMatrixVector = 1,
kMatrixMatrix = 2,
kUnsupported = 3
};
DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
ArrayRef<int64_t> lhs_shape =
dot_op.lhs().getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> rhs_shape =
dot_op.rhs().getType().cast<ShapedType>().getShape();
auto shape_matches = [](int64_t a, int64_t b) {
return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
a == b;
};
if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
shape_matches(lhs_shape[0], rhs_shape[0])) {
return DotOperationType::kVectorDot;
}
if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
shape_matches(lhs_shape[1], rhs_shape[0])) {
return DotOperationType::kMatrixVector;
}
if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
shape_matches(lhs_shape[1], rhs_shape[0])) {
return DotOperationType::kMatrixMatrix;
}
return DotOperationType::kUnsupported;
}
SmallVector<Value, 8> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
Value lhs, Value rhs,
ShapedType result_type,
DotOperationType type) {
SmallVector<Value, 8> dyn_shape;
switch (type) {
case DotOperationType::kMatrixMatrix: {
if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
if (result_type.isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kMatrixVector: {
if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
break;
}
case DotOperationType::kVectorDot:
case DotOperationType::kUnsupported:
default: {
break;
}
}
return dyn_shape;
}
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DotOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
return failure();
}
Location loc = op.getLoc();
mhlo::DotOp::Adaptor adaptor(args);
Type result_type = op.getResult().getType();
auto shaped_type = result_type.cast<ShapedType>();
DotOperationType op_type = GetDotOperationType(op);
SmallVector<Value, 8> dyn_shape = GetDotOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type, op_type);
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto init_tensor = rewriter.create<DynamicTensorFromElementsOp>(
loc, result_type, dyn_shape);
{
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type, 4> arg_types(shaped_type.getRank(),
rewriter.getIndexType());
Region& region = init_tensor.body();
Block* block = rewriter.createBlock(&region, region.begin(), arg_types);
rewriter.setInsertionPointToEnd(block);
rewriter.create<YieldOp>(loc, zero);
}
linalg::LinalgOp linalg_op;
switch (op_type) {
case DotOperationType::kMatrixMatrix: {
linalg_op = rewriter.create<linalg::MatmulOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
break;
}
case DotOperationType::kMatrixVector: {
linalg_op = rewriter.create<linalg::MatvecOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
break;
}
case DotOperationType::kVectorDot: {
linalg_op = rewriter.create<linalg::DotOp>(
loc, TypeRange{result_type},
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{init_tensor});
break;
}
case DotOperationType::kUnsupported:
default: {
return op.emitError("unsupported dot operation type");
}
}
rewriter.replaceOp(op, linalg_op->getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context, void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
@ -1181,7 +1298,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::XorOp, false>, PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context); TransposeConverter<mhlo::TransposeOp, false>,
DotOpOnTensorsConversion>(context);
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -830,3 +830,44 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?xf32>) // CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?xf32>)
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
func @dot_matmul(%arg0: tensor<2x3xf32>,
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
tensor<3x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
// CHECK: func @dot_matmul(%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<2x?xf32>)
// -----
func @dot_matvec(%arg0: tensor<?x3xf32>,
%arg1: tensor<3xf32>) -> tensor<?xf32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?x3xf32>,
tensor<3xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK: func @dot_matvec(%[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: linalg.matvec
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x3xf32>, tensor<3xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?xf32>)
// -----
func @dot_dot(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>) -> tensor<f32> {
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: func @dot_dot(%[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>)
// CHECK: %[[INIT:.*]] = dynamic_tensor_from_elements
// CHECK: linalg.dot
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<f32>)