[NFC] Make naming style consistent.
Use lowercase with underscores between words instead of camelStyle. PiperOrigin-RevId: 338722328
This commit is contained in:
		
							parent
							
								
									31c1c3aa1f
								
							
						
					
					
						commit
						444fae9bac
					
				|  | @ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) { | ||||||
| 
 | 
 | ||||||
| template <bool isLHLO = true> | template <bool isLHLO = true> | ||||||
| bool verifyHloOpBufferOrTensorSemantics(Operation* op) { | bool verifyHloOpBufferOrTensorSemantics(Operation* op) { | ||||||
|   auto verifyType = [&](Value val) -> bool { |   auto verify_type = [&](Value val) -> bool { | ||||||
|     return (isLHLO && val.getType().isa<MemRefType>()) || |     return (isLHLO && val.getType().isa<MemRefType>()) || | ||||||
|            (!isLHLO && val.getType().isa<RankedTensorType>()); |            (!isLHLO && val.getType().isa<RankedTensorType>()); | ||||||
|   }; |   }; | ||||||
|   if (!llvm::all_of(op->getOperands(), verifyType)) return false; |   if (!llvm::all_of(op->getOperands(), verify_type)) return false; | ||||||
|   return isLHLO ? op->getResults().empty() |   return isLHLO ? op->getResults().empty() | ||||||
|                 : llvm::all_of(op->getResults(), verifyType); |                 : llvm::all_of(op->getResults(), verify_type); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename OpTy, bool isLHLO = true> | template <typename OpTy, bool isLHLO = true> | ||||||
|  | @ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { | ||||||
|              << nloops << " parallel iterators: " << *(op.getOperation()); |              << nloops << " parallel iterators: " << *(op.getOperation()); | ||||||
| 
 | 
 | ||||||
|     // Construct the indexing maps needed for linalg.generic ops.
 |     // Construct the indexing maps needed for linalg.generic ops.
 | ||||||
|     SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; |     SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types; | ||||||
| 
 | 
 | ||||||
|     // This doesnt account for implicit broadcast, but the working assumption
 |     // This doesnt account for implicit broadcast, but the working assumption
 | ||||||
|     // in HLO/LHLO is that are broadcasts are made explicit.
 |     // in HLO/LHLO is that are broadcasts are made explicit.
 | ||||||
| 
 | 
 | ||||||
|     if (isLHLO && !nloops) return failure(); |     if (isLHLO && !nloops) return failure(); | ||||||
| 
 | 
 | ||||||
|     int numInputs = (isLHLO ? args.size() - 1 : args.size()); |     int num_inputs = (isLHLO ? args.size() - 1 : args.size()); | ||||||
| 
 | 
 | ||||||
|     ValueRange inputs(args.take_front(numInputs)); |     ValueRange inputs(args.take_front(num_inputs)); | ||||||
|     for (Value in : inputs) |     for (Value in : inputs) | ||||||
|       bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); |       body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); | ||||||
| 
 | 
 | ||||||
|     ValueRange outputBuffers(args.take_back(args.size() - numInputs)); |     ValueRange output_buffers(args.take_back(args.size() - num_inputs)); | ||||||
|     for (Value out : outputBuffers) |     for (Value out : output_buffers) | ||||||
|       bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); |       body_result_types.emplace_back(getElementTypeOrSelf(out.getType())); | ||||||
| 
 | 
 | ||||||
|     if (!isLHLO) { |     if (!isLHLO) { | ||||||
|       // HLO operations have return as tensor types.
 |       // HLO operations have return as tensor types.
 | ||||||
|       assert(bodyResultTypes.empty() && |       assert(body_result_types.empty() && | ||||||
|              "When lowering HLO ops result can't be part of arguments"); |              "When lowering HLO ops result can't be part of arguments"); | ||||||
|       Value result = op.getOperation()->getResult(0); |       Value result = op.getOperation()->getResult(0); | ||||||
|       bodyResultTypes.push_back(getElementTypeOrSelf(result)); |       body_result_types.push_back(getElementTypeOrSelf(result)); | ||||||
|       opResultTypes.push_back(result.getType()); |       op_result_types.push_back(result.getType()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     AffineMap commonIndexingMap = |     AffineMap common_indexing_map = | ||||||
|         nloops ? rewriter.getMultiDimIdentityMap(nloops) |         nloops ? rewriter.getMultiDimIdentityMap(nloops) | ||||||
|                : AffineMap::get(nloops, 0, rewriter.getContext()); |                : AffineMap::get(nloops, 0, rewriter.getContext()); | ||||||
|     SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), |     SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), | ||||||
|                                             commonIndexingMap); |                                             common_indexing_map); | ||||||
| 
 | 
 | ||||||
|     auto linalgOp = rewriter.create<linalg::GenericOp>( |     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||||
|         loc, opResultTypes, inputs, outputBuffers, |         loc, op_result_types, inputs, output_buffers, | ||||||
|         /*initTensors=*/ValueRange{}, indexing_maps, |         /*initTensors=*/ValueRange{}, indexing_maps, | ||||||
|         GetNParallelLoopsAttrs(nloops), |         GetNParallelLoopsAttrs(nloops), | ||||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { |         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||||
|           // TODO(ravishankarm) : For now use the method in lmhlo namespace.
 |           // TODO(ravishankarm) : For now use the method in lmhlo namespace.
 | ||||||
|           // That method needs to be moved out of there.
 |           // That method needs to be moved out of there.
 | ||||||
|           Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>( |           Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>( | ||||||
|               op, bodyResultTypes, |               op, body_result_types, | ||||||
|               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); |               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); | ||||||
|           nestedBuilder.create<linalg::YieldOp>(loc, opResult); |           nested_builder.create<linalg::YieldOp>(loc, op_result); | ||||||
|         }); |         }); | ||||||
|     rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); |     rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> { | ||||||
|       LhloOp lhlo_op, ArrayRef<Value> args, |       LhloOp lhlo_op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     auto loc = lhlo_op.getLoc(); |     auto loc = lhlo_op.getLoc(); | ||||||
|     auto argType = |     auto arg_type = | ||||||
|         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); |         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); | ||||||
|     if (!argType || !argType.getElementType().isSignlessIntOrFloat() || |     if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() || | ||||||
|         (argType.getRank() != 0)) { |         (arg_type.getRank() != 0)) { | ||||||
|       return failure(); |       return failure(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> { | ||||||
|     auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); |     auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); | ||||||
|     auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); |     auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); | ||||||
|     // TODO(ravishankarm) : Move this method out of lmhlo namespace.
 |     // TODO(ravishankarm) : Move this method out of lmhlo namespace.
 | ||||||
|     Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>( |     Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>( | ||||||
|         lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, |         lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, | ||||||
|         &rewriter); |         &rewriter); | ||||||
|     rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); |     rewriter.create<StoreOp>(loc, op_result, lhlo_op.out()); | ||||||
|     rewriter.eraseOp(lhlo_op); |     rewriter.eraseOp(lhlo_op); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
|  | @ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> { | ||||||
|       lmhlo::ConvOp op, ArrayRef<Value> args, |       lmhlo::ConvOp op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     // Check validity of dimension information.
 |     // Check validity of dimension information.
 | ||||||
|     if (const mhlo::ConvDimensionNumbers& dimensionNumbers = |     if (const mhlo::ConvDimensionNumbers& dimension_numbers = | ||||||
|             op.dimension_numbers()) { |             op.dimension_numbers()) { | ||||||
|       const int inputSpatialRank = |       const int input_spatial_rank = | ||||||
|           llvm::size(dimensionNumbers.input_spatial_dimensions()); |           llvm::size(dimension_numbers.input_spatial_dimensions()); | ||||||
|       // The dimensions for input should follow the order of
 |       // The dimensions for input should follow the order of
 | ||||||
|       // batch_count, spatial_dims..., input_feature_count.
 |       // batch_count, spatial_dims..., input_feature_count.
 | ||||||
|       if (dimensionNumbers.input_batch_dimension().getInt() != 0 || |       if (dimension_numbers.input_batch_dimension().getInt() != 0 || | ||||||
|           dimensionNumbers.input_feature_dimension().getInt() != |           dimension_numbers.input_feature_dimension().getInt() != | ||||||
|               (inputSpatialRank + 1)) |               (input_spatial_rank + 1)) | ||||||
|         return failure(); |         return failure(); | ||||||
| 
 | 
 | ||||||
|       const int kernelSpatialRank = |       const int kernel_spatial_rank = | ||||||
|           llvm::size(dimensionNumbers.kernel_spatial_dimensions()); |           llvm::size(dimension_numbers.kernel_spatial_dimensions()); | ||||||
|       // The dimensions for filter should follow the order of
 |       // The dimensions for filter should follow the order of
 | ||||||
|       // spatial_dims..., input_feature_count, num_output_feature_count.
 |       // spatial_dims..., input_feature_count, num_output_feature_count.
 | ||||||
|       if (dimensionNumbers.kernel_input_feature_dimension().getInt() != |       if (dimension_numbers.kernel_input_feature_dimension().getInt() != | ||||||
|               kernelSpatialRank || |               kernel_spatial_rank || | ||||||
|           dimensionNumbers.kernel_output_feature_dimension().getInt() != |           dimension_numbers.kernel_output_feature_dimension().getInt() != | ||||||
|               (kernelSpatialRank + 1)) |               (kernel_spatial_rank + 1)) | ||||||
|         return failure(); |         return failure(); | ||||||
| 
 | 
 | ||||||
|       const int outputSpatialRank = |       const int output_spatial_rank = | ||||||
|           llvm::size(dimensionNumbers.output_spatial_dimensions()); |           llvm::size(dimension_numbers.output_spatial_dimensions()); | ||||||
|       // The dimensions for output should follow the order of
 |       // The dimensions for output should follow the order of
 | ||||||
|       // batch_count, spatial_dims.., output_feature_count.
 |       // batch_count, spatial_dims.., output_feature_count.
 | ||||||
|       if (dimensionNumbers.output_batch_dimension().getInt() != 0 || |       if (dimension_numbers.output_batch_dimension().getInt() != 0 || | ||||||
|           dimensionNumbers.output_feature_dimension().getInt() != |           dimension_numbers.output_feature_dimension().getInt() != | ||||||
|               (outputSpatialRank + 1)) |               (output_spatial_rank + 1)) | ||||||
|         return failure(); |         return failure(); | ||||||
| 
 | 
 | ||||||
|       if (inputSpatialRank != outputSpatialRank || |       if (input_spatial_rank != output_spatial_rank || | ||||||
|           inputSpatialRank != kernelSpatialRank) |           input_spatial_rank != kernel_spatial_rank) | ||||||
|         return failure(); |         return failure(); | ||||||
| 
 | 
 | ||||||
|       auto inputSpatialDim = |       auto input_spatial_dim = | ||||||
|           dimensionNumbers.input_spatial_dimensions().begin(); |           dimension_numbers.input_spatial_dimensions().begin(); | ||||||
|       auto kernelSpatialDim = |       auto kernel_spatial_dim = | ||||||
|           dimensionNumbers.kernel_spatial_dimensions().begin(); |           dimension_numbers.kernel_spatial_dimensions().begin(); | ||||||
|       auto outputSpatialDim = |       auto output_spatial_dim = | ||||||
|           dimensionNumbers.output_spatial_dimensions().begin(); |           dimension_numbers.output_spatial_dimensions().begin(); | ||||||
|       // Check if spatial dims are ordered correctly.
 |       // Check if spatial dims are ordered correctly.
 | ||||||
|       for (int i = 0; i < inputSpatialRank; ++i) { |       for (int i = 0; i < input_spatial_rank; ++i) { | ||||||
|         const int dim = i + 1; |         const int dim = i + 1; | ||||||
|         if ((*inputSpatialDim++).getZExtValue() != dim || |         if ((*input_spatial_dim++).getZExtValue() != dim || | ||||||
|             (*outputSpatialDim++).getZExtValue() != dim || |             (*output_spatial_dim++).getZExtValue() != dim || | ||||||
|             (*kernelSpatialDim++).getZExtValue() != i) |             (*kernel_spatial_dim++).getZExtValue() != i) | ||||||
|           return failure(); |           return failure(); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  | @ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     llvm::SmallVector<Attribute, 4> strides; |     llvm::SmallVector<Attribute, 4> strides; | ||||||
|     if (auto windowStrides = op.window_strides()) { |     if (auto window_strides = op.window_strides()) { | ||||||
|       auto range = windowStrides->getAttributeValues(); |       auto range = window_strides->getAttributeValues(); | ||||||
|       strides.assign(range.begin(), range.end()); |       strides.assign(range.begin(), range.end()); | ||||||
|     } |     } | ||||||
|     auto stridesArg = ArrayAttr::get(strides, op.getContext()); |     auto strides_arg = ArrayAttr::get(strides, op.getContext()); | ||||||
| 
 | 
 | ||||||
|     llvm::SmallVector<Attribute, 2> dilation; |     llvm::SmallVector<Attribute, 2> dilation; | ||||||
|     if (auto rhsDilation = op.rhs_dilation()) { |     if (auto rhs_dilation = op.rhs_dilation()) { | ||||||
|       auto range = rhsDilation->getAttributeValues(); |       auto range = rhs_dilation->getAttributeValues(); | ||||||
|       dilation.assign(range.begin(), range.end()); |       dilation.assign(range.begin(), range.end()); | ||||||
|     } else { |     } else { | ||||||
|       // Default dilation of 1.
 |       // Default dilation of 1.
 | ||||||
|       dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); |       dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); | ||||||
|     } |     } | ||||||
|     auto dilationArg = ArrayAttr::get(dilation, op.getContext()); |     auto dilation_arg = ArrayAttr::get(dilation, op.getContext()); | ||||||
| 
 | 
 | ||||||
|     // Set padding only if it is non-zero.
 |     // Set padding only if it is non-zero.
 | ||||||
|     DenseIntElementsAttr padding = op.paddingAttr(); |     DenseIntElementsAttr padding = op.paddingAttr(); | ||||||
|     if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) { |     if (!padding || | ||||||
|           return !intVal.isNullValue(); |         !llvm::any_of(padding.getValues<APInt>(), | ||||||
|         })) { |                       [](APInt int_val) { return !int_val.isNullValue(); })) { | ||||||
|       padding = nullptr; |       padding = nullptr; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // The order of input and filter are switched with linalg.conv.
 |     // The order of input and filter are switched with linalg.conv.
 | ||||||
|     rewriter.replaceOpWithNewOp<linalg::ConvOp>( |     rewriter.replaceOpWithNewOp<linalg::ConvOp>( | ||||||
|         op, args[1], args[0], args[2], stridesArg, dilationArg, padding); |         op, args[1], args[0], args[2], strides_arg, dilation_arg, padding); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> { | ||||||
|       OpTy op, ArrayRef<Value> args, |       OpTy op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); |     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); | ||||||
|     auto resultType = getHloOpResultType<isLHLO>(op); |     auto result_type = getHloOpResultType<isLHLO>(op); | ||||||
| 
 | 
 | ||||||
|     SmallVector<AffineMap, 2> indexing_maps = |     SmallVector<AffineMap, 2> indexing_maps = | ||||||
|         Derived::getIndexingMaps(op, &rewriter); |         Derived::getIndexingMaps(op, &rewriter); | ||||||
|     if (indexing_maps.empty()) return failure(); |     if (indexing_maps.empty()) return failure(); | ||||||
| 
 | 
 | ||||||
|     auto nloops = resultType.getRank(); |     auto nloops = result_type.getRank(); | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
|     auto linalgOp = rewriter.create<linalg::GenericOp>( |     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||||
|         loc, |         loc, | ||||||
|         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType, |         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type, | ||||||
|         /*inputs=*/args.front(), |         /*inputs=*/args.front(), | ||||||
|         /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, |         /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, | ||||||
|         /*initTensor=*/ValueRange{}, indexing_maps, |         /*initTensor=*/ValueRange{}, indexing_maps, | ||||||
|         GetNParallelLoopsAttrs(nloops), |         GetNParallelLoopsAttrs(nloops), | ||||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { |         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||||
|           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); |           nested_builder.create<linalg::YieldOp>(loc, *args.begin()); | ||||||
|         }); |         }); | ||||||
|     rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); |     rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -325,32 +325,32 @@ class BroadcastConverter | ||||||
|   using DataMovementOpConverter<BroadcastConverter, OpTy, |   using DataMovementOpConverter<BroadcastConverter, OpTy, | ||||||
|                                 isLHLO>::DataMovementOpConverter; |                                 isLHLO>::DataMovementOpConverter; | ||||||
| 
 | 
 | ||||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp, |   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op, | ||||||
|                                                    Builder* b) { |                                                    Builder* b) { | ||||||
|     ShapedType inputType = |     ShapedType input_type = | ||||||
|         broadcastOp.operand().getType().template cast<ShapedType>(); |         broadcast_op.operand().getType().template cast<ShapedType>(); | ||||||
|     unsigned inputRank = inputType.getRank(); |     unsigned input_rank = input_type.getRank(); | ||||||
|     unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank(); |     unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank(); | ||||||
| 
 | 
 | ||||||
|     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
 |     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
 | ||||||
|     // the input's dimensions.
 |     // the input's dimensions.
 | ||||||
|     unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); |     unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes()); | ||||||
|     SmallVector<AffineExpr, 4> inputDimExprs; |     SmallVector<AffineExpr, 4> input_dim_exprs; | ||||||
|     inputDimExprs.reserve(inputRank); |     input_dim_exprs.reserve(input_rank); | ||||||
|     for (int i = 0; i < inputRank; ++i) { |     for (int i = 0; i < input_rank; ++i) { | ||||||
|       inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); |       input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     AffineMap inputMap; |     AffineMap input_map; | ||||||
|     MLIRContext* context = b->getContext(); |     MLIRContext* context = b->getContext(); | ||||||
|     if (inputDimExprs.empty()) { |     if (input_dim_exprs.empty()) { | ||||||
|       // The input is a scalar, i.e. this is a scalar broadcast op.
 |       // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||||
|       inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); |       input_map = AffineMap::get(nloops, /*symbolCount=*/0, context); | ||||||
|     } else { |     } else { | ||||||
|       inputMap = |       input_map = | ||||||
|           AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); |           AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context); | ||||||
|     } |     } | ||||||
|     return {inputMap, b->getMultiDimIdentityMap(nloops)}; |     return {input_map, b->getMultiDimIdentityMap(nloops)}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -363,34 +363,34 @@ class HloBroadcastInDimConverter | ||||||
|                                 false>::DataMovementOpConverter; |                                 false>::DataMovementOpConverter; | ||||||
| 
 | 
 | ||||||
|   static SmallVector<AffineMap, 2> getIndexingMaps( |   static SmallVector<AffineMap, 2> getIndexingMaps( | ||||||
|       mhlo::BroadcastInDimOp broadcastOp, Builder* b) { |       mhlo::BroadcastInDimOp broadcast_op, Builder* b) { | ||||||
|     auto resultType = getHloOpResultType<false>(broadcastOp); |     auto result_type = getHloOpResultType<false>(broadcast_op); | ||||||
|     auto operandType = |     auto operand_type = | ||||||
|         broadcastOp.operand().getType().template cast<ShapedType>(); |         broadcast_op.operand().getType().template cast<ShapedType>(); | ||||||
|     unsigned nloops = resultType.getRank(); |     unsigned nloops = result_type.getRank(); | ||||||
| 
 | 
 | ||||||
|     // The input is a scalar, i.e. this is a scalar broadcast op.
 |     // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||||
|     if (operandType.getRank() == 0) { |     if (operand_type.getRank() == 0) { | ||||||
|       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), |       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), | ||||||
|               b->getMultiDimIdentityMap(nloops)}; |               b->getMultiDimIdentityMap(nloops)}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     auto operandShape = operandType.getShape(); |     auto operand_shape = operand_type.getShape(); | ||||||
|     SmallVector<AffineExpr, 4> dimExprs; |     SmallVector<AffineExpr, 4> dim_exprs; | ||||||
|     dimExprs.reserve(nloops); |     dim_exprs.reserve(nloops); | ||||||
| 
 | 
 | ||||||
|     if (broadcastOp.broadcast_dimensions()) { |     if (broadcast_op.broadcast_dimensions()) { | ||||||
|       for (const auto& broadcastDim : |       for (const auto& broadcastDim : | ||||||
|            enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { |            enumerate(broadcast_op.broadcast_dimensions().getIntValues())) { | ||||||
|         int size = broadcastDim.value().getSExtValue(); |         int size = broadcastDim.value().getSExtValue(); | ||||||
|         bool expansion_needed = operandShape[broadcastDim.index()] == 1 && |         bool expansion_needed = operand_shape[broadcastDim.index()] == 1 && | ||||||
|                                 resultType.getShape()[size] != 1; |                                 result_type.getShape()[size] != 1; | ||||||
|         dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) |         dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) | ||||||
|                                             : b->getAffineDimExpr(size)); |                                              : b->getAffineDimExpr(size)); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     return { |     return { | ||||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), |         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), | ||||||
|         b->getMultiDimIdentityMap(nloops)}; |         b->getMultiDimIdentityMap(nloops)}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -430,8 +430,8 @@ class LhloBroadcastInDimConverter | ||||||
|           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, |           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, | ||||||
|           llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), |           llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), | ||||||
|           GetNParallelLoopsAttrs(nloops), |           GetNParallelLoopsAttrs(nloops), | ||||||
|           [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { |           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||||
|             nestedBuilder.create<linalg::YieldOp>(loc, val); |             nested_builder.create<linalg::YieldOp>(loc, val); | ||||||
|           }); |           }); | ||||||
| 
 | 
 | ||||||
|     } else { |     } else { | ||||||
|  | @ -441,8 +441,8 @@ class LhloBroadcastInDimConverter | ||||||
|           loc, /*inputs=*/ValueRange{operand}, |           loc, /*inputs=*/ValueRange{operand}, | ||||||
|           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, |           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, | ||||||
|           GetNParallelLoopsAttrs(nloops), |           GetNParallelLoopsAttrs(nloops), | ||||||
|           [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { |           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||||
|             nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); |             nested_builder.create<linalg::YieldOp>(loc, *args.begin()); | ||||||
|           }); |           }); | ||||||
|     } |     } | ||||||
|     rewriter.replaceOp(op, llvm::None); |     rewriter.replaceOp(op, llvm::None); | ||||||
|  | @ -520,35 +520,35 @@ class LhloBroadcastInDimConverter | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op, |   SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op, | ||||||
|                                             ArrayRef<int64_t> broadcastDims, |                                             ArrayRef<int64_t> broadcast_dims, | ||||||
|                                             ArrayRef<int64_t> resultShape, |                                             ArrayRef<int64_t> result_shape, | ||||||
|                                             MemRefType operandType, |                                             MemRefType operand_type, | ||||||
|                                             Builder* b) const { |                                             Builder* b) const { | ||||||
|     unsigned nloops = resultShape.size(); |     unsigned nloops = result_shape.size(); | ||||||
| 
 | 
 | ||||||
|     // The input is a scalar, i.e. this is a scalar broadcast op.
 |     // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||||
|     if (operandType.getRank() == 0) { |     if (operand_type.getRank() == 0) { | ||||||
|       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), |       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), | ||||||
|               b->getMultiDimIdentityMap(nloops)}; |               b->getMultiDimIdentityMap(nloops)}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     auto operandShape = operandType.getShape(); |     auto operand_shape = operand_type.getShape(); | ||||||
|     SmallVector<AffineExpr, 4> dimExprs; |     SmallVector<AffineExpr, 4> dim_exprs; | ||||||
|     dimExprs.reserve(nloops); |     dim_exprs.reserve(nloops); | ||||||
| 
 | 
 | ||||||
|     for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { |     for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) { | ||||||
|       int size = broadcastDim.value(); |       int size = broadcast_dim.value(); | ||||||
|       bool expansion_needed = |       bool expansion_needed = | ||||||
|           operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; |           operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1; | ||||||
|       if (expansion_needed) { |       if (expansion_needed) { | ||||||
|         op.emitOpError( |         op.emitOpError( | ||||||
|             "BroadcastInDimOp lowering to Linalg does not support size-1 " |             "BroadcastInDimOp lowering to Linalg does not support size-1 " | ||||||
|             "dimensions expansion."); |             "dimensions expansion."); | ||||||
|       } |       } | ||||||
|       dimExprs.push_back(b->getAffineDimExpr(size)); |       dim_exprs.push_back(b->getAffineDimExpr(size)); | ||||||
|     } |     } | ||||||
|     return { |     return { | ||||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), |         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), | ||||||
|         b->getMultiDimIdentityMap(nloops)}; |         b->getMultiDimIdentityMap(nloops)}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -561,17 +561,17 @@ class TransposeConverter | ||||||
|   using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, |   using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, | ||||||
|                                 isLHLO>::DataMovementOpConverter; |                                 isLHLO>::DataMovementOpConverter; | ||||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { |   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { | ||||||
|     auto resultType = |     auto result_type = | ||||||
|         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); |         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); | ||||||
|     auto nloops = resultType.getRank(); |     auto nloops = result_type.getRank(); | ||||||
|     SmallVector<AffineExpr, 2> inputExprs; |     SmallVector<AffineExpr, 2> input_exprs; | ||||||
|     inputExprs.resize(resultType.getRank()); |     input_exprs.resize(result_type.getRank()); | ||||||
|     for (auto permutation : llvm::enumerate(op.permutation())) { |     for (auto permutation : llvm::enumerate(op.permutation())) { | ||||||
|       inputExprs[permutation.value().getZExtValue()] = |       input_exprs[permutation.value().getZExtValue()] = | ||||||
|           b->getAffineDimExpr(permutation.index()); |           b->getAffineDimExpr(permutation.index()); | ||||||
|     } |     } | ||||||
|     return { |     return { | ||||||
|         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), |         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), | ||||||
|         b->getMultiDimIdentityMap(nloops)}; |         b->getMultiDimIdentityMap(nloops)}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> { | ||||||
|   using OpConversionPattern<OpTy>::OpConversionPattern; |   using OpConversionPattern<OpTy>::OpConversionPattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       OpTy reshapeOp, ArrayRef<Value> args, |       OpTy reshape_op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp)) |     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op)) | ||||||
|       return failure(); |       return failure(); | ||||||
|     ShapedType operandType = |     ShapedType operand_type = | ||||||
|         reshapeOp.operand().getType().template cast<ShapedType>(); |         reshape_op.operand().getType().template cast<ShapedType>(); | ||||||
|     ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp); |     ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op); | ||||||
| 
 | 
 | ||||||
|     if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) |     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     // Compute the reassociation maps for the linalg operation.
 |     // Compute the reassociation maps for the linalg operation.
 | ||||||
|     ArrayRef<int64_t> srcShape = |     ArrayRef<int64_t> src_shape = | ||||||
|         (operandType.getRank() > resultType.getRank() ? operandType.getShape() |         (operand_type.getRank() > result_type.getRank() | ||||||
|                                                       : resultType.getShape()); |              ? operand_type.getShape() | ||||||
|     ArrayRef<int64_t> dstShape = |              : result_type.getShape()); | ||||||
|         (operandType.getRank() > resultType.getRank() ? resultType.getShape() |     ArrayRef<int64_t> dst_shape = | ||||||
|                                                       : operandType.getShape()); |         (operand_type.getRank() > result_type.getRank() | ||||||
|     unsigned currSrcDim = 0, currDstDim = 0; |              ? result_type.getShape() | ||||||
|     SmallVector<linalg::ReassociationExprs, 4> reassociationMap( |              : operand_type.getShape()); | ||||||
|         dstShape.size()); |     unsigned curr_src_dim = 0, curr_dst_dim = 0; | ||||||
|     bool isExpandingOrCollapsing = true; |     SmallVector<linalg::ReassociationExprs, 4> reassociation_map( | ||||||
|     while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { |         dst_shape.size()); | ||||||
|       int64_t dstSize = dstShape[currDstDim]; |     bool is_expanding_or_collapsing = true; | ||||||
|       int64_t srcSize = srcShape[currSrcDim]; |     while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { | ||||||
|       while (srcSize < dstSize && currSrcDim < srcShape.size()) { |       int64_t dst_size = dst_shape[curr_dst_dim]; | ||||||
|         reassociationMap[currDstDim].push_back( |       int64_t src_size = src_shape[curr_src_dim]; | ||||||
|             rewriter.getAffineDimExpr(currSrcDim++)); |       while (src_size < dst_size && curr_src_dim < src_shape.size()) { | ||||||
|         srcSize *= srcShape[currSrcDim]; |         reassociation_map[curr_dst_dim].push_back( | ||||||
|  |             rewriter.getAffineDimExpr(curr_src_dim++)); | ||||||
|  |         src_size *= src_shape[curr_src_dim]; | ||||||
|       } |       } | ||||||
|       if (srcSize == dstSize) { |       if (src_size == dst_size) { | ||||||
|         reassociationMap[currDstDim].push_back( |         reassociation_map[curr_dst_dim].push_back( | ||||||
|             rewriter.getAffineDimExpr(currSrcDim++)); |             rewriter.getAffineDimExpr(curr_src_dim++)); | ||||||
|         // If the next dim in dstShape is not 1, treat subsequent dims in
 |         // If the next dim in dst_shape is not 1, treat subsequent dims in
 | ||||||
|         // srcShape which are 1 to be collapsed.
 |         // src_shape which are 1 to be collapsed.
 | ||||||
|         if (currDstDim == dstShape.size() - 1 || |         if (curr_dst_dim == dst_shape.size() - 1 || | ||||||
|             dstShape[currDstDim + 1] != 1) { |             dst_shape[curr_dst_dim + 1] != 1) { | ||||||
|           while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { |           while (curr_src_dim < src_shape.size() && | ||||||
|             reassociationMap[currDstDim].push_back( |                  src_shape[curr_src_dim] == 1) { | ||||||
|                 rewriter.getAffineDimExpr(currSrcDim++)); |             reassociation_map[curr_dst_dim].push_back( | ||||||
|  |                 rewriter.getAffineDimExpr(curr_src_dim++)); | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|       } else { |       } else { | ||||||
|         isExpandingOrCollapsing = false; |         is_expanding_or_collapsing = false; | ||||||
|         break; |         break; | ||||||
|       } |       } | ||||||
|       currDstDim++; |       curr_dst_dim++; | ||||||
|     } |     } | ||||||
|     if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) |     if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) | ||||||
|       isExpandingOrCollapsing = false; |       is_expanding_or_collapsing = false; | ||||||
| 
 | 
 | ||||||
|     if (!isExpandingOrCollapsing) { |     if (!is_expanding_or_collapsing) { | ||||||
|       auto getIdentityExprs = [&rewriter](int n) { |       auto get_identity_exprs = [&rewriter](int n) { | ||||||
|         SmallVector<AffineExpr, 4> exprs; |         SmallVector<AffineExpr, 4> exprs; | ||||||
|         for (int i = 0; i < n; ++i) |         for (int i = 0; i < n; ++i) | ||||||
|           exprs.push_back(rewriter.getAffineDimExpr(i)); |           exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||||
|         return exprs; |         return exprs; | ||||||
|       }; |       }; | ||||||
|       Location loc = reshapeOp.getLoc(); |       Location loc = reshape_op.getLoc(); | ||||||
|       int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, |       int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(), | ||||||
|                                            std::multiplies<int64_t>()); |                                             1, std::multiplies<int64_t>()); | ||||||
|       auto elemType = operandType.getElementType(); |       auto elem_type = operand_type.getElementType(); | ||||||
|       SmallVector<linalg::ReassociationExprs, 4> collapsingMap = { |       SmallVector<linalg::ReassociationExprs, 4> collapsing_map = { | ||||||
|           getIdentityExprs(dstShape.size())}; |           get_identity_exprs(dst_shape.size())}; | ||||||
|       SmallVector<linalg::ReassociationExprs, 4> expandingMap = { |       SmallVector<linalg::ReassociationExprs, 4> expanding_map = { | ||||||
|           getIdentityExprs(srcShape.size())}; |           get_identity_exprs(src_shape.size())}; | ||||||
| 
 | 
 | ||||||
|       if (isLHLO) { |       if (isLHLO) { | ||||||
|         auto collapsedType = MemRefType::get({totalElems}, elemType); |         auto collapsed_type = MemRefType::get({total_elems}, elem_type); | ||||||
|         Value collapsedOp = rewriter.create<linalg::ReshapeOp>( |         Value collapsed_op = rewriter.create<linalg::ReshapeOp>( | ||||||
|             loc, collapsedType, args[0], collapsingMap); |             loc, collapsed_type, args[0], collapsing_map); | ||||||
|         Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( |         Value reshape_buffer = rewriter.create<linalg::ReshapeOp>( | ||||||
|             loc, resultType, collapsedOp, expandingMap); |             loc, result_type, collapsed_op, expanding_map); | ||||||
|         rewriter.replaceOpWithNewOp<linalg::CopyOp>( |         rewriter.replaceOpWithNewOp<linalg::CopyOp>( | ||||||
|             reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, |             reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, | ||||||
|             /*outputPermutation =*/nullptr); |             /*outputPermutation =*/nullptr); | ||||||
|       } else { |       } else { | ||||||
|         auto collapsedType = RankedTensorType::get({totalElems}, elemType); |         auto collapsed_type = RankedTensorType::get({total_elems}, elem_type); | ||||||
|         Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>( |         Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>( | ||||||
|             loc, collapsedType, args[0], collapsingMap); |             loc, collapsed_type, args[0], collapsing_map); | ||||||
|         rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( |         rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( | ||||||
|             reshapeOp, resultType, collapsedOp, expandingMap); |             reshape_op, result_type, collapsed_op, expanding_map); | ||||||
|       } |       } | ||||||
|       return success(); |       return success(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     if (isLHLO) { |     if (isLHLO) { | ||||||
|       Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( |       Value reshape_buffer = rewriter.create<linalg::ReshapeOp>( | ||||||
|           reshapeOp.getLoc(), resultType, args[0], reassociationMap); |           reshape_op.getLoc(), result_type, args[0], reassociation_map); | ||||||
|       rewriter.replaceOpWithNewOp<linalg::CopyOp>( |       rewriter.replaceOpWithNewOp<linalg::CopyOp>( | ||||||
|           reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, |           reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, | ||||||
|           /*outputPermutation =*/nullptr); |           /*outputPermutation =*/nullptr); | ||||||
|     } else { |     } else { | ||||||
|       rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( |       rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( | ||||||
|           reshapeOp, resultType, args[0], reassociationMap); |           reshape_op, result_type, args[0], reassociation_map); | ||||||
|     } |     } | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
|  | @ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> { | ||||||
|   using OpConversionPattern<OpTy>::OpConversionPattern; |   using OpConversionPattern<OpTy>::OpConversionPattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       OpTy iotaOp, ArrayRef<Value> args, |       OpTy iota_op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp); |     ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op); | ||||||
|     if (!resultShapedType) return failure(); |     if (!result_shaped_type) return failure(); | ||||||
| 
 | 
 | ||||||
|     auto resultElementType = resultShapedType.getElementType(); |     auto result_element_type = result_shaped_type.getElementType(); | ||||||
|     if (!resultElementType.isSignlessIntOrFloat()) return failure(); |     if (!result_element_type.isSignlessIntOrFloat()) return failure(); | ||||||
| 
 | 
 | ||||||
|     // Construct the indexing maps needed for linalg.generic ops.
 |     // Construct the indexing maps needed for linalg.generic ops.
 | ||||||
|     unsigned nloops = resultShapedType.getRank(); |     unsigned nloops = result_shaped_type.getRank(); | ||||||
| 
 | 
 | ||||||
|     auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( |     auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( | ||||||
|         iotaOp.getLoc(), |         iota_op.getLoc(), | ||||||
|         /*resultTensorTypes=*/ |         /*resultTensorTypes=*/ | ||||||
|         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType}, |         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type}, | ||||||
|         /*inputs=*/ValueRange{}, |         /*inputs=*/ValueRange{}, | ||||||
|         /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, |         /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, | ||||||
|         /*initTensors=*/ValueRange{}, |         /*initTensors=*/ValueRange{}, | ||||||
|         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), |         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), | ||||||
|         GetNParallelLoopsAttrs(nloops), |         GetNParallelLoopsAttrs(nloops), | ||||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, |         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, | ||||||
|             ValueRange args) { |             ValueRange args) { | ||||||
|           Value castOp = nestedBuilder.create<IndexCastOp>( |           Value cast_op = nested_builder.create<IndexCastOp>( | ||||||
|               nestedLoc, ivs[iotaOp.iota_dimension()], |               nested_loc, ivs[iota_op.iota_dimension()], | ||||||
|               nestedBuilder.getIntegerType( |               nested_builder.getIntegerType( | ||||||
|                   resultElementType.getIntOrFloatBitWidth())); |                   result_element_type.getIntOrFloatBitWidth())); | ||||||
|           if (resultElementType.template isa<FloatType>()) { |           if (result_element_type.template isa<FloatType>()) { | ||||||
|             castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp, |             cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op, | ||||||
|                                                     resultElementType); |                                                       result_element_type); | ||||||
|           } |           } | ||||||
|           nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); |           nested_builder.create<linalg::YieldOp>(nested_loc, cast_op); | ||||||
|         }); |         }); | ||||||
|     if (isLHLO) |     if (isLHLO) | ||||||
|       rewriter.replaceOp(iotaOp, llvm::None); |       rewriter.replaceOp(iota_op, llvm::None); | ||||||
|     else |     else | ||||||
|       rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); |       rewriter.replaceOp(iota_op, linalg_op.result_tensors()); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> { | ||||||
|   using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; |   using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       lmhlo::ConstOp constOp, ArrayRef<Value> args, |       lmhlo::ConstOp const_op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     auto loc = constOp.getLoc(); |     auto loc = const_op.getLoc(); | ||||||
|     auto valueAttr = constOp.value().cast<DenseElementsAttr>(); |     auto value_attr = const_op.value().cast<DenseElementsAttr>(); | ||||||
|     if (valueAttr.getType().getRank() != 0) return failure(); |     if (value_attr.getType().getRank() != 0) return failure(); | ||||||
|     auto stdConstOp = |     auto std_const_op = | ||||||
|         rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({})); |         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({})); | ||||||
|     rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(), |     rewriter.create<mlir::AffineStoreOp>(loc, std_const_op, | ||||||
|                                          ValueRange()); |                                          const_op.getOperand(), ValueRange()); | ||||||
|     rewriter.eraseOp(constOp); |     rewriter.eraseOp(const_op); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -758,21 +761,21 @@ class ReverseConverter | ||||||
|   using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, |   using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, | ||||||
|                                 isLHLO>::DataMovementOpConverter; |                                 isLHLO>::DataMovementOpConverter; | ||||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { |   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { | ||||||
|     auto resultType = |     auto result_type = | ||||||
|         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); |         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); | ||||||
|     auto nloops = resultType.getRank(); |     auto nloops = result_type.getRank(); | ||||||
|     SmallVector<AffineExpr, 2> inputExprs; |     SmallVector<AffineExpr, 2> input_exprs; | ||||||
|     inputExprs.reserve(nloops); |     input_exprs.reserve(nloops); | ||||||
|     for (int i = 0; i < nloops; ++i) |     for (int i = 0; i < nloops; ++i) | ||||||
|       inputExprs.push_back(b->getAffineDimExpr(i)); |       input_exprs.push_back(b->getAffineDimExpr(i)); | ||||||
|     for (auto dim : op.dimensions()) { |     for (auto dim : op.dimensions()) { | ||||||
|       int i = dim.getZExtValue(); |       int i = dim.getZExtValue(); | ||||||
|       if (resultType.isDynamicDim(i)) return {}; |       if (result_type.isDynamicDim(i)) return {}; | ||||||
|       int n = resultType.getShape()[i]; |       int n = result_type.getShape()[i]; | ||||||
|       inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; |       input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i]; | ||||||
|     } |     } | ||||||
|     return { |     return { | ||||||
|         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), |         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), | ||||||
|         b->getMultiDimIdentityMap(nloops)}; |         b->getMultiDimIdentityMap(nloops)}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> { | ||||||
|   using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern; |   using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       lmhlo::SliceOp sliceOp, ArrayRef<Value> args, |       lmhlo::SliceOp slice_op, ArrayRef<Value> args, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     auto loc = sliceOp.getLoc(); |     auto loc = slice_op.getLoc(); | ||||||
|     auto argType = |     auto arg_type = | ||||||
|         sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>(); |         slice_op.getOperand(0).getType().template dyn_cast<ShapedType>(); | ||||||
|     if (!argType || !argType.hasRank()) { |     if (!arg_type || !arg_type.hasRank()) { | ||||||
|       emitError(loc, "lhlo to linalg conversion expects known-rank args"); |       emitError(loc, "lhlo to linalg conversion expects known-rank args"); | ||||||
|       return failure(); |       return failure(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     SmallVector<Value, 3> ranges; |     SmallVector<Value, 3> ranges; | ||||||
|     for (int i = 0, e = argType.getRank(); i < e; ++i) { |     for (int i = 0, e = arg_type.getRank(); i < e; ++i) { | ||||||
|       Value start_index = rewriter.create<ConstantIndexOp>( |       Value start_index = rewriter.create<ConstantIndexOp>( | ||||||
|           loc, sliceOp.start_indices().getValue<int64_t>(i)); |           loc, slice_op.start_indices().getValue<int64_t>(i)); | ||||||
|       Value limit_index = rewriter.create<ConstantIndexOp>( |       Value limit_index = rewriter.create<ConstantIndexOp>( | ||||||
|           loc, sliceOp.limit_indices().getValue<int64_t>(i)); |           loc, slice_op.limit_indices().getValue<int64_t>(i)); | ||||||
|       Value stride = rewriter.create<ConstantIndexOp>( |       Value stride = rewriter.create<ConstantIndexOp>( | ||||||
|           loc, sliceOp.strides().getValue<int64_t>(i)); |           loc, slice_op.strides().getValue<int64_t>(i)); | ||||||
|       ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index, |       ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index, | ||||||
|                                                         limit_index, stride)); |                                                         limit_index, stride)); | ||||||
|     } |     } | ||||||
|     auto linalg_slice = |     auto linalg_slice = | ||||||
|         rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges); |         rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges); | ||||||
|     rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1)); |     rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1)); | ||||||
|     rewriter.eraseOp(sliceOp); |     rewriter.eraseOp(slice_op); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue