[NFC] Make naming style consistent.

Use lowercase with underscores between words instead of camelStyle.

PiperOrigin-RevId: 338722328
This commit is contained in:
Hanhan Wang 2020-10-23 12:22:21 -07:00 committed by TensorFlow MLIR Team
parent 31c1c3aa1f
commit 444fae9bac
1 changed files with 243 additions and 240 deletions

View File

@ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) {
template <bool isLHLO = true> template <bool isLHLO = true>
bool verifyHloOpBufferOrTensorSemantics(Operation* op) { bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool { auto verify_type = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) || return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>()); (!isLHLO && val.getType().isa<RankedTensorType>());
}; };
if (!llvm::all_of(op->getOperands(), verifyType)) return false; if (!llvm::all_of(op->getOperands(), verify_type)) return false;
return isLHLO ? op->getResults().empty() return isLHLO ? op->getResults().empty()
: llvm::all_of(op->getResults(), verifyType); : llvm::all_of(op->getResults(), verify_type);
} }
template <typename OpTy, bool isLHLO = true> template <typename OpTy, bool isLHLO = true>
@ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
<< nloops << " parallel iterators: " << *(op.getOperation()); << nloops << " parallel iterators: " << *(op.getOperation());
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
// This doesnt account for implicit broadcast, but the working assumption // This doesnt account for implicit broadcast, but the working assumption
// in HLO/LHLO is that are broadcasts are made explicit. // in HLO/LHLO is that are broadcasts are made explicit.
if (isLHLO && !nloops) return failure(); if (isLHLO && !nloops) return failure();
int numInputs = (isLHLO ? args.size() - 1 : args.size()); int num_inputs = (isLHLO ? args.size() - 1 : args.size());
ValueRange inputs(args.take_front(numInputs)); ValueRange inputs(args.take_front(num_inputs));
for (Value in : inputs) for (Value in : inputs)
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType())); body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
ValueRange outputBuffers(args.take_back(args.size() - numInputs)); ValueRange output_buffers(args.take_back(args.size() - num_inputs));
for (Value out : outputBuffers) for (Value out : output_buffers)
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType())); body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
if (!isLHLO) { if (!isLHLO) {
// HLO operations have return as tensor types. // HLO operations have return as tensor types.
assert(bodyResultTypes.empty() && assert(body_result_types.empty() &&
"When lowering HLO ops result can't be part of arguments"); "When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0); Value result = op.getOperation()->getResult(0);
bodyResultTypes.push_back(getElementTypeOrSelf(result)); body_result_types.push_back(getElementTypeOrSelf(result));
opResultTypes.push_back(result.getType()); op_result_types.push_back(result.getType());
} }
AffineMap commonIndexingMap = AffineMap common_indexing_map =
nloops ? rewriter.getMultiDimIdentityMap(nloops) nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()); : AffineMap::get(nloops, 0, rewriter.getContext());
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1), SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
commonIndexingMap); common_indexing_map);
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, inputs, outputBuffers, loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps, /*initTensors=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace. // TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there. // That method needs to be moved out of there.
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>( Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, op, body_result_types,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter); llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult); nested_builder.create<linalg::YieldOp>(loc, op_result);
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success(); return success();
} }
}; };
@ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
LhloOp lhlo_op, ArrayRef<Value> args, LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc(); auto loc = lhlo_op.getLoc();
auto argType = auto arg_type =
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>(); lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isSignlessIntOrFloat() || if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
(argType.getRank() != 0)) { (arg_type.getRank() != 0)) {
return failure(); return failure();
} }
@ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace. // TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>( Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs}, lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter); &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
rewriter.eraseOp(lhlo_op); rewriter.eraseOp(lhlo_op);
return success(); return success();
} }
@ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
lmhlo::ConvOp op, ArrayRef<Value> args, lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information. // Check validity of dimension information.
if (const mhlo::ConvDimensionNumbers& dimensionNumbers = if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) { op.dimension_numbers()) {
const int inputSpatialRank = const int input_spatial_rank =
llvm::size(dimensionNumbers.input_spatial_dimensions()); llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of // The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count. // batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 || if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() != dimension_numbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1)) (input_spatial_rank + 1))
return failure(); return failure();
const int kernelSpatialRank = const int kernel_spatial_rank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions()); llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of // The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count. // spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() != if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank || kernel_spatial_rank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() != dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1)) (kernel_spatial_rank + 1))
return failure(); return failure();
const int outputSpatialRank = const int output_spatial_rank =
llvm::size(dimensionNumbers.output_spatial_dimensions()); llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of // The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count. // batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 || if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() != dimension_numbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1)) (output_spatial_rank + 1))
return failure(); return failure();
if (inputSpatialRank != outputSpatialRank || if (input_spatial_rank != output_spatial_rank ||
inputSpatialRank != kernelSpatialRank) input_spatial_rank != kernel_spatial_rank)
return failure(); return failure();
auto inputSpatialDim = auto input_spatial_dim =
dimensionNumbers.input_spatial_dimensions().begin(); dimension_numbers.input_spatial_dimensions().begin();
auto kernelSpatialDim = auto kernel_spatial_dim =
dimensionNumbers.kernel_spatial_dimensions().begin(); dimension_numbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim = auto output_spatial_dim =
dimensionNumbers.output_spatial_dimensions().begin(); dimension_numbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly. // Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) { for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1; const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim || if ((*input_spatial_dim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim || (*output_spatial_dim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i) (*kernel_spatial_dim++).getZExtValue() != i)
return failure(); return failure();
} }
} }
@ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
} }
llvm::SmallVector<Attribute, 4> strides; llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) { if (auto window_strides = op.window_strides()) {
auto range = windowStrides->getAttributeValues(); auto range = window_strides->getAttributeValues();
strides.assign(range.begin(), range.end()); strides.assign(range.begin(), range.end());
} }
auto stridesArg = ArrayAttr::get(strides, op.getContext()); auto strides_arg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation; llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) { if (auto rhs_dilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues(); auto range = rhs_dilation->getAttributeValues();
dilation.assign(range.begin(), range.end()); dilation.assign(range.begin(), range.end());
} else { } else {
// Default dilation of 1. // Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
} }
auto dilationArg = ArrayAttr::get(dilation, op.getContext()); auto dilation_arg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero. // Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr(); DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) { if (!padding ||
return !intVal.isNullValue(); !llvm::any_of(padding.getValues<APInt>(),
})) { [](APInt int_val) { return !int_val.isNullValue(); })) {
padding = nullptr; padding = nullptr;
} }
// The order of input and filter are switched with linalg.conv. // The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>( rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding); op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
return success(); return success();
} }
}; };
@ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
OpTy op, ArrayRef<Value> args, OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getHloOpResultType<isLHLO>(op); auto result_type = getHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps = SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter); Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure(); if (indexing_maps.empty()) return failure();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, loc,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType, /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
/*inputs=*/args.front(), /*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps, /*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nested_builder.create<linalg::YieldOp>(loc, *args.begin());
}); });
rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success(); return success();
} }
}; };
@ -325,32 +325,32 @@ class BroadcastConverter
using DataMovementOpConverter<BroadcastConverter, OpTy, using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp, static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
Builder* b) { Builder* b) {
ShapedType inputType = ShapedType input_type =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcast_op.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank(); unsigned input_rank = input_type.getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank(); unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions. // the input's dimensions.
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
SmallVector<AffineExpr, 4> inputDimExprs; SmallVector<AffineExpr, 4> input_dim_exprs;
inputDimExprs.reserve(inputRank); input_dim_exprs.reserve(input_rank);
for (int i = 0; i < inputRank; ++i) { for (int i = 0; i < input_rank; ++i) {
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
} }
AffineMap inputMap; AffineMap input_map;
MLIRContext* context = b->getContext(); MLIRContext* context = b->getContext();
if (inputDimExprs.empty()) { if (input_dim_exprs.empty()) {
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
} else { } else {
inputMap = input_map =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
} }
return {inputMap, b->getMultiDimIdentityMap(nloops)}; return {input_map, b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -363,34 +363,34 @@ class HloBroadcastInDimConverter
false>::DataMovementOpConverter; false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps( static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcastOp, Builder* b) { mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
auto resultType = getHloOpResultType<false>(broadcastOp); auto result_type = getHloOpResultType<false>(broadcast_op);
auto operandType = auto operand_type =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcast_op.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank(); unsigned nloops = result_type.getRank();
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
auto operandShape = operandType.getShape(); auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dimExprs; SmallVector<AffineExpr, 4> dim_exprs;
dimExprs.reserve(nloops); dim_exprs.reserve(nloops);
if (broadcastOp.broadcast_dimensions()) { if (broadcast_op.broadcast_dimensions()) {
for (const auto& broadcastDim : for (const auto& broadcastDim :
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
int size = broadcastDim.value().getSExtValue(); int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 && bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1; result_type.getShape()[size] != 1;
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size)); : b->getAffineDimExpr(size));
} }
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -430,8 +430,8 @@ class LhloBroadcastInDimConverter
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, /*outputBuffers=*/ValueRange{operand_adaptor.output()},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, val); nested_builder.create<linalg::YieldOp>(loc, val);
}); });
} else { } else {
@ -441,8 +441,8 @@ class LhloBroadcastInDimConverter
loc, /*inputs=*/ValueRange{operand}, loc, /*inputs=*/ValueRange{operand},
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps, /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin()); nested_builder.create<linalg::YieldOp>(loc, *args.begin());
}); });
} }
rewriter.replaceOp(op, llvm::None); rewriter.replaceOp(op, llvm::None);
@ -520,35 +520,35 @@ class LhloBroadcastInDimConverter
} }
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op, SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims, ArrayRef<int64_t> broadcast_dims,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> result_shape,
MemRefType operandType, MemRefType operand_type,
Builder* b) const { Builder* b) const {
unsigned nloops = resultShape.size(); unsigned nloops = result_shape.size();
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
auto operandShape = operandType.getShape(); auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dimExprs; SmallVector<AffineExpr, 4> dim_exprs;
dimExprs.reserve(nloops); dim_exprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
int size = broadcastDim.value(); int size = broadcast_dim.value();
bool expansion_needed = bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
if (expansion_needed) { if (expansion_needed) {
op.emitOpError( op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 " "BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion."); "dimensions expansion.");
} }
dimExprs.push_back(b->getAffineDimExpr(size)); dim_exprs.push_back(b->getAffineDimExpr(size));
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -561,17 +561,17 @@ class TransposeConverter
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> input_exprs;
inputExprs.resize(resultType.getRank()); input_exprs.resize(result_type.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) { for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] = input_exprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index()); b->getAffineDimExpr(permutation.index());
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args, OpTy reshape_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp)) if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
return failure(); return failure();
ShapedType operandType = ShapedType operand_type =
reshapeOp.operand().getType().template cast<ShapedType>(); reshape_op.operand().getType().template cast<ShapedType>();
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp); ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure(); return failure();
// Compute the reassociation maps for the linalg operation. // Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape = ArrayRef<int64_t> src_shape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape() (operand_type.getRank() > result_type.getRank()
: resultType.getShape()); ? operand_type.getShape()
ArrayRef<int64_t> dstShape = : result_type.getShape());
(operandType.getRank() > resultType.getRank() ? resultType.getShape() ArrayRef<int64_t> dst_shape =
: operandType.getShape()); (operand_type.getRank() > result_type.getRank()
unsigned currSrcDim = 0, currDstDim = 0; ? result_type.getShape()
SmallVector<linalg::ReassociationExprs, 4> reassociationMap( : operand_type.getShape());
dstShape.size()); unsigned curr_src_dim = 0, curr_dst_dim = 0;
bool isExpandingOrCollapsing = true; SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { dst_shape.size());
int64_t dstSize = dstShape[currDstDim]; bool is_expanding_or_collapsing = true;
int64_t srcSize = srcShape[currSrcDim]; while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
while (srcSize < dstSize && currSrcDim < srcShape.size()) { int64_t dst_size = dst_shape[curr_dst_dim];
reassociationMap[currDstDim].push_back( int64_t src_size = src_shape[curr_src_dim];
rewriter.getAffineDimExpr(currSrcDim++)); while (src_size < dst_size && curr_src_dim < src_shape.size()) {
srcSize *= srcShape[currSrcDim]; reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
src_size *= src_shape[curr_src_dim];
} }
if (srcSize == dstSize) { if (src_size == dst_size) {
reassociationMap[currDstDim].push_back( reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(currSrcDim++)); rewriter.getAffineDimExpr(curr_src_dim++));
// If the next dim in dstShape is not 1, treat subsequent dims in // If the next dim in dst_shape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed. // src_shape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || if (curr_dst_dim == dst_shape.size() - 1 ||
dstShape[currDstDim + 1] != 1) { dst_shape[curr_dst_dim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { while (curr_src_dim < src_shape.size() &&
reassociationMap[currDstDim].push_back( src_shape[curr_src_dim] == 1) {
rewriter.getAffineDimExpr(currSrcDim++)); reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
} }
} }
} else { } else {
isExpandingOrCollapsing = false; is_expanding_or_collapsing = false;
break; break;
} }
currDstDim++; curr_dst_dim++;
} }
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
isExpandingOrCollapsing = false; is_expanding_or_collapsing = false;
if (!isExpandingOrCollapsing) { if (!is_expanding_or_collapsing) {
auto getIdentityExprs = [&rewriter](int n) { auto get_identity_exprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs; SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i)); exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs; return exprs;
}; };
Location loc = reshapeOp.getLoc(); Location loc = reshape_op.getLoc();
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
std::multiplies<int64_t>()); 1, std::multiplies<int64_t>());
auto elemType = operandType.getElementType(); auto elem_type = operand_type.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = { SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
getIdentityExprs(dstShape.size())}; get_identity_exprs(dst_shape.size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = { SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
getIdentityExprs(srcShape.size())}; get_identity_exprs(src_shape.size())};
if (isLHLO) { if (isLHLO) {
auto collapsedType = MemRefType::get({totalElems}, elemType); auto collapsed_type = MemRefType::get({total_elems}, elem_type);
Value collapsedOp = rewriter.create<linalg::ReshapeOp>( Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
loc, collapsedType, args[0], collapsingMap); loc, collapsed_type, args[0], collapsing_map);
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
loc, resultType, collapsedOp, expandingMap); loc, result_type, collapsed_op, expanding_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>( rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr); /*outputPermutation =*/nullptr);
} else { } else {
auto collapsedType = RankedTensorType::get({totalElems}, elemType); auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>( Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedType, args[0], collapsingMap); loc, collapsed_type, args[0], collapsing_map);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, collapsedOp, expandingMap); reshape_op, result_type, collapsed_op, expanding_map);
} }
return success(); return success();
} }
if (isLHLO) { if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMap); reshape_op.getLoc(), result_type, args[0], reassociation_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>( rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr); /*outputPermutation =*/nullptr);
} else { } else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMap); reshape_op, result_type, args[0], reassociation_map);
} }
return success(); return success();
} }
@ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
OpTy iotaOp, ArrayRef<Value> args, OpTy iota_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp); ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
if (!resultShapedType) return failure(); if (!result_shaped_type) return failure();
auto resultElementType = resultShapedType.getElementType(); auto result_element_type = result_shaped_type.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure(); if (!result_element_type.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultShapedType.getRank(); unsigned nloops = result_shaped_type.getRank();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), iota_op.getLoc(),
/*resultTensorTypes=*/ /*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType}, isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{}, /*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{}, /*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{}, /*initTensors=*/ValueRange{},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops), GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
ValueRange args) { ValueRange args) {
Value castOp = nestedBuilder.create<IndexCastOp>( Value cast_op = nested_builder.create<IndexCastOp>(
nestedLoc, ivs[iotaOp.iota_dimension()], nested_loc, ivs[iota_op.iota_dimension()],
nestedBuilder.getIntegerType( nested_builder.getIntegerType(
resultElementType.getIntOrFloatBitWidth())); result_element_type.getIntOrFloatBitWidth()));
if (resultElementType.template isa<FloatType>()) { if (result_element_type.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp, cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
resultElementType); result_element_type);
} }
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp); nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
}); });
if (isLHLO) if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None); rewriter.replaceOp(iota_op, llvm::None);
else else
rewriter.replaceOp(iotaOp, linalgOp.result_tensors()); rewriter.replaceOp(iota_op, linalg_op.result_tensors());
return success(); return success();
} }
}; };
@ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern; using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::ConstOp constOp, ArrayRef<Value> args, lmhlo::ConstOp const_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc(); auto loc = const_op.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>(); auto value_attr = const_op.value().cast<DenseElementsAttr>();
if (valueAttr.getType().getRank() != 0) return failure(); if (value_attr.getType().getRank() != 0) return failure();
auto stdConstOp = auto std_const_op =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({})); rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(), rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
ValueRange()); const_op.getOperand(), ValueRange());
rewriter.eraseOp(constOp); rewriter.eraseOp(const_op);
return success(); return success();
} }
}; };
@ -758,21 +761,21 @@ class ReverseConverter
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>(); getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> inputExprs; SmallVector<AffineExpr, 2> input_exprs;
inputExprs.reserve(nloops); input_exprs.reserve(nloops);
for (int i = 0; i < nloops; ++i) for (int i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i)); input_exprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) { for (auto dim : op.dimensions()) {
int i = dim.getZExtValue(); int i = dim.getZExtValue();
if (resultType.isDynamicDim(i)) return {}; if (result_type.isDynamicDim(i)) return {};
int n = resultType.getShape()[i]; int n = result_type.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
} }
return { return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}; b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern; using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
lmhlo::SliceOp sliceOp, ArrayRef<Value> args, lmhlo::SliceOp slice_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc(); auto loc = slice_op.getLoc();
auto argType = auto arg_type =
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>(); slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) { if (!arg_type || !arg_type.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args"); emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure(); return failure();
} }
SmallVector<Value, 3> ranges; SmallVector<Value, 3> ranges;
for (int i = 0, e = argType.getRank(); i < e; ++i) { for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
Value start_index = rewriter.create<ConstantIndexOp>( Value start_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.start_indices().getValue<int64_t>(i)); loc, slice_op.start_indices().getValue<int64_t>(i));
Value limit_index = rewriter.create<ConstantIndexOp>( Value limit_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.limit_indices().getValue<int64_t>(i)); loc, slice_op.limit_indices().getValue<int64_t>(i));
Value stride = rewriter.create<ConstantIndexOp>( Value stride = rewriter.create<ConstantIndexOp>(
loc, sliceOp.strides().getValue<int64_t>(i)); loc, slice_op.strides().getValue<int64_t>(i));
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index, ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
limit_index, stride)); limit_index, stride));
} }
auto linalg_slice = auto linalg_slice =
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges); rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1)); rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
rewriter.eraseOp(sliceOp); rewriter.eraseOp(slice_op);
return success(); return success();
} }
}; };