[NFC] Make naming style consistent.
Use lowercase with underscores between words instead of camelStyle. PiperOrigin-RevId: 338722328
This commit is contained in:
		
							parent
							
								
									31c1c3aa1f
								
							
						
					
					
						commit
						444fae9bac
					
				|  | @ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) { | |||
| 
 | ||||
| template <bool isLHLO = true> | ||||
| bool verifyHloOpBufferOrTensorSemantics(Operation* op) { | ||||
|   auto verifyType = [&](Value val) -> bool { | ||||
|   auto verify_type = [&](Value val) -> bool { | ||||
|     return (isLHLO && val.getType().isa<MemRefType>()) || | ||||
|            (!isLHLO && val.getType().isa<RankedTensorType>()); | ||||
|   }; | ||||
|   if (!llvm::all_of(op->getOperands(), verifyType)) return false; | ||||
|   if (!llvm::all_of(op->getOperands(), verify_type)) return false; | ||||
|   return isLHLO ? op->getResults().empty() | ||||
|                 : llvm::all_of(op->getResults(), verifyType); | ||||
|                 : llvm::all_of(op->getResults(), verify_type); | ||||
| } | ||||
| 
 | ||||
| template <typename OpTy, bool isLHLO = true> | ||||
|  | @ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> { | |||
|              << nloops << " parallel iterators: " << *(op.getOperation()); | ||||
| 
 | ||||
|     // Construct the indexing maps needed for linalg.generic ops.
 | ||||
|     SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; | ||||
|     SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types; | ||||
| 
 | ||||
|     // This doesnt account for implicit broadcast, but the working assumption
 | ||||
|     // in HLO/LHLO is that are broadcasts are made explicit.
 | ||||
| 
 | ||||
|     if (isLHLO && !nloops) return failure(); | ||||
| 
 | ||||
|     int numInputs = (isLHLO ? args.size() - 1 : args.size()); | ||||
|     int num_inputs = (isLHLO ? args.size() - 1 : args.size()); | ||||
| 
 | ||||
|     ValueRange inputs(args.take_front(numInputs)); | ||||
|     ValueRange inputs(args.take_front(num_inputs)); | ||||
|     for (Value in : inputs) | ||||
|       bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); | ||||
|       body_arg_types.emplace_back(getElementTypeOrSelf(in.getType())); | ||||
| 
 | ||||
|     ValueRange outputBuffers(args.take_back(args.size() - numInputs)); | ||||
|     for (Value out : outputBuffers) | ||||
|       bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); | ||||
|     ValueRange output_buffers(args.take_back(args.size() - num_inputs)); | ||||
|     for (Value out : output_buffers) | ||||
|       body_result_types.emplace_back(getElementTypeOrSelf(out.getType())); | ||||
| 
 | ||||
|     if (!isLHLO) { | ||||
|       // HLO operations have return as tensor types.
 | ||||
|       assert(bodyResultTypes.empty() && | ||||
|       assert(body_result_types.empty() && | ||||
|              "When lowering HLO ops result can't be part of arguments"); | ||||
|       Value result = op.getOperation()->getResult(0); | ||||
|       bodyResultTypes.push_back(getElementTypeOrSelf(result)); | ||||
|       opResultTypes.push_back(result.getType()); | ||||
|       body_result_types.push_back(getElementTypeOrSelf(result)); | ||||
|       op_result_types.push_back(result.getType()); | ||||
|     } | ||||
| 
 | ||||
|     AffineMap commonIndexingMap = | ||||
|     AffineMap common_indexing_map = | ||||
|         nloops ? rewriter.getMultiDimIdentityMap(nloops) | ||||
|                : AffineMap::get(nloops, 0, rewriter.getContext()); | ||||
|     SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), | ||||
|                                             commonIndexingMap); | ||||
|                                             common_indexing_map); | ||||
| 
 | ||||
|     auto linalgOp = rewriter.create<linalg::GenericOp>( | ||||
|         loc, opResultTypes, inputs, outputBuffers, | ||||
|     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||
|         loc, op_result_types, inputs, output_buffers, | ||||
|         /*initTensors=*/ValueRange{}, indexing_maps, | ||||
|         GetNParallelLoopsAttrs(nloops), | ||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { | ||||
|         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||
|           // TODO(ravishankarm) : For now use the method in lmhlo namespace.
 | ||||
|           // That method needs to be moved out of there.
 | ||||
|           Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>( | ||||
|               op, bodyResultTypes, | ||||
|           Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>( | ||||
|               op, body_result_types, | ||||
|               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); | ||||
|           nestedBuilder.create<linalg::YieldOp>(loc, opResult); | ||||
|           nested_builder.create<linalg::YieldOp>(loc, op_result); | ||||
|         }); | ||||
|     rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); | ||||
|     rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> { | |||
|       LhloOp lhlo_op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     auto loc = lhlo_op.getLoc(); | ||||
|     auto argType = | ||||
|     auto arg_type = | ||||
|         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); | ||||
|     if (!argType || !argType.getElementType().isSignlessIntOrFloat() || | ||||
|         (argType.getRank() != 0)) { | ||||
|     if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() || | ||||
|         (arg_type.getRank() != 0)) { | ||||
|       return failure(); | ||||
|     } | ||||
| 
 | ||||
|  | @ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> { | |||
|     auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); | ||||
|     auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); | ||||
|     // TODO(ravishankarm) : Move this method out of lmhlo namespace.
 | ||||
|     Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>( | ||||
|         lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, | ||||
|     Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>( | ||||
|         lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, | ||||
|         &rewriter); | ||||
|     rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); | ||||
|     rewriter.create<StoreOp>(loc, op_result, lhlo_op.out()); | ||||
|     rewriter.eraseOp(lhlo_op); | ||||
|     return success(); | ||||
|   } | ||||
|  | @ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> { | |||
|       lmhlo::ConvOp op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     // Check validity of dimension information.
 | ||||
|     if (const mhlo::ConvDimensionNumbers& dimensionNumbers = | ||||
|     if (const mhlo::ConvDimensionNumbers& dimension_numbers = | ||||
|             op.dimension_numbers()) { | ||||
|       const int inputSpatialRank = | ||||
|           llvm::size(dimensionNumbers.input_spatial_dimensions()); | ||||
|       const int input_spatial_rank = | ||||
|           llvm::size(dimension_numbers.input_spatial_dimensions()); | ||||
|       // The dimensions for input should follow the order of
 | ||||
|       // batch_count, spatial_dims..., input_feature_count.
 | ||||
|       if (dimensionNumbers.input_batch_dimension().getInt() != 0 || | ||||
|           dimensionNumbers.input_feature_dimension().getInt() != | ||||
|               (inputSpatialRank + 1)) | ||||
|       if (dimension_numbers.input_batch_dimension().getInt() != 0 || | ||||
|           dimension_numbers.input_feature_dimension().getInt() != | ||||
|               (input_spatial_rank + 1)) | ||||
|         return failure(); | ||||
| 
 | ||||
|       const int kernelSpatialRank = | ||||
|           llvm::size(dimensionNumbers.kernel_spatial_dimensions()); | ||||
|       const int kernel_spatial_rank = | ||||
|           llvm::size(dimension_numbers.kernel_spatial_dimensions()); | ||||
|       // The dimensions for filter should follow the order of
 | ||||
|       // spatial_dims..., input_feature_count, num_output_feature_count.
 | ||||
|       if (dimensionNumbers.kernel_input_feature_dimension().getInt() != | ||||
|               kernelSpatialRank || | ||||
|           dimensionNumbers.kernel_output_feature_dimension().getInt() != | ||||
|               (kernelSpatialRank + 1)) | ||||
|       if (dimension_numbers.kernel_input_feature_dimension().getInt() != | ||||
|               kernel_spatial_rank || | ||||
|           dimension_numbers.kernel_output_feature_dimension().getInt() != | ||||
|               (kernel_spatial_rank + 1)) | ||||
|         return failure(); | ||||
| 
 | ||||
|       const int outputSpatialRank = | ||||
|           llvm::size(dimensionNumbers.output_spatial_dimensions()); | ||||
|       const int output_spatial_rank = | ||||
|           llvm::size(dimension_numbers.output_spatial_dimensions()); | ||||
|       // The dimensions for output should follow the order of
 | ||||
|       // batch_count, spatial_dims.., output_feature_count.
 | ||||
|       if (dimensionNumbers.output_batch_dimension().getInt() != 0 || | ||||
|           dimensionNumbers.output_feature_dimension().getInt() != | ||||
|               (outputSpatialRank + 1)) | ||||
|       if (dimension_numbers.output_batch_dimension().getInt() != 0 || | ||||
|           dimension_numbers.output_feature_dimension().getInt() != | ||||
|               (output_spatial_rank + 1)) | ||||
|         return failure(); | ||||
| 
 | ||||
|       if (inputSpatialRank != outputSpatialRank || | ||||
|           inputSpatialRank != kernelSpatialRank) | ||||
|       if (input_spatial_rank != output_spatial_rank || | ||||
|           input_spatial_rank != kernel_spatial_rank) | ||||
|         return failure(); | ||||
| 
 | ||||
|       auto inputSpatialDim = | ||||
|           dimensionNumbers.input_spatial_dimensions().begin(); | ||||
|       auto kernelSpatialDim = | ||||
|           dimensionNumbers.kernel_spatial_dimensions().begin(); | ||||
|       auto outputSpatialDim = | ||||
|           dimensionNumbers.output_spatial_dimensions().begin(); | ||||
|       auto input_spatial_dim = | ||||
|           dimension_numbers.input_spatial_dimensions().begin(); | ||||
|       auto kernel_spatial_dim = | ||||
|           dimension_numbers.kernel_spatial_dimensions().begin(); | ||||
|       auto output_spatial_dim = | ||||
|           dimension_numbers.output_spatial_dimensions().begin(); | ||||
|       // Check if spatial dims are ordered correctly.
 | ||||
|       for (int i = 0; i < inputSpatialRank; ++i) { | ||||
|       for (int i = 0; i < input_spatial_rank; ++i) { | ||||
|         const int dim = i + 1; | ||||
|         if ((*inputSpatialDim++).getZExtValue() != dim || | ||||
|             (*outputSpatialDim++).getZExtValue() != dim || | ||||
|             (*kernelSpatialDim++).getZExtValue() != i) | ||||
|         if ((*input_spatial_dim++).getZExtValue() != dim || | ||||
|             (*output_spatial_dim++).getZExtValue() != dim || | ||||
|             (*kernel_spatial_dim++).getZExtValue() != i) | ||||
|           return failure(); | ||||
|       } | ||||
|     } | ||||
|  | @ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> { | |||
|     } | ||||
| 
 | ||||
|     llvm::SmallVector<Attribute, 4> strides; | ||||
|     if (auto windowStrides = op.window_strides()) { | ||||
|       auto range = windowStrides->getAttributeValues(); | ||||
|     if (auto window_strides = op.window_strides()) { | ||||
|       auto range = window_strides->getAttributeValues(); | ||||
|       strides.assign(range.begin(), range.end()); | ||||
|     } | ||||
|     auto stridesArg = ArrayAttr::get(strides, op.getContext()); | ||||
|     auto strides_arg = ArrayAttr::get(strides, op.getContext()); | ||||
| 
 | ||||
|     llvm::SmallVector<Attribute, 2> dilation; | ||||
|     if (auto rhsDilation = op.rhs_dilation()) { | ||||
|       auto range = rhsDilation->getAttributeValues(); | ||||
|     if (auto rhs_dilation = op.rhs_dilation()) { | ||||
|       auto range = rhs_dilation->getAttributeValues(); | ||||
|       dilation.assign(range.begin(), range.end()); | ||||
|     } else { | ||||
|       // Default dilation of 1.
 | ||||
|       dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); | ||||
|     } | ||||
|     auto dilationArg = ArrayAttr::get(dilation, op.getContext()); | ||||
|     auto dilation_arg = ArrayAttr::get(dilation, op.getContext()); | ||||
| 
 | ||||
|     // Set padding only if it is non-zero.
 | ||||
|     DenseIntElementsAttr padding = op.paddingAttr(); | ||||
|     if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) { | ||||
|           return !intVal.isNullValue(); | ||||
|         })) { | ||||
|     if (!padding || | ||||
|         !llvm::any_of(padding.getValues<APInt>(), | ||||
|                       [](APInt int_val) { return !int_val.isNullValue(); })) { | ||||
|       padding = nullptr; | ||||
|     } | ||||
| 
 | ||||
|     // The order of input and filter are switched with linalg.conv.
 | ||||
|     rewriter.replaceOpWithNewOp<linalg::ConvOp>( | ||||
|         op, args[1], args[0], args[2], stridesArg, dilationArg, padding); | ||||
|         op, args[1], args[0], args[2], strides_arg, dilation_arg, padding); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> { | |||
|       OpTy op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); | ||||
|     auto resultType = getHloOpResultType<isLHLO>(op); | ||||
|     auto result_type = getHloOpResultType<isLHLO>(op); | ||||
| 
 | ||||
|     SmallVector<AffineMap, 2> indexing_maps = | ||||
|         Derived::getIndexingMaps(op, &rewriter); | ||||
|     if (indexing_maps.empty()) return failure(); | ||||
| 
 | ||||
|     auto nloops = resultType.getRank(); | ||||
|     auto nloops = result_type.getRank(); | ||||
|     auto loc = op.getLoc(); | ||||
|     auto linalgOp = rewriter.create<linalg::GenericOp>( | ||||
|     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||
|         loc, | ||||
|         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType, | ||||
|         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type, | ||||
|         /*inputs=*/args.front(), | ||||
|         /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, | ||||
|         /*initTensor=*/ValueRange{}, indexing_maps, | ||||
|         GetNParallelLoopsAttrs(nloops), | ||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { | ||||
|           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); | ||||
|         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||
|           nested_builder.create<linalg::YieldOp>(loc, *args.begin()); | ||||
|         }); | ||||
|     rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); | ||||
|     rewriter.replaceOp(op, linalg_op.getOperation()->getResults()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -325,32 +325,32 @@ class BroadcastConverter | |||
|   using DataMovementOpConverter<BroadcastConverter, OpTy, | ||||
|                                 isLHLO>::DataMovementOpConverter; | ||||
| 
 | ||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp, | ||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op, | ||||
|                                                    Builder* b) { | ||||
|     ShapedType inputType = | ||||
|         broadcastOp.operand().getType().template cast<ShapedType>(); | ||||
|     unsigned inputRank = inputType.getRank(); | ||||
|     unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank(); | ||||
|     ShapedType input_type = | ||||
|         broadcast_op.operand().getType().template cast<ShapedType>(); | ||||
|     unsigned input_rank = input_type.getRank(); | ||||
|     unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank(); | ||||
| 
 | ||||
|     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
 | ||||
|     // the input's dimensions.
 | ||||
|     unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); | ||||
|     SmallVector<AffineExpr, 4> inputDimExprs; | ||||
|     inputDimExprs.reserve(inputRank); | ||||
|     for (int i = 0; i < inputRank; ++i) { | ||||
|       inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); | ||||
|     unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes()); | ||||
|     SmallVector<AffineExpr, 4> input_dim_exprs; | ||||
|     input_dim_exprs.reserve(input_rank); | ||||
|     for (int i = 0; i < input_rank; ++i) { | ||||
|       input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i)); | ||||
|     } | ||||
| 
 | ||||
|     AffineMap inputMap; | ||||
|     AffineMap input_map; | ||||
|     MLIRContext* context = b->getContext(); | ||||
|     if (inputDimExprs.empty()) { | ||||
|     if (input_dim_exprs.empty()) { | ||||
|       // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||
|       inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); | ||||
|       input_map = AffineMap::get(nloops, /*symbolCount=*/0, context); | ||||
|     } else { | ||||
|       inputMap = | ||||
|           AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); | ||||
|       input_map = | ||||
|           AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context); | ||||
|     } | ||||
|     return {inputMap, b->getMultiDimIdentityMap(nloops)}; | ||||
|     return {input_map, b->getMultiDimIdentityMap(nloops)}; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
|  | @ -363,34 +363,34 @@ class HloBroadcastInDimConverter | |||
|                                 false>::DataMovementOpConverter; | ||||
| 
 | ||||
|   static SmallVector<AffineMap, 2> getIndexingMaps( | ||||
|       mhlo::BroadcastInDimOp broadcastOp, Builder* b) { | ||||
|     auto resultType = getHloOpResultType<false>(broadcastOp); | ||||
|     auto operandType = | ||||
|         broadcastOp.operand().getType().template cast<ShapedType>(); | ||||
|     unsigned nloops = resultType.getRank(); | ||||
|       mhlo::BroadcastInDimOp broadcast_op, Builder* b) { | ||||
|     auto result_type = getHloOpResultType<false>(broadcast_op); | ||||
|     auto operand_type = | ||||
|         broadcast_op.operand().getType().template cast<ShapedType>(); | ||||
|     unsigned nloops = result_type.getRank(); | ||||
| 
 | ||||
|     // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||
|     if (operandType.getRank() == 0) { | ||||
|     if (operand_type.getRank() == 0) { | ||||
|       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), | ||||
|               b->getMultiDimIdentityMap(nloops)}; | ||||
|     } | ||||
| 
 | ||||
|     auto operandShape = operandType.getShape(); | ||||
|     SmallVector<AffineExpr, 4> dimExprs; | ||||
|     dimExprs.reserve(nloops); | ||||
|     auto operand_shape = operand_type.getShape(); | ||||
|     SmallVector<AffineExpr, 4> dim_exprs; | ||||
|     dim_exprs.reserve(nloops); | ||||
| 
 | ||||
|     if (broadcastOp.broadcast_dimensions()) { | ||||
|     if (broadcast_op.broadcast_dimensions()) { | ||||
|       for (const auto& broadcastDim : | ||||
|            enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { | ||||
|            enumerate(broadcast_op.broadcast_dimensions().getIntValues())) { | ||||
|         int size = broadcastDim.value().getSExtValue(); | ||||
|         bool expansion_needed = operandShape[broadcastDim.index()] == 1 && | ||||
|                                 resultType.getShape()[size] != 1; | ||||
|         dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) | ||||
|         bool expansion_needed = operand_shape[broadcastDim.index()] == 1 && | ||||
|                                 result_type.getShape()[size] != 1; | ||||
|         dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) | ||||
|                                              : b->getAffineDimExpr(size)); | ||||
|       } | ||||
|     } | ||||
|     return { | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), | ||||
|         b->getMultiDimIdentityMap(nloops)}; | ||||
|   } | ||||
| }; | ||||
|  | @ -430,8 +430,8 @@ class LhloBroadcastInDimConverter | |||
|           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, | ||||
|           llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), | ||||
|           GetNParallelLoopsAttrs(nloops), | ||||
|           [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { | ||||
|             nestedBuilder.create<linalg::YieldOp>(loc, val); | ||||
|           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||
|             nested_builder.create<linalg::YieldOp>(loc, val); | ||||
|           }); | ||||
| 
 | ||||
|     } else { | ||||
|  | @ -441,8 +441,8 @@ class LhloBroadcastInDimConverter | |||
|           loc, /*inputs=*/ValueRange{operand}, | ||||
|           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, | ||||
|           GetNParallelLoopsAttrs(nloops), | ||||
|           [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { | ||||
|             nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); | ||||
|           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { | ||||
|             nested_builder.create<linalg::YieldOp>(loc, *args.begin()); | ||||
|           }); | ||||
|     } | ||||
|     rewriter.replaceOp(op, llvm::None); | ||||
|  | @ -520,35 +520,35 @@ class LhloBroadcastInDimConverter | |||
|   } | ||||
| 
 | ||||
|   SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op, | ||||
|                                             ArrayRef<int64_t> broadcastDims, | ||||
|                                             ArrayRef<int64_t> resultShape, | ||||
|                                             MemRefType operandType, | ||||
|                                             ArrayRef<int64_t> broadcast_dims, | ||||
|                                             ArrayRef<int64_t> result_shape, | ||||
|                                             MemRefType operand_type, | ||||
|                                             Builder* b) const { | ||||
|     unsigned nloops = resultShape.size(); | ||||
|     unsigned nloops = result_shape.size(); | ||||
| 
 | ||||
|     // The input is a scalar, i.e. this is a scalar broadcast op.
 | ||||
|     if (operandType.getRank() == 0) { | ||||
|     if (operand_type.getRank() == 0) { | ||||
|       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), | ||||
|               b->getMultiDimIdentityMap(nloops)}; | ||||
|     } | ||||
| 
 | ||||
|     auto operandShape = operandType.getShape(); | ||||
|     SmallVector<AffineExpr, 4> dimExprs; | ||||
|     dimExprs.reserve(nloops); | ||||
|     auto operand_shape = operand_type.getShape(); | ||||
|     SmallVector<AffineExpr, 4> dim_exprs; | ||||
|     dim_exprs.reserve(nloops); | ||||
| 
 | ||||
|     for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { | ||||
|       int size = broadcastDim.value(); | ||||
|     for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) { | ||||
|       int size = broadcast_dim.value(); | ||||
|       bool expansion_needed = | ||||
|           operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; | ||||
|           operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1; | ||||
|       if (expansion_needed) { | ||||
|         op.emitOpError( | ||||
|             "BroadcastInDimOp lowering to Linalg does not support size-1 " | ||||
|             "dimensions expansion."); | ||||
|       } | ||||
|       dimExprs.push_back(b->getAffineDimExpr(size)); | ||||
|       dim_exprs.push_back(b->getAffineDimExpr(size)); | ||||
|     } | ||||
|     return { | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()), | ||||
|         b->getMultiDimIdentityMap(nloops)}; | ||||
|   } | ||||
| }; | ||||
|  | @ -561,17 +561,17 @@ class TransposeConverter | |||
|   using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, | ||||
|                                 isLHLO>::DataMovementOpConverter; | ||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { | ||||
|     auto resultType = | ||||
|     auto result_type = | ||||
|         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); | ||||
|     auto nloops = resultType.getRank(); | ||||
|     SmallVector<AffineExpr, 2> inputExprs; | ||||
|     inputExprs.resize(resultType.getRank()); | ||||
|     auto nloops = result_type.getRank(); | ||||
|     SmallVector<AffineExpr, 2> input_exprs; | ||||
|     input_exprs.resize(result_type.getRank()); | ||||
|     for (auto permutation : llvm::enumerate(op.permutation())) { | ||||
|       inputExprs[permutation.value().getZExtValue()] = | ||||
|       input_exprs[permutation.value().getZExtValue()] = | ||||
|           b->getAffineDimExpr(permutation.index()); | ||||
|     } | ||||
|     return { | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), | ||||
|         b->getMultiDimIdentityMap(nloops)}; | ||||
|   } | ||||
| }; | ||||
|  | @ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> { | |||
|   using OpConversionPattern<OpTy>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       OpTy reshapeOp, ArrayRef<Value> args, | ||||
|       OpTy reshape_op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp)) | ||||
|     if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op)) | ||||
|       return failure(); | ||||
|     ShapedType operandType = | ||||
|         reshapeOp.operand().getType().template cast<ShapedType>(); | ||||
|     ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp); | ||||
|     ShapedType operand_type = | ||||
|         reshape_op.operand().getType().template cast<ShapedType>(); | ||||
|     ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op); | ||||
| 
 | ||||
|     if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) | ||||
|     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Compute the reassociation maps for the linalg operation.
 | ||||
|     ArrayRef<int64_t> srcShape = | ||||
|         (operandType.getRank() > resultType.getRank() ? operandType.getShape() | ||||
|                                                       : resultType.getShape()); | ||||
|     ArrayRef<int64_t> dstShape = | ||||
|         (operandType.getRank() > resultType.getRank() ? resultType.getShape() | ||||
|                                                       : operandType.getShape()); | ||||
|     unsigned currSrcDim = 0, currDstDim = 0; | ||||
|     SmallVector<linalg::ReassociationExprs, 4> reassociationMap( | ||||
|         dstShape.size()); | ||||
|     bool isExpandingOrCollapsing = true; | ||||
|     while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { | ||||
|       int64_t dstSize = dstShape[currDstDim]; | ||||
|       int64_t srcSize = srcShape[currSrcDim]; | ||||
|       while (srcSize < dstSize && currSrcDim < srcShape.size()) { | ||||
|         reassociationMap[currDstDim].push_back( | ||||
|             rewriter.getAffineDimExpr(currSrcDim++)); | ||||
|         srcSize *= srcShape[currSrcDim]; | ||||
|     ArrayRef<int64_t> src_shape = | ||||
|         (operand_type.getRank() > result_type.getRank() | ||||
|              ? operand_type.getShape() | ||||
|              : result_type.getShape()); | ||||
|     ArrayRef<int64_t> dst_shape = | ||||
|         (operand_type.getRank() > result_type.getRank() | ||||
|              ? result_type.getShape() | ||||
|              : operand_type.getShape()); | ||||
|     unsigned curr_src_dim = 0, curr_dst_dim = 0; | ||||
|     SmallVector<linalg::ReassociationExprs, 4> reassociation_map( | ||||
|         dst_shape.size()); | ||||
|     bool is_expanding_or_collapsing = true; | ||||
|     while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { | ||||
|       int64_t dst_size = dst_shape[curr_dst_dim]; | ||||
|       int64_t src_size = src_shape[curr_src_dim]; | ||||
|       while (src_size < dst_size && curr_src_dim < src_shape.size()) { | ||||
|         reassociation_map[curr_dst_dim].push_back( | ||||
|             rewriter.getAffineDimExpr(curr_src_dim++)); | ||||
|         src_size *= src_shape[curr_src_dim]; | ||||
|       } | ||||
|       if (srcSize == dstSize) { | ||||
|         reassociationMap[currDstDim].push_back( | ||||
|             rewriter.getAffineDimExpr(currSrcDim++)); | ||||
|         // If the next dim in dstShape is not 1, treat subsequent dims in
 | ||||
|         // srcShape which are 1 to be collapsed.
 | ||||
|         if (currDstDim == dstShape.size() - 1 || | ||||
|             dstShape[currDstDim + 1] != 1) { | ||||
|           while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { | ||||
|             reassociationMap[currDstDim].push_back( | ||||
|                 rewriter.getAffineDimExpr(currSrcDim++)); | ||||
|       if (src_size == dst_size) { | ||||
|         reassociation_map[curr_dst_dim].push_back( | ||||
|             rewriter.getAffineDimExpr(curr_src_dim++)); | ||||
|         // If the next dim in dst_shape is not 1, treat subsequent dims in
 | ||||
|         // src_shape which are 1 to be collapsed.
 | ||||
|         if (curr_dst_dim == dst_shape.size() - 1 || | ||||
|             dst_shape[curr_dst_dim + 1] != 1) { | ||||
|           while (curr_src_dim < src_shape.size() && | ||||
|                  src_shape[curr_src_dim] == 1) { | ||||
|             reassociation_map[curr_dst_dim].push_back( | ||||
|                 rewriter.getAffineDimExpr(curr_src_dim++)); | ||||
|           } | ||||
|         } | ||||
|       } else { | ||||
|         isExpandingOrCollapsing = false; | ||||
|         is_expanding_or_collapsing = false; | ||||
|         break; | ||||
|       } | ||||
|       currDstDim++; | ||||
|       curr_dst_dim++; | ||||
|     } | ||||
|     if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) | ||||
|       isExpandingOrCollapsing = false; | ||||
|     if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) | ||||
|       is_expanding_or_collapsing = false; | ||||
| 
 | ||||
|     if (!isExpandingOrCollapsing) { | ||||
|       auto getIdentityExprs = [&rewriter](int n) { | ||||
|     if (!is_expanding_or_collapsing) { | ||||
|       auto get_identity_exprs = [&rewriter](int n) { | ||||
|         SmallVector<AffineExpr, 4> exprs; | ||||
|         for (int i = 0; i < n; ++i) | ||||
|           exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||
|         return exprs; | ||||
|       }; | ||||
|       Location loc = reshapeOp.getLoc(); | ||||
|       int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, | ||||
|                                            std::multiplies<int64_t>()); | ||||
|       auto elemType = operandType.getElementType(); | ||||
|       SmallVector<linalg::ReassociationExprs, 4> collapsingMap = { | ||||
|           getIdentityExprs(dstShape.size())}; | ||||
|       SmallVector<linalg::ReassociationExprs, 4> expandingMap = { | ||||
|           getIdentityExprs(srcShape.size())}; | ||||
|       Location loc = reshape_op.getLoc(); | ||||
|       int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(), | ||||
|                                             1, std::multiplies<int64_t>()); | ||||
|       auto elem_type = operand_type.getElementType(); | ||||
|       SmallVector<linalg::ReassociationExprs, 4> collapsing_map = { | ||||
|           get_identity_exprs(dst_shape.size())}; | ||||
|       SmallVector<linalg::ReassociationExprs, 4> expanding_map = { | ||||
|           get_identity_exprs(src_shape.size())}; | ||||
| 
 | ||||
|       if (isLHLO) { | ||||
|         auto collapsedType = MemRefType::get({totalElems}, elemType); | ||||
|         Value collapsedOp = rewriter.create<linalg::ReshapeOp>( | ||||
|             loc, collapsedType, args[0], collapsingMap); | ||||
|         Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( | ||||
|             loc, resultType, collapsedOp, expandingMap); | ||||
|         auto collapsed_type = MemRefType::get({total_elems}, elem_type); | ||||
|         Value collapsed_op = rewriter.create<linalg::ReshapeOp>( | ||||
|             loc, collapsed_type, args[0], collapsing_map); | ||||
|         Value reshape_buffer = rewriter.create<linalg::ReshapeOp>( | ||||
|             loc, result_type, collapsed_op, expanding_map); | ||||
|         rewriter.replaceOpWithNewOp<linalg::CopyOp>( | ||||
|             reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, | ||||
|             reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, | ||||
|             /*outputPermutation =*/nullptr); | ||||
|       } else { | ||||
|         auto collapsedType = RankedTensorType::get({totalElems}, elemType); | ||||
|         Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>( | ||||
|             loc, collapsedType, args[0], collapsingMap); | ||||
|         auto collapsed_type = RankedTensorType::get({total_elems}, elem_type); | ||||
|         Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>( | ||||
|             loc, collapsed_type, args[0], collapsing_map); | ||||
|         rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( | ||||
|             reshapeOp, resultType, collapsedOp, expandingMap); | ||||
|             reshape_op, result_type, collapsed_op, expanding_map); | ||||
|       } | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     if (isLHLO) { | ||||
|       Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( | ||||
|           reshapeOp.getLoc(), resultType, args[0], reassociationMap); | ||||
|       Value reshape_buffer = rewriter.create<linalg::ReshapeOp>( | ||||
|           reshape_op.getLoc(), result_type, args[0], reassociation_map); | ||||
|       rewriter.replaceOpWithNewOp<linalg::CopyOp>( | ||||
|           reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, | ||||
|           reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr, | ||||
|           /*outputPermutation =*/nullptr); | ||||
|     } else { | ||||
|       rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( | ||||
|           reshapeOp, resultType, args[0], reassociationMap); | ||||
|           reshape_op, result_type, args[0], reassociation_map); | ||||
|     } | ||||
|     return success(); | ||||
|   } | ||||
|  | @ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> { | |||
|   using OpConversionPattern<OpTy>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       OpTy iotaOp, ArrayRef<Value> args, | ||||
|       OpTy iota_op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp); | ||||
|     if (!resultShapedType) return failure(); | ||||
|     ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op); | ||||
|     if (!result_shaped_type) return failure(); | ||||
| 
 | ||||
|     auto resultElementType = resultShapedType.getElementType(); | ||||
|     if (!resultElementType.isSignlessIntOrFloat()) return failure(); | ||||
|     auto result_element_type = result_shaped_type.getElementType(); | ||||
|     if (!result_element_type.isSignlessIntOrFloat()) return failure(); | ||||
| 
 | ||||
|     // Construct the indexing maps needed for linalg.generic ops.
 | ||||
|     unsigned nloops = resultShapedType.getRank(); | ||||
|     unsigned nloops = result_shaped_type.getRank(); | ||||
| 
 | ||||
|     auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( | ||||
|         iotaOp.getLoc(), | ||||
|     auto linalg_op = rewriter.create<linalg::IndexedGenericOp>( | ||||
|         iota_op.getLoc(), | ||||
|         /*resultTensorTypes=*/ | ||||
|         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType}, | ||||
|         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type}, | ||||
|         /*inputs=*/ValueRange{}, | ||||
|         /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, | ||||
|         /*initTensors=*/ValueRange{}, | ||||
|         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), | ||||
|         GetNParallelLoopsAttrs(nloops), | ||||
|         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, | ||||
|         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs, | ||||
|             ValueRange args) { | ||||
|           Value castOp = nestedBuilder.create<IndexCastOp>( | ||||
|               nestedLoc, ivs[iotaOp.iota_dimension()], | ||||
|               nestedBuilder.getIntegerType( | ||||
|                   resultElementType.getIntOrFloatBitWidth())); | ||||
|           if (resultElementType.template isa<FloatType>()) { | ||||
|             castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp, | ||||
|                                                     resultElementType); | ||||
|           Value cast_op = nested_builder.create<IndexCastOp>( | ||||
|               nested_loc, ivs[iota_op.iota_dimension()], | ||||
|               nested_builder.getIntegerType( | ||||
|                   result_element_type.getIntOrFloatBitWidth())); | ||||
|           if (result_element_type.template isa<FloatType>()) { | ||||
|             cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op, | ||||
|                                                       result_element_type); | ||||
|           } | ||||
|           nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); | ||||
|           nested_builder.create<linalg::YieldOp>(nested_loc, cast_op); | ||||
|         }); | ||||
|     if (isLHLO) | ||||
|       rewriter.replaceOp(iotaOp, llvm::None); | ||||
|       rewriter.replaceOp(iota_op, llvm::None); | ||||
|     else | ||||
|       rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); | ||||
|       rewriter.replaceOp(iota_op, linalg_op.result_tensors()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> { | |||
|   using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       lmhlo::ConstOp constOp, ArrayRef<Value> args, | ||||
|       lmhlo::ConstOp const_op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     auto loc = constOp.getLoc(); | ||||
|     auto valueAttr = constOp.value().cast<DenseElementsAttr>(); | ||||
|     if (valueAttr.getType().getRank() != 0) return failure(); | ||||
|     auto stdConstOp = | ||||
|         rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({})); | ||||
|     rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(), | ||||
|                                          ValueRange()); | ||||
|     rewriter.eraseOp(constOp); | ||||
|     auto loc = const_op.getLoc(); | ||||
|     auto value_attr = const_op.value().cast<DenseElementsAttr>(); | ||||
|     if (value_attr.getType().getRank() != 0) return failure(); | ||||
|     auto std_const_op = | ||||
|         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({})); | ||||
|     rewriter.create<mlir::AffineStoreOp>(loc, std_const_op, | ||||
|                                          const_op.getOperand(), ValueRange()); | ||||
|     rewriter.eraseOp(const_op); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -758,21 +761,21 @@ class ReverseConverter | |||
|   using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, | ||||
|                                 isLHLO>::DataMovementOpConverter; | ||||
|   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { | ||||
|     auto resultType = | ||||
|     auto result_type = | ||||
|         getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); | ||||
|     auto nloops = resultType.getRank(); | ||||
|     SmallVector<AffineExpr, 2> inputExprs; | ||||
|     inputExprs.reserve(nloops); | ||||
|     auto nloops = result_type.getRank(); | ||||
|     SmallVector<AffineExpr, 2> input_exprs; | ||||
|     input_exprs.reserve(nloops); | ||||
|     for (int i = 0; i < nloops; ++i) | ||||
|       inputExprs.push_back(b->getAffineDimExpr(i)); | ||||
|       input_exprs.push_back(b->getAffineDimExpr(i)); | ||||
|     for (auto dim : op.dimensions()) { | ||||
|       int i = dim.getZExtValue(); | ||||
|       if (resultType.isDynamicDim(i)) return {}; | ||||
|       int n = resultType.getShape()[i]; | ||||
|       inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; | ||||
|       if (result_type.isDynamicDim(i)) return {}; | ||||
|       int n = result_type.getShape()[i]; | ||||
|       input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i]; | ||||
|     } | ||||
|     return { | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), | ||||
|         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()), | ||||
|         b->getMultiDimIdentityMap(nloops)}; | ||||
|   } | ||||
| }; | ||||
|  | @ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> { | |||
|   using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       lmhlo::SliceOp sliceOp, ArrayRef<Value> args, | ||||
|       lmhlo::SliceOp slice_op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     auto loc = sliceOp.getLoc(); | ||||
|     auto argType = | ||||
|         sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>(); | ||||
|     if (!argType || !argType.hasRank()) { | ||||
|     auto loc = slice_op.getLoc(); | ||||
|     auto arg_type = | ||||
|         slice_op.getOperand(0).getType().template dyn_cast<ShapedType>(); | ||||
|     if (!arg_type || !arg_type.hasRank()) { | ||||
|       emitError(loc, "lhlo to linalg conversion expects known-rank args"); | ||||
|       return failure(); | ||||
|     } | ||||
| 
 | ||||
|     SmallVector<Value, 3> ranges; | ||||
|     for (int i = 0, e = argType.getRank(); i < e; ++i) { | ||||
|     for (int i = 0, e = arg_type.getRank(); i < e; ++i) { | ||||
|       Value start_index = rewriter.create<ConstantIndexOp>( | ||||
|           loc, sliceOp.start_indices().getValue<int64_t>(i)); | ||||
|           loc, slice_op.start_indices().getValue<int64_t>(i)); | ||||
|       Value limit_index = rewriter.create<ConstantIndexOp>( | ||||
|           loc, sliceOp.limit_indices().getValue<int64_t>(i)); | ||||
|           loc, slice_op.limit_indices().getValue<int64_t>(i)); | ||||
|       Value stride = rewriter.create<ConstantIndexOp>( | ||||
|           loc, sliceOp.strides().getValue<int64_t>(i)); | ||||
|           loc, slice_op.strides().getValue<int64_t>(i)); | ||||
|       ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index, | ||||
|                                                         limit_index, stride)); | ||||
|     } | ||||
|     auto linalg_slice = | ||||
|         rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges); | ||||
|     rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1)); | ||||
|     rewriter.eraseOp(sliceOp); | ||||
|         rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges); | ||||
|     rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1)); | ||||
|     rewriter.eraseOp(slice_op); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue