/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. #include "third_party/absl/memory/memory.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { namespace { SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { static constexpr StringRef kParallelIterType = "parallel"; return SmallVector(nParallelLoops, kParallelIterType); } template Value getResultValue(Operation* op) { return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); } template ShapedType getHloOpResultType(Operation* op) { return getResultValue(op).getType().template cast(); } template bool verifyHloOpBufferOrTensorSemantics(Operation* op) { auto verifyType = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); }; if (!llvm::all_of(op->getOperands(), verifyType)) return false; return isLHLO ? op->getResults().empty() : llvm::all_of(op->getResults(), verifyType); } template class PointwiseToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); auto argType = op.getOperation()->getOperand(0).getType().template cast(); if (!argType.hasRank()) { emitError(loc, "lhlo to linalg conversion expects ranked args"); return failure(); } auto elemTy = argType.getElementType(); if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { return failure(); } // Construct the indexing maps needed for linalg.generic ops. SmallVector indexing_maps; SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; // This doesnt account for implicit broadcast, but the working assumption // here is that are broadcasts have been made explicit. unsigned nloops = argType.getRank(); if (isLHLO && !nloops) return failure(); int operandCount = (isLHLO ? args.size() - 1 : args.size()); auto verifyArgOrResultType = [&](Value val) -> ShapedType { auto shapedType = val.getType().dyn_cast(); if (!shapedType || (!shapedType.isa() && !shapedType.isa()) || shapedType.getRank() != nloops) return nullptr; indexing_maps.emplace_back( nloops ? rewriter.getMultiDimIdentityMap(nloops) : AffineMap::get(nloops, 0, rewriter.getContext())); return shapedType; }; for (const auto& arg : llvm::enumerate(args)) { auto shapedType = verifyArgOrResultType(arg.value()); if (!shapedType) return failure(); auto& result_or_body_arg = arg.index() < operandCount ? bodyArgTypes : bodyResultTypes; result_or_body_arg.emplace_back(shapedType.getElementType()); } if (!isLHLO) { // HLO operations have return as tensor types. assert(bodyResultTypes.empty() && "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); auto shapedType = verifyArgOrResultType(result); if (!shapedType) return failure(); bodyResultTypes.push_back(shapedType.getElementType()); opResultTypes.push_back(shapedType); } int64_t args_count = bodyArgTypes.size(); int64_t results_count = bodyResultTypes.size(); auto linalgOp = rewriter.create( loc, opResultTypes, args, args_count, results_count, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. Value opResult = lmhlo::HloOpToStdScalarOp::map( op, bodyResultTypes, llvm::to_vector<2>(args.take_front(args_count)), &rewriter); nestedBuilder.create(loc, opResult); }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); } }; template class ScalarPointwiseToStandardConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); auto argType = lhlo_op.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.getElementType().isSignlessIntOrFloat() || (argType.getRank() != 0)) { return failure(); } // Create two loads from the input. auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of lmhlo namespace. Value opResult = lmhlo::HloOpToStdScalarOp::map( lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return success(); } }; //===----------------------------------------------------------------------===// // lmhlo.convolution conversion pattern. //===----------------------------------------------------------------------===// /// Converts lmhlo.convolution operation to a linalg.conv op. struct ConvToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); // The dimensions for input should follow the order of // batch_count, spatial_dims..., input_feature_count. if (dimensionNumbers.input_batch_dimension().getInt() != 0 || dimensionNumbers.input_feature_dimension().getInt() != (inputSpatialRank + 1)) return failure(); const int kernelSpatialRank = llvm::size(dimensionNumbers.kernel_spatial_dimensions()); // The dimensions for filter should follow the order of // spatial_dims..., input_feature_count, num_output_feature_count. if (dimensionNumbers.kernel_input_feature_dimension().getInt() != kernelSpatialRank || dimensionNumbers.kernel_output_feature_dimension().getInt() != (kernelSpatialRank + 1)) return failure(); const int outputSpatialRank = llvm::size(dimensionNumbers.output_spatial_dimensions()); // The dimensions for output should follow the order of // batch_count, spatial_dims.., output_feature_count. if (dimensionNumbers.output_batch_dimension().getInt() != 0 || dimensionNumbers.output_feature_dimension().getInt() != (outputSpatialRank + 1)) return failure(); if (inputSpatialRank != outputSpatialRank || inputSpatialRank != kernelSpatialRank) return failure(); auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin(); auto kernelSpatialDim = dimensionNumbers.kernel_spatial_dimensions().begin(); auto outputSpatialDim = dimensionNumbers.output_spatial_dimensions().begin(); // Check if spatial dims are ordered correctly. for (int i = 0; i < inputSpatialRank; ++i) { const int dim = i + 1; if ((*inputSpatialDim++).getZExtValue() != dim || (*outputSpatialDim++).getZExtValue() != dim || (*kernelSpatialDim++).getZExtValue() != i) return failure(); } } // TODO: LHS dilation for deconvolution not supported yet. if (op.lhs_dilation()) { return failure(); } llvm::SmallVector strides; if (auto windowStrides = op.window_strides()) { auto range = windowStrides->getAttributeValues(); strides.assign(range.begin(), range.end()); } auto stridesArg = ArrayAttr::get(strides, op.getContext()); llvm::SmallVector dilation; if (auto rhsDilation = op.rhs_dilation()) { auto range = rhsDilation->getAttributeValues(); dilation.assign(range.begin(), range.end()); } else { // Default dilation of 1. dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); } auto dilationArg = ArrayAttr::get(dilation, op.getContext()); // Set padding only if it is non-zero. DenseIntElementsAttr padding = op.paddingAttr(); if (!padding || !llvm::any_of(padding.getValues(), [](APInt intVal) { return !intVal.isNullValue(); })) { padding = nullptr; } // The order of input and filter are switched with linalg.conv. rewriter.replaceOpWithNewOp( op, args[1], args[0], args[2], stridesArg, dilationArg, padding); return success(); } }; /// Base class for lowering HLO operations that have one operand and one result, /// and are semantically equivalent to a copy of the input to the output (like /// transpose, some reshape, etc.). The derived classes need to provide a method /// `getIndexingMaps` that returns AffineMaps for the index maps of the input /// and the output. template class DataMovementOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); auto resultType = getHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); if (indexing_maps.empty()) return failure(); auto nloops = resultType.getRank(); auto loc = op.getLoc(); auto linalgOp = rewriter.create( loc, isLHLO ? ArrayRef{} : resultType, args, /*argsIn=*/1, /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); } }; /// Pattern to convert BroadcastOp to Linalg ops. template class BroadcastConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy broadcastOp, Builder* b) { ShapedType inputType = broadcastOp.operand().getType().template cast(); unsigned inputRank = inputType.getRank(); unsigned nloops = getHloOpResultType(broadcastOp).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); SmallVector inputDimExprs; inputDimExprs.reserve(inputRank); for (int i = 0; i < inputRank; ++i) { inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); } AffineMap inputMap; MLIRContext* context = b->getContext(); if (inputDimExprs.empty()) { // The input is a scalar, i.e. this is a scalar broadcast op. inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); } else { inputMap = AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); } return {inputMap, b->getMultiDimIdentityMap(nloops)}; } }; class HloBroadcastInDimConverter : public DataMovementOpConverter { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps( mhlo::BroadcastInDimOp broadcastOp, Builder* b) { auto resultType = getHloOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); unsigned nloops = resultType.getRank(); // The input is a scalar, i.e. this is a scalar broadcast op. if (operandType.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } auto operandShape = operandType.getShape(); SmallVector dimExprs; dimExprs.reserve(nloops); if (broadcastOp.broadcast_dimensions()) { for (const auto& broadcastDim : enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { int size = broadcastDim.value().getSExtValue(); bool expansion_needed = operandShape[broadcastDim.index()] == 1 && resultType.getShape()[size] != 1; dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) : b->getAffineDimExpr(size)); } } return { AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; class LhloBroadcastInDimConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); auto result_type = operand_adaptor.output().getType().cast(); auto result_shape = result_type.getShape(); auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter); Value operand = std::get<0>(operand_and_dims); auto broadcast_dims = std::get<1>(operand_and_dims); auto loc = op.getLoc(); auto nloops = result_type.getRank(); auto operand_type = operand.getType().cast(); // For a degenerate case, i.e. broadcasting with expansion of // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`. // Instead the value is loaded and used directly in `linalg.yield`. if (operand_type.getRank() == 1 && operand_type.getDimSize(0) < result_type.getDimSize(broadcast_dims.front())) { Value zero = rewriter.create(loc, 0); Value val = rewriter.create(loc, operand, llvm::makeArrayRef({zero})); rewriter.create( loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), /*argsIn=*/0, /*argsOut=*/1, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, val); }); } else { auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, operand_type, &rewriter); rewriter.create( loc, llvm::None, llvm::makeArrayRef({operand, operand_adaptor.output()}), /*argsIn=*/1, /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); }); } rewriter.replaceOp(op, llvm::None); return success(); } // Inserts 'linalg.reshape' if there is a size-1 dim expansion. std::pair> InsertReshapeIfNecessary( lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const { lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); Value operand = operand_adaptor.operand(); auto operand_type = operand_adaptor.operand().getType().cast(); auto operand_shape = operand_type.getShape(); Value result = operand_adaptor.output(); auto result_type = result.getType().cast(); auto result_shape = result_type.getShape(); SmallVector operand_strides; int64_t operand_offset; if (failed(getStridesAndOffset(operand_type, operand_strides, operand_offset))) { op.emitOpError() << "Failed to get offset and strides."; } SmallVector new_shape, new_strides, broadcast_dims; SmallVector collapsed_dims_list; linalg::ReassociationIndices collapsed_dims; for (const auto& item : enumerate(op.broadcast_dimensions().getIntValues())) { size_t index = item.index(); int dim = item.value().getSExtValue(); collapsed_dims.push_back(index); bool expansion_needed = operand_shape[index] == 1 && result_shape[dim] != 1; if (expansion_needed) { continue; } new_shape.push_back(operand_shape[index]); new_strides.push_back(operand_strides[index]); broadcast_dims.push_back(dim); collapsed_dims_list.push_back(collapsed_dims); collapsed_dims.clear(); } // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1] // and all dimensions need expansion. Such memref will be reshaped to a 1D // memref with a single element. New shape and strides needs to be updated // accordingly. if (collapsed_dims_list.empty()) { collapsed_dims_list.push_back({}); new_shape.push_back(1); new_strides.push_back(1); broadcast_dims.push_back(0); } for (const auto& dims : collapsed_dims) { collapsed_dims_list.back().push_back(dims); } // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be // reduced. if (new_shape.size() < operand_shape.size()) { auto new_memref_type = MemRefType::get( new_shape, operand_type.getElementType(), makeStridedLinearLayoutMap(new_strides, operand_offset, rewriter.getContext())); operand = rewriter.create(op.getLoc(), new_memref_type, operand_adaptor.operand(), collapsed_dims_list); } return std::make_pair(operand, broadcast_dims); } SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, ArrayRef broadcastDims, ArrayRef resultShape, MemRefType operandType, Builder* b) const { unsigned nloops = resultShape.size(); // The input is a scalar, i.e. this is a scalar broadcast op. if (operandType.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } auto operandShape = operandType.getShape(); SmallVector dimExprs; dimExprs.reserve(nloops); for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { int size = broadcastDim.value(); bool expansion_needed = operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; if (expansion_needed) { op.emitOpError( "BroadcastInDimOp lowering to Linalg does not support size-1 " "dimensions expansion."); } dimExprs.push_back(b->getAffineDimExpr(size)); } return { AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; template class TransposeConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.resize(resultType.getRank()); for (auto permutation : llvm::enumerate(op.permutation())) { inputExprs[permutation.value().getZExtValue()] = b->getAffineDimExpr(permutation.index()); } return { AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; // Converts reshape ops that can be proven to be either a collapse of dimensions // or expansion of dimensions of the operand. template class ReshapeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy reshapeOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure(); ShapedType operandType = reshapeOp.operand().getType().template cast(); ShapedType resultType = getHloOpResultType(reshapeOp); if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); // Compute the reassociation maps for the linalg operation. ArrayRef srcShape = (operandType.getRank() > resultType.getRank() ? operandType.getShape() : resultType.getShape()); ArrayRef dstShape = (operandType.getRank() > resultType.getRank() ? resultType.getShape() : operandType.getShape()); unsigned currSrcDim = 0, currDstDim = 0; SmallVector reassociationMap( dstShape.size()); while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; while (srcSize < dstSize && currSrcDim < srcShape.size()) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); srcSize *= srcShape[currSrcDim]; } if (srcSize == dstSize) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); // If the next dim in dstShape is not 1, treat subsequent dims in // srcShape which are 1 to be collapsed. if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); } } } else { return failure(); } currDstDim++; } if (currSrcDim != srcShape.size()) return failure(); if (isLHLO) { Value reshapeBuffer = rewriter.create( reshapeOp.getLoc(), resultType, args[0], reassociationMap); rewriter.replaceOpWithNewOp( reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, /*outputPermutation =*/nullptr); } else { rewriter.replaceOpWithNewOp( reshapeOp, resultType, args[0], reassociationMap); } return success(); } }; template class IotaConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { ShapedType resultShapedType = getHloOpResultType(iotaOp); if (!resultShapedType) return failure(); auto resultElementType = resultShapedType.getElementType(); if (!resultElementType.isSignlessIntOrFloat()) return failure(); // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = resultShapedType.getRank(); auto linalgOp = rewriter.create( iotaOp.getLoc(), isLHLO ? ArrayRef{} : resultShapedType, args, 0, // args_in 1, // args_out llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, ValueRange args) { Value castOp = nestedBuilder.create( nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], nestedBuilder.getIntegerType( resultElementType.getIntOrFloatBitWidth())); if (resultElementType.template isa()) { castOp = nestedBuilder.create(nestedLoc, castOp, resultElementType); } nestedBuilder.create(nestedLoc, castOp); }); if (isLHLO) rewriter.replaceOp(iotaOp, llvm::None); else rewriter.replaceOp(iotaOp, linalgOp.output_tensors()); return success(); } }; class ConstConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); if (valueAttr.getType().getRank() != 0) return failure(); auto stdConstOp = rewriter.create(loc, valueAttr.getValue({})); rewriter.create(loc, stdConstOp, constOp.getOperand(), ValueRange()); rewriter.eraseOp(constOp); return success(); } }; // TODO(b/156787842): Support the lowering for dynamic shapes. template class ReverseConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto resultType = getHloOpResultType(op).template cast(); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.reserve(nloops); for (int i = 0; i < nloops; ++i) inputExprs.push_back(b->getAffineDimExpr(i)); for (auto dim : op.dimensions()) { int i = dim.getZExtValue(); if (resultType.isDynamicDim(i)) return {}; int n = resultType.getShape()[i]; inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; } return { AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::SliceOp sliceOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = sliceOp.getLoc(); auto argType = sliceOp.getOperand(0).getType().template dyn_cast(); if (!argType || !argType.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); return failure(); } SmallVector ranges; for (int i = 0, e = argType.getRank(); i < e; ++i) { Value start_index = rewriter.create( loc, sliceOp.start_indices().getValue(i)); Value limit_index = rewriter.create( loc, sliceOp.limit_indices().getValue(i)); Value stride = rewriter.create( loc, sliceOp.strides().getValue(i)); ranges.push_back(rewriter.create(loc, start_index, limit_index, stride)); } auto linalg_slice = rewriter.create(loc, sliceOp.getOperand(0), ranges); rewriter.create(loc, linalg_slice, sliceOp.getOperand(1)); rewriter.eraseOp(sliceOp); return success(); } }; void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert, ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, SliceConverter >(context); // clang-format on } // Converts LHLO ops to Linalg generic. // Sample result for lmhlo::AddOp. // // "lmhlo.add"(%arg1, %arg2, %out) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // // will be converted to // // #map0 = (d0, d1) -> (d0, d1) // "linalg.generic"(%arg1, %arg2, %out) ( { // ^bb0(%arg4: f32, %arg5: f32): // %0 = addf %arg4, %arg5 : f32 // "linalg.yield"(%0) : (f32) -> () // }) { // args_in = 2, // args_out = 1, // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalg : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); } } }; struct HloLegalizeToLinalg : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); } } }; } // namespace namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { return absl::make_unique(); } static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); } // namespace lmhlo namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { patterns ->insert, HloBroadcastInDimConverter, IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { return absl::make_unique(); } static PassRegistration legalize_hlo_pass( "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); } // namespace mhlo } // namespace mlir