/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. #include #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace { SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { static constexpr StringRef kParallelIterType = "parallel"; return SmallVector(nParallelLoops, kParallelIterType); } template Value GetResultValue(Operation* op) { return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); } template ShapedType GetHloOpResultType(Operation* op) { return GetResultValue(op).getType().template cast(); } template bool VerifyHloOpBufferOrTensorSemantics(Operation* op) { auto verify_type = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); }; if (!llvm::all_of(op->getOperands(), verify_type)) return false; return isLHLO ? op->getResults().empty() : llvm::all_of(op->getResults(), verify_type); } template class PointwiseToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); ShapedType t0 = args[0].getType().template dyn_cast(); if (!t0) return failure(); unsigned nloops = t0.getRank(); auto fail = [&](ShapedType t) { return !t || !t.hasRank() || t.getRank() != nloops || !(t.getElementType().isSignlessIntOrFloat() || t.getElementType().isa()); }; if (llvm::any_of(args, [&](Value v) { return fail(v.getType().dyn_cast()); }) || llvm::any_of(op.getOperation()->getResultTypes(), [&](Type t) { return fail(t.dyn_cast()); })) return emitError(loc, "lhlo to linalg conversion expects ranked args of " "signless int, float or complex element type with ") << nloops << " parallel iterators: " << *(op.getOperation()); // Construct the indexing maps needed for linalg.generic ops. SmallVector body_arg_types, body_result_types, op_result_types; // This doesnt account for implicit broadcast, but the working assumption // in HLO/LHLO is that are broadcasts are made explicit. if (isLHLO && !nloops) return failure(); int num_inputs = (isLHLO ? args.size() - 1 : args.size()); ValueRange inputs(args.take_front(num_inputs)); for (Value in : inputs) body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); ValueRange output_buffers(args.take_back(args.size() - num_inputs)); for (Value out : output_buffers) body_result_types.emplace_back(getElementTypeOrSelf(out.getType())); if (!isLHLO) { // HLO operations have return as tensor types. assert(body_result_types.empty() && "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); body_result_types.push_back(getElementTypeOrSelf(result)); op_result_types.push_back(result.getType()); } AffineMap common_indexing_map = nloops ? rewriter.getMultiDimIdentityMap(nloops) : AffineMap::get(nloops, 0, rewriter.getContext()); SmallVector indexing_maps(args.size() + (isLHLO ? 0 : 1), common_indexing_map); bool failed = false; auto linalg_op = rewriter.create( loc, op_result_types, inputs, output_buffers, /*initTensors=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. Value op_result = lmhlo::HloOpToStdScalarOp::map( op, body_result_types, llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); if (op_result == nullptr) { failed = true; } else { nested_builder.create(loc, op_result); } }); if (failed) return failure(); rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); return success(); } }; template class ScalarPointwiseToStandardConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); auto arg_type = lhlo_op.getOperand(0).getType().template dyn_cast(); if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() || (arg_type.getRank() != 0)) { return failure(); } // Create two loads from the input. auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of lmhlo namespace. Value op_result = lmhlo::HloOpToStdScalarOp::map( lhlo_op, arg_type.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, op_result, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return success(); } }; //===----------------------------------------------------------------------===// // lmhlo.convolution conversion pattern. //===----------------------------------------------------------------------===// /// Converts lmhlo.convolution operation to a linalg.conv op. struct ConvToLinalgConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. if (const mhlo::ConvDimensionNumbers& dimension_numbers = op.dimension_numbers()) { const int input_spatial_rank = llvm::size(dimension_numbers.input_spatial_dimensions()); // The dimensions for input should follow the order of // batch_count, spatial_dims..., input_feature_count. if (dimension_numbers.input_batch_dimension().getInt() != 0 || dimension_numbers.input_feature_dimension().getInt() != (input_spatial_rank + 1)) return failure(); const int kernel_spatial_rank = llvm::size(dimension_numbers.kernel_spatial_dimensions()); // The dimensions for filter should follow the order of // spatial_dims..., input_feature_count, num_output_feature_count. if (dimension_numbers.kernel_input_feature_dimension().getInt() != kernel_spatial_rank || dimension_numbers.kernel_output_feature_dimension().getInt() != (kernel_spatial_rank + 1)) return failure(); const int output_spatial_rank = llvm::size(dimension_numbers.output_spatial_dimensions()); // The dimensions for output should follow the order of // batch_count, spatial_dims.., output_feature_count. if (dimension_numbers.output_batch_dimension().getInt() != 0 || dimension_numbers.output_feature_dimension().getInt() != (output_spatial_rank + 1)) return failure(); if (input_spatial_rank != output_spatial_rank || input_spatial_rank != kernel_spatial_rank) return failure(); auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin(); auto kernel_spatial_dim = dimension_numbers.kernel_spatial_dimensions().begin(); auto output_spatial_dim = dimension_numbers.output_spatial_dimensions().begin(); // Check if spatial dims are ordered correctly. for (int i = 0; i < input_spatial_rank; ++i) { const int dim = i + 1; if ((*input_spatial_dim++).getZExtValue() != dim || (*output_spatial_dim++).getZExtValue() != dim || (*kernel_spatial_dim++).getZExtValue() != i) return failure(); } } // TODO: LHS dilation for deconvolution not supported yet. // TODO(jurahul): Window reversal is not supported yet. if (op.lhs_dilation() || op.hasWindowReversal()) { return failure(); } llvm::SmallVector strides; if (auto window_strides = op.window_strides()) { auto range = window_strides->getAttributeValues(); strides.assign(range.begin(), range.end()); } auto strides_arg = ArrayAttr::get(strides, op.getContext()); llvm::SmallVector dilation; if (auto rhs_dilation = op.rhs_dilation()) { auto range = rhs_dilation->getAttributeValues(); dilation.assign(range.begin(), range.end()); } else { // Default dilation of 1. dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); } auto dilation_arg = ArrayAttr::get(dilation, op.getContext()); // Set padding only if it is non-zero. DenseIntElementsAttr padding = op.paddingAttr(); if (!padding || !llvm::any_of(padding.getValues(), [](APInt int_val) { return !int_val.isNullValue(); })) { padding = nullptr; } // The order of input and filter are switched with linalg.conv. rewriter.replaceOpWithNewOp( op, args[1], args[0], args[2], strides_arg, dilation_arg, padding); return success(); } }; /// Base class for lowering HLO operations that have one operand and one result, /// and are semantically equivalent to a copy of the input to the output (like /// transpose, some reshape, etc.). The derived classes need to provide a method /// `getIndexingMaps` that returns AffineMaps for the index maps of the input /// and the output. template class DataMovementOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!VerifyHloOpBufferOrTensorSemantics(op)) return failure(); auto result_type = GetHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); if (indexing_maps.empty()) return failure(); auto nloops = result_type.getRank(); auto loc = op.getLoc(); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/isLHLO ? ArrayRef{} : result_type, /*inputs=*/args.front(), /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, /*initTensor=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, *args.begin()); }); rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); return success(); } }; /// Pattern to convert BroadcastOp to Linalg ops. template class BroadcastConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy broadcast_op, Builder* b) { ShapedType input_type = broadcast_op.operand().getType().template cast(); unsigned input_rank = input_type.getRank(); unsigned nloops = GetHloOpResultType(broadcast_op).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes()); SmallVector input_dim_exprs; input_dim_exprs.reserve(input_rank); for (int i = 0; i < input_rank; ++i) { input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i)); } AffineMap input_map; MLIRContext* context = b->getContext(); if (input_dim_exprs.empty()) { // The input is a scalar, i.e. this is a scalar broadcast op. input_map = AffineMap::get(nloops, /*symbolCount=*/0, context); } else { input_map = AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context); } return {input_map, b->getMultiDimIdentityMap(nloops)}; } }; class HloBroadcastInDimConverter : public DataMovementOpConverter { public: using DataMovementOpConverter::DataMovementOpConverter; static SmallVector getIndexingMaps( mhlo::BroadcastInDimOp broadcast_op, Builder* b) { auto result_type = GetHloOpResultType(broadcast_op); auto operand_type = broadcast_op.operand().getType().template cast(); unsigned nloops = result_type.getRank(); // The input is a scalar, i.e. this is a scalar broadcast op. if (operand_type.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } auto operand_shape = operand_type.getShape(); SmallVector dim_exprs; dim_exprs.reserve(nloops); if (broadcast_op.broadcast_dimensions()) { for (const auto& broadcastDim : enumerate(broadcast_op.broadcast_dimensions().getIntValues())) { int size = broadcastDim.value().getSExtValue(); bool expansion_needed = operand_shape[broadcastDim.index()] == 1 && result_type.getShape()[size] != 1; dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) : b->getAffineDimExpr(size)); } } return { AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; class LhloBroadcastInDimConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); auto result_type = operand_adaptor.output().getType().cast(); auto result_shape = result_type.getShape(); auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter); Value operand = std::get<0>(operand_and_dims); auto broadcast_dims = std::get<1>(operand_and_dims); auto loc = op.getLoc(); auto nloops = result_type.getRank(); auto operand_type = operand.getType().cast(); // For a degenerate case, i.e. broadcasting with expansion of // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`. // Instead the value is loaded and used directly in `linalg.yield`. if (operand_type.getRank() == 1 && operand_type.getDimSize(0) < result_type.getDimSize(broadcast_dims.front())) { Value zero = rewriter.create(loc, 0); Value val = rewriter.create(loc, operand, llvm::makeArrayRef({zero})); rewriter.create( loc, /*inputs=*/ValueRange{}, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, val); }); } else { auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, operand_type, &rewriter); rewriter.create( loc, /*inputs=*/ValueRange{operand}, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { nested_builder.create(loc, *args.begin()); }); } rewriter.replaceOp(op, llvm::None); return success(); } // Inserts 'linalg.reshape' if there is a size-1 dim expansion. std::pair> InsertReshapeIfNecessary( lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const { lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); Value operand = operand_adaptor.operand(); auto operand_type = operand_adaptor.operand().getType().cast(); auto operand_shape = operand_type.getShape(); Value result = operand_adaptor.output(); auto result_type = result.getType().cast(); auto result_shape = result_type.getShape(); SmallVector operand_strides; int64_t operand_offset; if (failed(getStridesAndOffset(operand_type, operand_strides, operand_offset))) { op.emitOpError() << "Failed to get offset and strides."; } SmallVector new_shape, new_strides, broadcast_dims; SmallVector collapsed_dims_list; linalg::ReassociationIndices collapsed_dims; for (const auto& item : enumerate(op.broadcast_dimensions().getIntValues())) { size_t index = item.index(); int dim = item.value().getSExtValue(); collapsed_dims.push_back(index); bool expansion_needed = operand_shape[index] == 1 && result_shape[dim] != 1; if (expansion_needed) { continue; } new_shape.push_back(operand_shape[index]); new_strides.push_back(operand_strides[index]); broadcast_dims.push_back(dim); collapsed_dims_list.push_back(collapsed_dims); collapsed_dims.clear(); } // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1] // and all dimensions need expansion. Such memref will be reshaped to a 1D // memref with a single element. New shape and strides needs to be updated // accordingly. if (collapsed_dims_list.empty()) { collapsed_dims_list.push_back({}); new_shape.push_back(1); new_strides.push_back(1); broadcast_dims.push_back(0); } for (const auto& dims : collapsed_dims) { collapsed_dims_list.back().push_back(dims); } // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be // reduced. if (new_shape.size() < operand_shape.size()) { auto new_memref_type = MemRefType::get( new_shape, operand_type.getElementType(), makeStridedLinearLayoutMap(new_strides, operand_offset, rewriter.getContext())); operand = rewriter.create(op.getLoc(), new_memref_type, operand_adaptor.operand(), collapsed_dims_list); } return std::make_pair(operand, broadcast_dims); } SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, ArrayRef broadcast_dims, ArrayRef result_shape, MemRefType operand_type, Builder* b) const { unsigned nloops = result_shape.size(); // The input is a scalar, i.e. this is a scalar broadcast op. if (operand_type.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } auto operand_shape = operand_type.getShape(); SmallVector dim_exprs; dim_exprs.reserve(nloops); for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) { int size = broadcast_dim.value(); bool expansion_needed = operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1; if (expansion_needed) { op.emitOpError( "BroadcastInDimOp lowering to Linalg does not support size-1 " "dimensions expansion."); } dim_exprs.push_back(b->getAffineDimExpr(size)); } return { AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; template class TransposeConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto result_type = GetHloOpResultType(op).template cast(); auto nloops = result_type.getRank(); SmallVector input_exprs; input_exprs.resize(result_type.getRank()); for (auto permutation : llvm::enumerate(op.permutation())) { input_exprs[permutation.value().getZExtValue()] = b->getAffineDimExpr(permutation.index()); } return { AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; // Converts reshape ops that can be proven to be either a collapse of dimensions // or expansion of dimensions of the operand. template class ReshapeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy reshape_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!VerifyHloOpBufferOrTensorSemantics(reshape_op)) return failure(); ShapedType operand_type = reshape_op.operand().getType().template cast(); ShapedType result_type = GetHloOpResultType(reshape_op); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); // Compute the reassociation maps for the linalg operation. ArrayRef src_shape = (operand_type.getRank() > result_type.getRank() ? operand_type.getShape() : result_type.getShape()); ArrayRef dst_shape = (operand_type.getRank() > result_type.getRank() ? result_type.getShape() : operand_type.getShape()); unsigned curr_src_dim = 0, curr_dst_dim = 0; SmallVector reassociation_map( dst_shape.size()); bool is_expanding_or_collapsing = true; while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { int64_t dst_size = dst_shape[curr_dst_dim]; int64_t src_size = src_shape[curr_src_dim]; while (src_size < dst_size && curr_src_dim < src_shape.size()) { reassociation_map[curr_dst_dim].push_back( rewriter.getAffineDimExpr(curr_src_dim++)); src_size *= src_shape[curr_src_dim]; } if (src_size == dst_size) { reassociation_map[curr_dst_dim].push_back( rewriter.getAffineDimExpr(curr_src_dim++)); // If the next dim in dst_shape is not 1, treat subsequent dims in // src_shape which are 1 to be collapsed. if (curr_dst_dim == dst_shape.size() - 1 || dst_shape[curr_dst_dim + 1] != 1) { while (curr_src_dim < src_shape.size() && src_shape[curr_src_dim] == 1) { reassociation_map[curr_dst_dim].push_back( rewriter.getAffineDimExpr(curr_src_dim++)); } } } else { is_expanding_or_collapsing = false; break; } curr_dst_dim++; } if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) is_expanding_or_collapsing = false; if (!is_expanding_or_collapsing) { auto get_identity_exprs = [&rewriter](int n) { SmallVector exprs; for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); return exprs; }; Location loc = reshape_op.getLoc(); int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(), 1, std::multiplies()); auto elem_type = operand_type.getElementType(); SmallVector collapsing_map = { get_identity_exprs(dst_shape.size())}; SmallVector expanding_map = { get_identity_exprs(src_shape.size())}; if (isLHLO) { auto collapsed_type = MemRefType::get({total_elems}, elem_type); Value collapsed_op = rewriter.create( loc, collapsed_type, args[0], collapsing_map); Value reshape_buffer = rewriter.create( loc, result_type, collapsed_op, expanding_map); rewriter.replaceOpWithNewOp( reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, /*outputPermutation =*/nullptr); } else { auto collapsed_type = RankedTensorType::get({total_elems}, elem_type); Value collapsed_op = rewriter.create( loc, collapsed_type, args[0], collapsing_map); rewriter.replaceOpWithNewOp( reshape_op, result_type, collapsed_op, expanding_map); } return success(); } if (isLHLO) { Value reshape_buffer = rewriter.create( reshape_op.getLoc(), result_type, args[0], reassociation_map); rewriter.replaceOpWithNewOp( reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, /*outputPermutation =*/nullptr); } else { rewriter.replaceOpWithNewOp( reshape_op, result_type, args[0], reassociation_map); } return success(); } }; template class IotaConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy iota_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { ShapedType result_shaped_type = GetHloOpResultType(iota_op); if (!result_shaped_type) return failure(); auto result_element_type = result_shaped_type.getElementType(); if (!result_element_type.isSignlessIntOrFloat()) return failure(); // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = result_shaped_type.getRank(); auto linalg_op = rewriter.create( iota_op.getLoc(), /*resultTensorTypes=*/ isLHLO ? ArrayRef{} : ArrayRef{result_shaped_type}, /*inputs=*/ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, /*initTensors=*/ValueRange{}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, ValueRange args) { Value cast_op = nested_builder.create( nested_loc, ivs[iota_op.iota_dimension()], nested_builder.getIntegerType( result_element_type.getIntOrFloatBitWidth())); if (result_element_type.template isa()) { cast_op = nested_builder.create(nested_loc, cast_op, result_element_type); } nested_builder.create(nested_loc, cast_op); }); if (isLHLO) rewriter.replaceOp(iota_op, llvm::None); else rewriter.replaceOp(iota_op, linalg_op.result_tensors()); return success(); } }; template class ConstConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( OpTy const_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { Location loc = const_op.getLoc(); auto value_attr = const_op.value().template cast(); if (value_attr.getType().getRank() != 0) return failure(); ReplaceConstOp(loc, const_op, value_attr, rewriter); return success(); } private: void ReplaceConstOp(Location loc, mhlo::ConstOp op, DenseElementsAttr value_attr, ConversionPatternRewriter& rewriter) const { Value std_tensor_const = rewriter.create(loc, value_attr); rewriter.replaceOp(op, {std_tensor_const}); } void ReplaceConstOp(Location loc, lmhlo::ConstOp op, DenseElementsAttr value_attr, ConversionPatternRewriter& rewriter) const { Value std_scalar_const = rewriter.create(loc, value_attr.getValue({})); rewriter.create(loc, std_scalar_const, op.getOperand(), llvm::None); rewriter.eraseOp(op); } }; class ReduceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::ReduceOp reduce_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = reduce_op.getLoc(); lmhlo::ReduceOp::Adaptor adaptor(args); auto operand_shape = adaptor.operands()[0].getType().template dyn_cast(); if (!operand_shape || !operand_shape.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); return failure(); } // First fill the output buffer with the init value. Value init_value = rewriter.create(loc, adaptor.init_values()[0]); rewriter.create(loc, adaptor.out()[0], init_value); DenseIntElementsAttr dimensions_attr = reduce_op.dimensions(); SmallVector reduction_dims; for (const auto& dim : dimensions_attr.getIntValues()) { reduction_dims.push_back(dim.getSExtValue()); } SmallVector src_exprs; SmallVector dst_exprs; SmallVector types; for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) { bool is_reduced = llvm::is_contained(reduction_dims, i); types.push_back(is_reduced ? getReductionIteratorTypeName() : getParallelIteratorTypeName()); src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); if (!is_reduced) { dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } } auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs}); auto linalg_op = rewriter.create( loc, /*resultTensorTypes=*/ArrayRef{}, /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), /*initTensors=*/ValueRange{}, maps, types); rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(), linalg_op.region().end()); { OpBuilder::InsertionGuard region_guard(rewriter); Block* block = linalg_op.getBody(); rewriter.setInsertionPoint(&block->front()); // The incoming region is operating on buffers, while linalg.generic // expects scalar SSA values. Add some allocs around the original op to // make it compatible. auto arg_type = block->getArgument(0).getType().cast(); Value alloc_a = rewriter.create(loc, arg_type); Value alloc_b = rewriter.create(loc, arg_type); Value alloc_res = rewriter.create(loc, arg_type); // Now turn the existing signature // (memref, memref, memref) -> () // into // (X, X) -> X TypeConverter::SignatureConversion signature_converter(3); signature_converter.remapInput(0, alloc_a); signature_converter.remapInput(1, alloc_b); signature_converter.remapInput(2, alloc_res); signature_converter.addInputs( {arg_type.getElementType(), arg_type.getElementType()}); Block* entry_block = rewriter.applySignatureConversion( &linalg_op.region(), signature_converter); // Store the arguments into the newly allocated buffers. rewriter.setInsertionPointAfter(alloc_res.getDefiningOp()); rewriter.create(loc, entry_block->getArgument(0), alloc_a); rewriter.create(loc, entry_block->getArgument(1), alloc_b); rewriter.replaceOp(entry_block->getTerminator(), {}); // Load & yield the result. rewriter.setInsertionPointToEnd(entry_block); auto load_res = rewriter.create(loc, alloc_res); rewriter.create(loc, ValueRange{load_res}); } rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults()); return success(); } }; // TODO(b/156787842): Support the lowering for dynamic shapes. template class ReverseConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { auto result_type = GetHloOpResultType(op).template cast(); auto nloops = result_type.getRank(); SmallVector input_exprs; input_exprs.reserve(nloops); for (int i = 0; i < nloops; ++i) input_exprs.push_back(b->getAffineDimExpr(i)); for (auto dim : op.dimensions()) { int i = dim.getZExtValue(); if (result_type.isDynamicDim(i)) return {}; int n = result_type.getShape()[i]; input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i]; } return { AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( lmhlo::SliceOp slice_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = slice_op.getLoc(); auto arg_type = slice_op.getOperand(0).getType().template dyn_cast(); if (!arg_type || !arg_type.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); return failure(); } SmallVector ranges; for (int i = 0, e = arg_type.getRank(); i < e; ++i) { Value start_index = rewriter.create( loc, slice_op.start_indices().getValue(i)); Value limit_index = rewriter.create( loc, slice_op.limit_indices().getValue(i)); Value stride = rewriter.create( loc, slice_op.strides().getValue(i)); ranges.push_back(rewriter.create(loc, start_index, limit_index, stride)); } auto linalg_slice = rewriter.create(loc, slice_op.getOperand(0), ranges); rewriter.create(loc, linalg_slice, slice_op.getOperand(1)); rewriter.eraseOp(slice_op); return success(); } }; void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert, ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReduceConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, ScalarPointwiseToStandardConverter, SliceConverter, TransposeConverter >(context); // clang-format on } // Converts LHLO ops to Linalg generic. // Sample result for lmhlo::AddOp. // // "lmhlo.add"(%arg1, %arg2, %out) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // // will be converted to // // #map0 = (d0, d1) -> (d0, d1) // "linalg.generic"(%arg1, %arg2, %out) ( { // ^bb0(%arg4: f32, %arg5: f32): // %0 = addf %arg4, %arg5 : f32 // "linalg.yield"(%0) : (f32) -> () // }) { // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { signalPassFailure(); } } }; struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { signalPassFailure(); } } }; } // namespace namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { return std::make_unique(); } } // namespace lmhlo namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { patterns ->insert, ConstConverter, HloBroadcastInDimConverter, IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { return std::make_unique(); } } // namespace mhlo } // namespace mlir