[NFC] Make naming style consistent.

Use lowercase with underscores between words instead of camelStyle.

PiperOrigin-RevId: 338722328
This commit is contained in:
Hanhan Wang 2020-10-23 12:22:21 -07:00 committed by TensorFlow MLIR Team
parent 31c1c3aa1f
commit 444fae9bac
1 changed files with 243 additions and 240 deletions

View File

@ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) {
template <bool isLHLO = true>
bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool {
auto verify_type = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>());
};
if (!llvm::all_of(op->getOperands(), verifyType)) return false;
if (!llvm::all_of(op->getOperands(), verify_type)) return false;
return isLHLO ? op->getResults().empty()
: llvm::all_of(op->getResults(), verifyType);
: llvm::all_of(op->getResults(), verify_type);
}
template <typename OpTy, bool isLHLO = true>
@ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
<< nloops << " parallel iterators: " << *(op.getOperation());
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
// This doesnt account for implicit broadcast, but the working assumption
// in HLO/LHLO is that are broadcasts are made explicit.
if (isLHLO && !nloops) return failure();
int numInputs = (isLHLO ? args.size() - 1 : args.size());
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
ValueRange inputs(args.take_front(numInputs));
ValueRange inputs(args.take_front(num_inputs));
for (Value in : inputs)
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
ValueRange outputBuffers(args.take_back(args.size() - numInputs));
for (Value out : outputBuffers)
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
ValueRange output_buffers(args.take_back(args.size() - num_inputs));
for (Value out : output_buffers)
body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
if (!isLHLO) {
// HLO operations have return as tensor types.
assert(bodyResultTypes.empty() &&
assert(body_result_types.empty() &&
"When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0);
bodyResultTypes.push_back(getElementTypeOrSelf(result));
opResultTypes.push_back(result.getType());
body_result_types.push_back(getElementTypeOrSelf(result));
op_result_types.push_back(result.getType());
}
AffineMap commonIndexingMap =
AffineMap common_indexing_map =
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext());
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
commonIndexingMap);
common_indexing_map);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, inputs, outputBuffers,
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, op_result_types, inputs, output_buffers,
/*initTensors=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there.
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
op, body_result_types,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
nested_builder.create<linalg::YieldOp>(loc, op_result);
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success();
}
};
@ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc();
auto argType =
auto arg_type =
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
(argType.getRank() != 0)) {
if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
(arg_type.getRank() != 0)) {
return failure();
}
@ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return success();
}
@ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
const int input_spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1))
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimension_numbers.input_feature_dimension().getInt() !=
(input_spatial_rank + 1))
return failure();
const int kernelSpatialRank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
const int kernel_spatial_rank =
llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1))
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernel_spatial_rank ||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernel_spatial_rank + 1))
return failure();
const int outputSpatialRank =
llvm::size(dimensionNumbers.output_spatial_dimensions());
const int output_spatial_rank =
llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1))
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimension_numbers.output_feature_dimension().getInt() !=
(output_spatial_rank + 1))
return failure();
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank)
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank)
return failure();
auto inputSpatialDim =
dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
dimensionNumbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim =
dimensionNumbers.output_spatial_dimensions().begin();
auto input_spatial_dim =
dimension_numbers.input_spatial_dimensions().begin();
auto kernel_spatial_dim =
dimension_numbers.kernel_spatial_dimensions().begin();
auto output_spatial_dim =
dimension_numbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) {
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i)
if ((*input_spatial_dim++).getZExtValue() != dim ||
(*output_spatial_dim++).getZExtValue() != dim ||
(*kernel_spatial_dim++).getZExtValue() != i)
return failure();
}
}
@ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
}
llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) {
auto range = windowStrides->getAttributeValues();
if (auto window_strides = op.window_strides()) {
auto range = window_strides->getAttributeValues();
strides.assign(range.begin(), range.end());
}
auto stridesArg = ArrayAttr::get(strides, op.getContext());
auto strides_arg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues();
if (auto rhs_dilation = op.rhs_dilation()) {
auto range = rhs_dilation->getAttributeValues();
dilation.assign(range.begin(), range.end());
} else {
// Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
}
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
auto dilation_arg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
return !intVal.isNullValue();
})) {
if (!padding ||
!llvm::any_of(padding.getValues<APInt>(),
[](APInt int_val) { return !int_val.isNullValue(); })) {
padding = nullptr;
}
// The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
return success();
}
};
@ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getHloOpResultType<isLHLO>(op);
auto result_type = getHloOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure();
auto nloops = resultType.getRank();
auto nloops = result_type.getRank();
auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
/*inputs=*/args.front(),
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
/*initTensor=*/ValueRange{}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success();
}
};
@ -325,32 +325,32 @@ class BroadcastConverter
using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
Builder* b) {
ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank();
ShapedType input_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned input_rank = input_type.getRank();
unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions.
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
SmallVector<AffineExpr, 4> inputDimExprs;
inputDimExprs.reserve(inputRank);
for (int i = 0; i < inputRank; ++i) {
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
SmallVector<AffineExpr, 4> input_dim_exprs;
input_dim_exprs.reserve(input_rank);
for (int i = 0; i < input_rank; ++i) {
input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
}
AffineMap inputMap;
AffineMap input_map;
MLIRContext* context = b->getContext();
if (inputDimExprs.empty()) {
if (input_dim_exprs.empty()) {
// The input is a scalar, i.e. this is a scalar broadcast op.
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
} else {
inputMap =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
input_map =
AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
}
return {inputMap, b->getMultiDimIdentityMap(nloops)};
return {input_map, b->getMultiDimIdentityMap(nloops)};
}
};
@ -363,34 +363,34 @@ class HloBroadcastInDimConverter
false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getHloOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank();
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
auto result_type = getHloOpResultType<false>(broadcast_op);
auto operand_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned nloops = result_type.getRank();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dim_exprs;
dim_exprs.reserve(nloops);
if (broadcastOp.broadcast_dimensions()) {
if (broadcast_op.broadcast_dimensions()) {
for (const auto& broadcastDim :
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) {
enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1;
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
result_type.getShape()[size] != 1;
dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
}
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -430,8 +430,8 @@ class LhloBroadcastInDimConverter
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, val);
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, val);
});
} else {
@ -441,8 +441,8 @@ class LhloBroadcastInDimConverter
loc, /*inputs=*/ValueRange{operand},
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
});
}
rewriter.replaceOp(op, llvm::None);
@ -520,35 +520,35 @@ class LhloBroadcastInDimConverter
}
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType,
ArrayRef<int64_t> broadcast_dims,
ArrayRef<int64_t> result_shape,
MemRefType operand_type,
Builder* b) const {
unsigned nloops = resultShape.size();
unsigned nloops = result_shape.size();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dim_exprs;
dim_exprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
int size = broadcastDim.value();
for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
int size = broadcast_dim.value();
bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
if (expansion_needed) {
op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion.");
}
dimExprs.push_back(b->getAffineDimExpr(size));
dim_exprs.push_back(b->getAffineDimExpr(size));
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -561,17 +561,17 @@ class TransposeConverter
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.resize(result_type.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] =
input_exprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args,
OpTy reshape_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
return failure();
ShapedType operandType =
reshapeOp.operand().getType().template cast<ShapedType>();
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp);
ShapedType operand_type =
reshape_op.operand().getType().template cast<ShapedType>();
ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
: resultType.getShape());
ArrayRef<int64_t> dstShape =
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
: operandType.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
dstShape.size());
bool isExpandingOrCollapsing = true;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
ArrayRef<int64_t> src_shape =
(operand_type.getRank() > result_type.getRank()
? operand_type.getShape()
: result_type.getShape());
ArrayRef<int64_t> dst_shape =
(operand_type.getRank() > result_type.getRank()
? result_type.getShape()
: operand_type.getShape());
unsigned curr_src_dim = 0, curr_dst_dim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
dst_shape.size());
bool is_expanding_or_collapsing = true;
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
int64_t dst_size = dst_shape[curr_dst_dim];
int64_t src_size = src_shape[curr_src_dim];
while (src_size < dst_size && curr_src_dim < src_shape.size()) {
reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
src_size *= src_shape[curr_src_dim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in dstShape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 ||
dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
if (src_size == dst_size) {
reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
// If the next dim in dst_shape is not 1, treat subsequent dims in
// src_shape which are 1 to be collapsed.
if (curr_dst_dim == dst_shape.size() - 1 ||
dst_shape[curr_dst_dim + 1] != 1) {
while (curr_src_dim < src_shape.size() &&
src_shape[curr_src_dim] == 1) {
reassociation_map[curr_dst_dim].push_back(
rewriter.getAffineDimExpr(curr_src_dim++));
}
}
} else {
isExpandingOrCollapsing = false;
is_expanding_or_collapsing = false;
break;
}
currDstDim++;
curr_dst_dim++;
}
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size())
isExpandingOrCollapsing = false;
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
is_expanding_or_collapsing = false;
if (!isExpandingOrCollapsing) {
auto getIdentityExprs = [&rewriter](int n) {
if (!is_expanding_or_collapsing) {
auto get_identity_exprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshapeOp.getLoc();
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
std::multiplies<int64_t>());
auto elemType = operandType.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
getIdentityExprs(dstShape.size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
getIdentityExprs(srcShape.size())};
Location loc = reshape_op.getLoc();
int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
1, std::multiplies<int64_t>());
auto elem_type = operand_type.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
get_identity_exprs(dst_shape.size())};
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
get_identity_exprs(src_shape.size())};
if (isLHLO) {
auto collapsedType = MemRefType::get({totalElems}, elemType);
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
loc, collapsedType, args[0], collapsingMap);
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
loc, resultType, collapsedOp, expandingMap);
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
loc, collapsed_type, args[0], collapsing_map);
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
loc, result_type, collapsed_op, expanding_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedType, args[0], collapsingMap);
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsed_type, args[0], collapsing_map);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, collapsedOp, expandingMap);
reshape_op, result_type, collapsed_op, expanding_map);
}
return success();
}
if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
reshape_op.getLoc(), result_type, args[0], reassociation_map);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMap);
reshape_op, result_type, args[0], reassociation_map);
}
return success();
}
@ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy iotaOp, ArrayRef<Value> args,
OpTy iota_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
if (!resultShapedType) return failure();
ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
if (!result_shaped_type) return failure();
auto resultElementType = resultShapedType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure();
auto result_element_type = result_shaped_type.getElementType();
if (!result_element_type.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultShapedType.getRank();
unsigned nloops = result_shaped_type.getRank();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(),
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
iota_op.getLoc(),
/*resultTensorTypes=*/
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{},
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
/*initTensors=*/ValueRange{},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
ValueRange args) {
Value castOp = nestedBuilder.create<IndexCastOp>(
nestedLoc, ivs[iotaOp.iota_dimension()],
nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType);
Value cast_op = nested_builder.create<IndexCastOp>(
nested_loc, ivs[iota_op.iota_dimension()],
nested_builder.getIntegerType(
result_element_type.getIntOrFloatBitWidth()));
if (result_element_type.template isa<FloatType>()) {
cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
result_element_type);
}
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
});
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None);
rewriter.replaceOp(iota_op, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
rewriter.replaceOp(iota_op, linalg_op.result_tensors());
return success();
}
};
@ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::ConstOp constOp, ArrayRef<Value> args,
lmhlo::ConstOp const_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
if (valueAttr.getType().getRank() != 0) return failure();
auto stdConstOp =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
ValueRange());
rewriter.eraseOp(constOp);
auto loc = const_op.getLoc();
auto value_attr = const_op.value().cast<DenseElementsAttr>();
if (value_attr.getType().getRank() != 0) return failure();
auto std_const_op =
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
const_op.getOperand(), ValueRange());
rewriter.eraseOp(const_op);
return success();
}
};
@ -758,21 +761,21 @@ class ReverseConverter
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
auto result_type =
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.reserve(nloops);
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.reserve(nloops);
for (int i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i));
input_exprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) {
int i = dim.getZExtValue();
if (resultType.isDynamicDim(i)) return {};
int n = resultType.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
if (result_type.isDynamicDim(i)) return {};
int n = result_type.getShape()[i];
input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc();
auto argType =
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) {
auto loc = slice_op.getLoc();
auto arg_type =
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!arg_type || !arg_type.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure();
}
SmallVector<Value, 3> ranges;
for (int i = 0, e = argType.getRank(); i < e; ++i) {
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
Value start_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.start_indices().getValue<int64_t>(i));
loc, slice_op.start_indices().getValue<int64_t>(i));
Value limit_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.limit_indices().getValue<int64_t>(i));
loc, slice_op.limit_indices().getValue<int64_t>(i));
Value stride = rewriter.create<ConstantIndexOp>(
loc, sliceOp.strides().getValue<int64_t>(i));
loc, slice_op.strides().getValue<int64_t>(i));
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
limit_index, stride));
}
auto linalg_slice =
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1));
rewriter.eraseOp(sliceOp);
rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
rewriter.eraseOp(slice_op);
return success();
}
};