Add support for lowering mhlo.iota to Linalg.

PiperOrigin-RevId: 322799853
This commit is contained in:
Hanhan Wang 2020-07-23 16:18:01 +00:00 committed by Mehdi Amini
parent 4251630426
commit 8f262ae8f5
2 changed files with 62 additions and 44 deletions

View File

@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
}; };
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> { template <typename OpTy, bool isLHLO = true>
class IotaConverter : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::IotaOp iotaOp, ArrayRef<Value> args, OpTy iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType = ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
iotaOp.getOperand().getType().dyn_cast<MemRefType>(); if (!resultShapedType) return failure();
if (!resultMemrefType) return failure();
auto resultElementType = resultMemrefType.getElementType(); auto resultElementType = resultShapedType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure(); if (!resultElementType.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank(); unsigned nloops = resultShapedType.getRank();
rewriter.create<linalg::IndexedGenericOp>( auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), ArrayRef<Type>{}, args, iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
0, // args_in 0, // args_in
1, // args_out 1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
nestedBuilder.getIntegerType( nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth())); resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.isa<FloatType>()) { if (resultElementType.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp, castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType); resultElementType);
} }
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
}); });
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None); rewriter.replaceOp(iotaOp, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
return success(); return success();
} }
}; };
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>, patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter, ConstConverter,
ConvToLinalgConverter, ConvToLinalgConverter,
IotaConverter, IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter, LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<lmhlo::AbsOp>, PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>, PointwiseToLinalgConverter<lmhlo::AddOp>,
@ -870,36 +872,37 @@ namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context, void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>, patterns
HloBroadcastInDimConverter, ->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>, HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>, PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>, PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>, PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>, PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>, PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>, PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>, PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>, PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>, PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>, PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>, PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>, PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>, PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context); ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
} }
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32>
}
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32