From 444fae9bac6db42a9ade58bb75671f72a3080337 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Fri, 23 Oct 2020 12:22:21 -0700 Subject: [PATCH] [NFC] Make naming style consistent. Use lowercase with underscores between words instead of camelStyle. PiperOrigin-RevId: 338722328 --- .../mhlo/transforms/legalize_to_linalg.cc | 483 +++++++++--------- 1 file changed, 243 insertions(+), 240 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index b64d662..834384e 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) { template bool verifyHloOpBufferOrTensorSemantics(Operation* op) { - auto verifyType = [&](Value val) -> bool { + auto verify_type = [&](Value val) -> bool { return (isLHLO && val.getType().isa()) || (!isLHLO && val.getType().isa()); }; - if (!llvm::all_of(op->getOperands(), verifyType)) return false; + if (!llvm::all_of(op->getOperands(), verify_type)) return false; return isLHLO ? op->getResults().empty() - : llvm::all_of(op->getResults(), verifyType); + : llvm::all_of(op->getResults(), verify_type); } template @@ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern { << nloops << " parallel iterators: " << *(op.getOperation()); // Construct the indexing maps needed for linalg.generic ops. - SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + SmallVector body_arg_types, body_result_types, op_result_types; // This doesnt account for implicit broadcast, but the working assumption // in HLO/LHLO is that are broadcasts are made explicit. if (isLHLO && !nloops) return failure(); - int numInputs = (isLHLO ? args.size() - 1 : args.size()); + int num_inputs = (isLHLO ? args.size() - 1 : args.size()); - ValueRange inputs(args.take_front(numInputs)); + ValueRange inputs(args.take_front(num_inputs)); for (Value in : inputs) - bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); + body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); - ValueRange outputBuffers(args.take_back(args.size() - numInputs)); - for (Value out : outputBuffers) - bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); + ValueRange output_buffers(args.take_back(args.size() - num_inputs)); + for (Value out : output_buffers) + body_result_types.emplace_back(getElementTypeOrSelf(out.getType())); if (!isLHLO) { // HLO operations have return as tensor types. - assert(bodyResultTypes.empty() && + assert(body_result_types.empty() && "When lowering HLO ops result can't be part of arguments"); Value result = op.getOperation()->getResult(0); - bodyResultTypes.push_back(getElementTypeOrSelf(result)); - opResultTypes.push_back(result.getType()); + body_result_types.push_back(getElementTypeOrSelf(result)); + op_result_types.push_back(result.getType()); } - AffineMap commonIndexingMap = + AffineMap common_indexing_map = nloops ? rewriter.getMultiDimIdentityMap(nloops) : AffineMap::get(nloops, 0, rewriter.getContext()); SmallVector indexing_maps(args.size() + (isLHLO ? 0 : 1), - commonIndexingMap); + common_indexing_map); - auto linalgOp = rewriter.create( - loc, opResultTypes, inputs, outputBuffers, + auto linalg_op = rewriter.create( + loc, op_result_types, inputs, output_buffers, /*initTensors=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value opResult = lmhlo::HloOpToStdScalarOp::map( - op, bodyResultTypes, + Value op_result = lmhlo::HloOpToStdScalarOp::map( + op, body_result_types, llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); - nestedBuilder.create(loc, opResult); + nested_builder.create(loc, op_result); }); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); return success(); } }; @@ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { LhloOp lhlo_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = lhlo_op.getLoc(); - auto argType = + auto arg_type = lhlo_op.getOperand(0).getType().template dyn_cast(); - if (!argType || !argType.getElementType().isSignlessIntOrFloat() || - (argType.getRank() != 0)) { + if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() || + (arg_type.getRank() != 0)) { return failure(); } @@ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of lmhlo namespace. - Value opResult = lmhlo::HloOpToStdScalarOp::map( - lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, + Value op_result = lmhlo::HloOpToStdScalarOp::map( + lhlo_op, arg_type.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); - rewriter.create(loc, opResult, lhlo_op.out()); + rewriter.create(loc, op_result, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return success(); } @@ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern { lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const mhlo::ConvDimensionNumbers& dimensionNumbers = + if (const mhlo::ConvDimensionNumbers& dimension_numbers = op.dimension_numbers()) { - const int inputSpatialRank = - llvm::size(dimensionNumbers.input_spatial_dimensions()); + const int input_spatial_rank = + llvm::size(dimension_numbers.input_spatial_dimensions()); // The dimensions for input should follow the order of // batch_count, spatial_dims..., input_feature_count. - if (dimensionNumbers.input_batch_dimension().getInt() != 0 || - dimensionNumbers.input_feature_dimension().getInt() != - (inputSpatialRank + 1)) + if (dimension_numbers.input_batch_dimension().getInt() != 0 || + dimension_numbers.input_feature_dimension().getInt() != + (input_spatial_rank + 1)) return failure(); - const int kernelSpatialRank = - llvm::size(dimensionNumbers.kernel_spatial_dimensions()); + const int kernel_spatial_rank = + llvm::size(dimension_numbers.kernel_spatial_dimensions()); // The dimensions for filter should follow the order of // spatial_dims..., input_feature_count, num_output_feature_count. - if (dimensionNumbers.kernel_input_feature_dimension().getInt() != - kernelSpatialRank || - dimensionNumbers.kernel_output_feature_dimension().getInt() != - (kernelSpatialRank + 1)) + if (dimension_numbers.kernel_input_feature_dimension().getInt() != + kernel_spatial_rank || + dimension_numbers.kernel_output_feature_dimension().getInt() != + (kernel_spatial_rank + 1)) return failure(); - const int outputSpatialRank = - llvm::size(dimensionNumbers.output_spatial_dimensions()); + const int output_spatial_rank = + llvm::size(dimension_numbers.output_spatial_dimensions()); // The dimensions for output should follow the order of // batch_count, spatial_dims.., output_feature_count. - if (dimensionNumbers.output_batch_dimension().getInt() != 0 || - dimensionNumbers.output_feature_dimension().getInt() != - (outputSpatialRank + 1)) + if (dimension_numbers.output_batch_dimension().getInt() != 0 || + dimension_numbers.output_feature_dimension().getInt() != + (output_spatial_rank + 1)) return failure(); - if (inputSpatialRank != outputSpatialRank || - inputSpatialRank != kernelSpatialRank) + if (input_spatial_rank != output_spatial_rank || + input_spatial_rank != kernel_spatial_rank) return failure(); - auto inputSpatialDim = - dimensionNumbers.input_spatial_dimensions().begin(); - auto kernelSpatialDim = - dimensionNumbers.kernel_spatial_dimensions().begin(); - auto outputSpatialDim = - dimensionNumbers.output_spatial_dimensions().begin(); + auto input_spatial_dim = + dimension_numbers.input_spatial_dimensions().begin(); + auto kernel_spatial_dim = + dimension_numbers.kernel_spatial_dimensions().begin(); + auto output_spatial_dim = + dimension_numbers.output_spatial_dimensions().begin(); // Check if spatial dims are ordered correctly. - for (int i = 0; i < inputSpatialRank; ++i) { + for (int i = 0; i < input_spatial_rank; ++i) { const int dim = i + 1; - if ((*inputSpatialDim++).getZExtValue() != dim || - (*outputSpatialDim++).getZExtValue() != dim || - (*kernelSpatialDim++).getZExtValue() != i) + if ((*input_spatial_dim++).getZExtValue() != dim || + (*output_spatial_dim++).getZExtValue() != dim || + (*kernel_spatial_dim++).getZExtValue() != i) return failure(); } } @@ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern { } llvm::SmallVector strides; - if (auto windowStrides = op.window_strides()) { - auto range = windowStrides->getAttributeValues(); + if (auto window_strides = op.window_strides()) { + auto range = window_strides->getAttributeValues(); strides.assign(range.begin(), range.end()); } - auto stridesArg = ArrayAttr::get(strides, op.getContext()); + auto strides_arg = ArrayAttr::get(strides, op.getContext()); llvm::SmallVector dilation; - if (auto rhsDilation = op.rhs_dilation()) { - auto range = rhsDilation->getAttributeValues(); + if (auto rhs_dilation = op.rhs_dilation()) { + auto range = rhs_dilation->getAttributeValues(); dilation.assign(range.begin(), range.end()); } else { // Default dilation of 1. dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); } - auto dilationArg = ArrayAttr::get(dilation, op.getContext()); + auto dilation_arg = ArrayAttr::get(dilation, op.getContext()); // Set padding only if it is non-zero. DenseIntElementsAttr padding = op.paddingAttr(); - if (!padding || !llvm::any_of(padding.getValues(), [](APInt intVal) { - return !intVal.isNullValue(); - })) { + if (!padding || + !llvm::any_of(padding.getValues(), + [](APInt int_val) { return !int_val.isNullValue(); })) { padding = nullptr; } // The order of input and filter are switched with linalg.conv. rewriter.replaceOpWithNewOp( - op, args[1], args[0], args[2], stridesArg, dilationArg, padding); + op, args[1], args[0], args[2], strides_arg, dilation_arg, padding); return success(); } }; @@ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern { OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); - auto resultType = getHloOpResultType(op); + auto result_type = getHloOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); if (indexing_maps.empty()) return failure(); - auto nloops = resultType.getRank(); + auto nloops = result_type.getRank(); auto loc = op.getLoc(); - auto linalgOp = rewriter.create( + auto linalg_op = rewriter.create( loc, - /*resultTensorTypes=*/isLHLO ? ArrayRef{} : resultType, + /*resultTensorTypes=*/isLHLO ? ArrayRef{} : result_type, /*inputs=*/args.front(), /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, /*initTensor=*/ValueRange{}, indexing_maps, GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(loc, *args.begin()); + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { + nested_builder.create(loc, *args.begin()); }); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); return success(); } }; @@ -325,32 +325,32 @@ class BroadcastConverter using DataMovementOpConverter::DataMovementOpConverter; - static SmallVector getIndexingMaps(OpTy broadcastOp, + static SmallVector getIndexingMaps(OpTy broadcast_op, Builder* b) { - ShapedType inputType = - broadcastOp.operand().getType().template cast(); - unsigned inputRank = inputType.getRank(); - unsigned nloops = getHloOpResultType(broadcastOp).getRank(); + ShapedType input_type = + broadcast_op.operand().getType().template cast(); + unsigned input_rank = input_type.getRank(); + unsigned nloops = getHloOpResultType(broadcast_op).getRank(); // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // the input's dimensions. - unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); - SmallVector inputDimExprs; - inputDimExprs.reserve(inputRank); - for (int i = 0; i < inputRank; ++i) { - inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); + unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes()); + SmallVector input_dim_exprs; + input_dim_exprs.reserve(input_rank); + for (int i = 0; i < input_rank; ++i) { + input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i)); } - AffineMap inputMap; + AffineMap input_map; MLIRContext* context = b->getContext(); - if (inputDimExprs.empty()) { + if (input_dim_exprs.empty()) { // The input is a scalar, i.e. this is a scalar broadcast op. - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); + input_map = AffineMap::get(nloops, /*symbolCount=*/0, context); } else { - inputMap = - AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); + input_map = + AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context); } - return {inputMap, b->getMultiDimIdentityMap(nloops)}; + return {input_map, b->getMultiDimIdentityMap(nloops)}; } }; @@ -363,34 +363,34 @@ class HloBroadcastInDimConverter false>::DataMovementOpConverter; static SmallVector getIndexingMaps( - mhlo::BroadcastInDimOp broadcastOp, Builder* b) { - auto resultType = getHloOpResultType(broadcastOp); - auto operandType = - broadcastOp.operand().getType().template cast(); - unsigned nloops = resultType.getRank(); + mhlo::BroadcastInDimOp broadcast_op, Builder* b) { + auto result_type = getHloOpResultType(broadcast_op); + auto operand_type = + broadcast_op.operand().getType().template cast(); + unsigned nloops = result_type.getRank(); // The input is a scalar, i.e. this is a scalar broadcast op. - if (operandType.getRank() == 0) { + if (operand_type.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } - auto operandShape = operandType.getShape(); - SmallVector dimExprs; - dimExprs.reserve(nloops); + auto operand_shape = operand_type.getShape(); + SmallVector dim_exprs; + dim_exprs.reserve(nloops); - if (broadcastOp.broadcast_dimensions()) { + if (broadcast_op.broadcast_dimensions()) { for (const auto& broadcastDim : - enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { + enumerate(broadcast_op.broadcast_dimensions().getIntValues())) { int size = broadcastDim.value().getSExtValue(); - bool expansion_needed = operandShape[broadcastDim.index()] == 1 && - resultType.getShape()[size] != 1; - dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size)); + bool expansion_needed = operand_shape[broadcastDim.index()] == 1 && + result_type.getShape()[size] != 1; + dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(size)); } } return { - AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), + AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; @@ -430,8 +430,8 @@ class LhloBroadcastInDimConverter /*outputBuffers=*/ValueRange{operand_adaptor.output()}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(loc, val); + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { + nested_builder.create(loc, val); }); } else { @@ -441,8 +441,8 @@ class LhloBroadcastInDimConverter loc, /*inputs=*/ValueRange{operand}, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(loc, *args.begin()); + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { + nested_builder.create(loc, *args.begin()); }); } rewriter.replaceOp(op, llvm::None); @@ -520,35 +520,35 @@ class LhloBroadcastInDimConverter } SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, - ArrayRef broadcastDims, - ArrayRef resultShape, - MemRefType operandType, + ArrayRef broadcast_dims, + ArrayRef result_shape, + MemRefType operand_type, Builder* b) const { - unsigned nloops = resultShape.size(); + unsigned nloops = result_shape.size(); // The input is a scalar, i.e. this is a scalar broadcast op. - if (operandType.getRank() == 0) { + if (operand_type.getRank() == 0) { return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } - auto operandShape = operandType.getShape(); - SmallVector dimExprs; - dimExprs.reserve(nloops); + auto operand_shape = operand_type.getShape(); + SmallVector dim_exprs; + dim_exprs.reserve(nloops); - for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { - int size = broadcastDim.value(); + for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) { + int size = broadcast_dim.value(); bool expansion_needed = - operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; + operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1; if (expansion_needed) { op.emitOpError( "BroadcastInDimOp lowering to Linalg does not support size-1 " "dimensions expansion."); } - dimExprs.push_back(b->getAffineDimExpr(size)); + dim_exprs.push_back(b->getAffineDimExpr(size)); } return { - AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), + AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; @@ -561,17 +561,17 @@ class TransposeConverter using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { - auto resultType = + auto result_type = getHloOpResultType(op).template cast(); - auto nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.resize(resultType.getRank()); + auto nloops = result_type.getRank(); + SmallVector input_exprs; + input_exprs.resize(result_type.getRank()); for (auto permutation : llvm::enumerate(op.permutation())) { - inputExprs[permutation.value().getZExtValue()] = + input_exprs[permutation.value().getZExtValue()] = b->getAffineDimExpr(permutation.index()); } return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; @@ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - OpTy reshapeOp, ArrayRef args, + OpTy reshape_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) + if (!verifyHloOpBufferOrTensorSemantics(reshape_op)) return failure(); - ShapedType operandType = - reshapeOp.operand().getType().template cast(); - ShapedType resultType = getHloOpResultType(reshapeOp); + ShapedType operand_type = + reshape_op.operand().getType().template cast(); + ShapedType result_type = getHloOpResultType(reshape_op); - if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) + if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) return failure(); // Compute the reassociation maps for the linalg operation. - ArrayRef srcShape = - (operandType.getRank() > resultType.getRank() ? operandType.getShape() - : resultType.getShape()); - ArrayRef dstShape = - (operandType.getRank() > resultType.getRank() ? resultType.getShape() - : operandType.getShape()); - unsigned currSrcDim = 0, currDstDim = 0; - SmallVector reassociationMap( - dstShape.size()); - bool isExpandingOrCollapsing = true; - while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { - int64_t dstSize = dstShape[currDstDim]; - int64_t srcSize = srcShape[currSrcDim]; - while (srcSize < dstSize && currSrcDim < srcShape.size()) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - srcSize *= srcShape[currSrcDim]; + ArrayRef src_shape = + (operand_type.getRank() > result_type.getRank() + ? operand_type.getShape() + : result_type.getShape()); + ArrayRef dst_shape = + (operand_type.getRank() > result_type.getRank() + ? result_type.getShape() + : operand_type.getShape()); + unsigned curr_src_dim = 0, curr_dst_dim = 0; + SmallVector reassociation_map( + dst_shape.size()); + bool is_expanding_or_collapsing = true; + while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { + int64_t dst_size = dst_shape[curr_dst_dim]; + int64_t src_size = src_shape[curr_src_dim]; + while (src_size < dst_size && curr_src_dim < src_shape.size()) { + reassociation_map[curr_dst_dim].push_back( + rewriter.getAffineDimExpr(curr_src_dim++)); + src_size *= src_shape[curr_src_dim]; } - if (srcSize == dstSize) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); - // If the next dim in dstShape is not 1, treat subsequent dims in - // srcShape which are 1 to be collapsed. - if (currDstDim == dstShape.size() - 1 || - dstShape[currDstDim + 1] != 1) { - while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { - reassociationMap[currDstDim].push_back( - rewriter.getAffineDimExpr(currSrcDim++)); + if (src_size == dst_size) { + reassociation_map[curr_dst_dim].push_back( + rewriter.getAffineDimExpr(curr_src_dim++)); + // If the next dim in dst_shape is not 1, treat subsequent dims in + // src_shape which are 1 to be collapsed. + if (curr_dst_dim == dst_shape.size() - 1 || + dst_shape[curr_dst_dim + 1] != 1) { + while (curr_src_dim < src_shape.size() && + src_shape[curr_src_dim] == 1) { + reassociation_map[curr_dst_dim].push_back( + rewriter.getAffineDimExpr(curr_src_dim++)); } } } else { - isExpandingOrCollapsing = false; + is_expanding_or_collapsing = false; break; } - currDstDim++; + curr_dst_dim++; } - if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) - isExpandingOrCollapsing = false; + if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) + is_expanding_or_collapsing = false; - if (!isExpandingOrCollapsing) { - auto getIdentityExprs = [&rewriter](int n) { + if (!is_expanding_or_collapsing) { + auto get_identity_exprs = [&rewriter](int n) { SmallVector exprs; for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); return exprs; }; - Location loc = reshapeOp.getLoc(); - int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, - std::multiplies()); - auto elemType = operandType.getElementType(); - SmallVector collapsingMap = { - getIdentityExprs(dstShape.size())}; - SmallVector expandingMap = { - getIdentityExprs(srcShape.size())}; + Location loc = reshape_op.getLoc(); + int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(), + 1, std::multiplies()); + auto elem_type = operand_type.getElementType(); + SmallVector collapsing_map = { + get_identity_exprs(dst_shape.size())}; + SmallVector expanding_map = { + get_identity_exprs(src_shape.size())}; if (isLHLO) { - auto collapsedType = MemRefType::get({totalElems}, elemType); - Value collapsedOp = rewriter.create( - loc, collapsedType, args[0], collapsingMap); - Value reshapeBuffer = rewriter.create( - loc, resultType, collapsedOp, expandingMap); + auto collapsed_type = MemRefType::get({total_elems}, elem_type); + Value collapsed_op = rewriter.create( + loc, collapsed_type, args[0], collapsing_map); + Value reshape_buffer = rewriter.create( + loc, result_type, collapsed_op, expanding_map); rewriter.replaceOpWithNewOp( - reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, /*outputPermutation =*/nullptr); } else { - auto collapsedType = RankedTensorType::get({totalElems}, elemType); - Value collapsedOp = rewriter.create( - loc, collapsedType, args[0], collapsingMap); + auto collapsed_type = RankedTensorType::get({total_elems}, elem_type); + Value collapsed_op = rewriter.create( + loc, collapsed_type, args[0], collapsing_map); rewriter.replaceOpWithNewOp( - reshapeOp, resultType, collapsedOp, expandingMap); + reshape_op, result_type, collapsed_op, expanding_map); } return success(); } if (isLHLO) { - Value reshapeBuffer = rewriter.create( - reshapeOp.getLoc(), resultType, args[0], reassociationMap); + Value reshape_buffer = rewriter.create( + reshape_op.getLoc(), result_type, args[0], reassociation_map); rewriter.replaceOpWithNewOp( - reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, /*outputPermutation =*/nullptr); } else { rewriter.replaceOpWithNewOp( - reshapeOp, resultType, args[0], reassociationMap); + reshape_op, result_type, args[0], reassociation_map); } return success(); } @@ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - OpTy iotaOp, ArrayRef args, + OpTy iota_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - ShapedType resultShapedType = getHloOpResultType(iotaOp); - if (!resultShapedType) return failure(); + ShapedType result_shaped_type = getHloOpResultType(iota_op); + if (!result_shaped_type) return failure(); - auto resultElementType = resultShapedType.getElementType(); - if (!resultElementType.isSignlessIntOrFloat()) return failure(); + auto result_element_type = result_shaped_type.getElementType(); + if (!result_element_type.isSignlessIntOrFloat()) return failure(); // Construct the indexing maps needed for linalg.generic ops. - unsigned nloops = resultShapedType.getRank(); + unsigned nloops = result_shaped_type.getRank(); - auto linalgOp = rewriter.create( - iotaOp.getLoc(), + auto linalg_op = rewriter.create( + iota_op.getLoc(), /*resultTensorTypes=*/ - isLHLO ? ArrayRef{} : ArrayRef{resultShapedType}, + isLHLO ? ArrayRef{} : ArrayRef{result_shaped_type}, /*inputs=*/ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, /*initTensors=*/ValueRange{}, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, ValueRange args) { - Value castOp = nestedBuilder.create( - nestedLoc, ivs[iotaOp.iota_dimension()], - nestedBuilder.getIntegerType( - resultElementType.getIntOrFloatBitWidth())); - if (resultElementType.template isa()) { - castOp = nestedBuilder.create(nestedLoc, castOp, - resultElementType); + Value cast_op = nested_builder.create( + nested_loc, ivs[iota_op.iota_dimension()], + nested_builder.getIntegerType( + result_element_type.getIntOrFloatBitWidth())); + if (result_element_type.template isa()) { + cast_op = nested_builder.create(nested_loc, cast_op, + result_element_type); } - nestedBuilder.create(nestedLoc, castOp); + nested_builder.create(nested_loc, cast_op); }); if (isLHLO) - rewriter.replaceOp(iotaOp, llvm::None); + rewriter.replaceOp(iota_op, llvm::None); else - rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); + rewriter.replaceOp(iota_op, linalg_op.result_tensors()); return success(); } }; @@ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::ConstOp constOp, ArrayRef args, + lmhlo::ConstOp const_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - auto loc = constOp.getLoc(); - auto valueAttr = constOp.value().cast(); - if (valueAttr.getType().getRank() != 0) return failure(); - auto stdConstOp = - rewriter.create(loc, valueAttr.getValue({})); - rewriter.create(loc, stdConstOp, constOp.getOperand(), - ValueRange()); - rewriter.eraseOp(constOp); + auto loc = const_op.getLoc(); + auto value_attr = const_op.value().cast(); + if (value_attr.getType().getRank() != 0) return failure(); + auto std_const_op = + rewriter.create(loc, value_attr.getValue({})); + rewriter.create(loc, std_const_op, + const_op.getOperand(), ValueRange()); + rewriter.eraseOp(const_op); return success(); } }; @@ -758,21 +761,21 @@ class ReverseConverter using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { - auto resultType = + auto result_type = getHloOpResultType(op).template cast(); - auto nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.reserve(nloops); + auto nloops = result_type.getRank(); + SmallVector input_exprs; + input_exprs.reserve(nloops); for (int i = 0; i < nloops; ++i) - inputExprs.push_back(b->getAffineDimExpr(i)); + input_exprs.push_back(b->getAffineDimExpr(i)); for (auto dim : op.dimensions()) { int i = dim.getZExtValue(); - if (resultType.isDynamicDim(i)) return {}; - int n = resultType.getShape()[i]; - inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; + if (result_type.isDynamicDim(i)) return {}; + int n = result_type.getShape()[i]; + input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i]; } return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}; } }; @@ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::SliceOp sliceOp, ArrayRef args, + lmhlo::SliceOp slice_op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - auto loc = sliceOp.getLoc(); - auto argType = - sliceOp.getOperand(0).getType().template dyn_cast(); - if (!argType || !argType.hasRank()) { + auto loc = slice_op.getLoc(); + auto arg_type = + slice_op.getOperand(0).getType().template dyn_cast(); + if (!arg_type || !arg_type.hasRank()) { emitError(loc, "lhlo to linalg conversion expects known-rank args"); return failure(); } SmallVector ranges; - for (int i = 0, e = argType.getRank(); i < e; ++i) { + for (int i = 0, e = arg_type.getRank(); i < e; ++i) { Value start_index = rewriter.create( - loc, sliceOp.start_indices().getValue(i)); + loc, slice_op.start_indices().getValue(i)); Value limit_index = rewriter.create( - loc, sliceOp.limit_indices().getValue(i)); + loc, slice_op.limit_indices().getValue(i)); Value stride = rewriter.create( - loc, sliceOp.strides().getValue(i)); + loc, slice_op.strides().getValue(i)); ranges.push_back(rewriter.create(loc, start_index, limit_index, stride)); } auto linalg_slice = - rewriter.create(loc, sliceOp.getOperand(0), ranges); - rewriter.create(loc, linalg_slice, sliceOp.getOperand(1)); - rewriter.eraseOp(sliceOp); + rewriter.create(loc, slice_op.getOperand(0), ranges); + rewriter.create(loc, linalg_slice, slice_op.getOperand(1)); + rewriter.eraseOp(slice_op); return success(); } };