[NFC] Make naming style consistent.
Use lowercase with underscores between words instead of camelStyle. PiperOrigin-RevId: 338722328
This commit is contained in:
parent
31c1c3aa1f
commit
444fae9bac
|
@ -60,13 +60,13 @@ ShapedType getHloOpResultType(Operation* op) {
|
|||
|
||||
template <bool isLHLO = true>
|
||||
bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
|
||||
auto verifyType = [&](Value val) -> bool {
|
||||
auto verify_type = [&](Value val) -> bool {
|
||||
return (isLHLO && val.getType().isa<MemRefType>()) ||
|
||||
(!isLHLO && val.getType().isa<RankedTensorType>());
|
||||
};
|
||||
if (!llvm::all_of(op->getOperands(), verifyType)) return false;
|
||||
if (!llvm::all_of(op->getOperands(), verify_type)) return false;
|
||||
return isLHLO ? op->getResults().empty()
|
||||
: llvm::all_of(op->getResults(), verifyType);
|
||||
: llvm::all_of(op->getResults(), verify_type);
|
||||
}
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
|
@ -99,51 +99,51 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||
<< nloops << " parallel iterators: " << *(op.getOperation());
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
||||
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
||||
|
||||
// This doesnt account for implicit broadcast, but the working assumption
|
||||
// in HLO/LHLO is that are broadcasts are made explicit.
|
||||
|
||||
if (isLHLO && !nloops) return failure();
|
||||
|
||||
int numInputs = (isLHLO ? args.size() - 1 : args.size());
|
||||
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
|
||||
|
||||
ValueRange inputs(args.take_front(numInputs));
|
||||
ValueRange inputs(args.take_front(num_inputs));
|
||||
for (Value in : inputs)
|
||||
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
|
||||
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
|
||||
|
||||
ValueRange outputBuffers(args.take_back(args.size() - numInputs));
|
||||
for (Value out : outputBuffers)
|
||||
bodyResultTypes.emplace_back(getElementTypeOrSelf(out.getType()));
|
||||
ValueRange output_buffers(args.take_back(args.size() - num_inputs));
|
||||
for (Value out : output_buffers)
|
||||
body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
|
||||
|
||||
if (!isLHLO) {
|
||||
// HLO operations have return as tensor types.
|
||||
assert(bodyResultTypes.empty() &&
|
||||
assert(body_result_types.empty() &&
|
||||
"When lowering HLO ops result can't be part of arguments");
|
||||
Value result = op.getOperation()->getResult(0);
|
||||
bodyResultTypes.push_back(getElementTypeOrSelf(result));
|
||||
opResultTypes.push_back(result.getType());
|
||||
body_result_types.push_back(getElementTypeOrSelf(result));
|
||||
op_result_types.push_back(result.getType());
|
||||
}
|
||||
|
||||
AffineMap commonIndexingMap =
|
||||
AffineMap common_indexing_map =
|
||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext());
|
||||
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
||||
commonIndexingMap);
|
||||
common_indexing_map);
|
||||
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, opResultTypes, inputs, outputBuffers,
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
loc, op_result_types, inputs, output_buffers,
|
||||
/*initTensors=*/ValueRange{}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes,
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
||||
op, body_result_types,
|
||||
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
||||
});
|
||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -157,10 +157,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
LhloOp lhlo_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = lhlo_op.getLoc();
|
||||
auto argType =
|
||||
auto arg_type =
|
||||
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
|
||||
(argType.getRank() != 0)) {
|
||||
if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
|
||||
(arg_type.getRank() != 0)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -168,10 +168,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||
Value opResult = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return success();
|
||||
}
|
||||
|
@ -192,52 +192,52 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
|||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
const int input_spatial_rank =
|
||||
llvm::size(dimension_numbers.input_spatial_dimensions());
|
||||
// The dimensions for input should follow the order of
|
||||
// batch_count, spatial_dims..., input_feature_count.
|
||||
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.input_feature_dimension().getInt() !=
|
||||
(inputSpatialRank + 1))
|
||||
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.input_feature_dimension().getInt() !=
|
||||
(input_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
const int kernelSpatialRank =
|
||||
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
|
||||
const int kernel_spatial_rank =
|
||||
llvm::size(dimension_numbers.kernel_spatial_dimensions());
|
||||
// The dimensions for filter should follow the order of
|
||||
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernelSpatialRank ||
|
||||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernelSpatialRank + 1))
|
||||
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernel_spatial_rank ||
|
||||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernel_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
const int outputSpatialRank =
|
||||
llvm::size(dimensionNumbers.output_spatial_dimensions());
|
||||
const int output_spatial_rank =
|
||||
llvm::size(dimension_numbers.output_spatial_dimensions());
|
||||
// The dimensions for output should follow the order of
|
||||
// batch_count, spatial_dims.., output_feature_count.
|
||||
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.output_feature_dimension().getInt() !=
|
||||
(outputSpatialRank + 1))
|
||||
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.output_feature_dimension().getInt() !=
|
||||
(output_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
if (inputSpatialRank != outputSpatialRank ||
|
||||
inputSpatialRank != kernelSpatialRank)
|
||||
if (input_spatial_rank != output_spatial_rank ||
|
||||
input_spatial_rank != kernel_spatial_rank)
|
||||
return failure();
|
||||
|
||||
auto inputSpatialDim =
|
||||
dimensionNumbers.input_spatial_dimensions().begin();
|
||||
auto kernelSpatialDim =
|
||||
dimensionNumbers.kernel_spatial_dimensions().begin();
|
||||
auto outputSpatialDim =
|
||||
dimensionNumbers.output_spatial_dimensions().begin();
|
||||
auto input_spatial_dim =
|
||||
dimension_numbers.input_spatial_dimensions().begin();
|
||||
auto kernel_spatial_dim =
|
||||
dimension_numbers.kernel_spatial_dimensions().begin();
|
||||
auto output_spatial_dim =
|
||||
dimension_numbers.output_spatial_dimensions().begin();
|
||||
// Check if spatial dims are ordered correctly.
|
||||
for (int i = 0; i < inputSpatialRank; ++i) {
|
||||
for (int i = 0; i < input_spatial_rank; ++i) {
|
||||
const int dim = i + 1;
|
||||
if ((*inputSpatialDim++).getZExtValue() != dim ||
|
||||
(*outputSpatialDim++).getZExtValue() != dim ||
|
||||
(*kernelSpatialDim++).getZExtValue() != i)
|
||||
if ((*input_spatial_dim++).getZExtValue() != dim ||
|
||||
(*output_spatial_dim++).getZExtValue() != dim ||
|
||||
(*kernel_spatial_dim++).getZExtValue() != i)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
@ -248,33 +248,33 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
|||
}
|
||||
|
||||
llvm::SmallVector<Attribute, 4> strides;
|
||||
if (auto windowStrides = op.window_strides()) {
|
||||
auto range = windowStrides->getAttributeValues();
|
||||
if (auto window_strides = op.window_strides()) {
|
||||
auto range = window_strides->getAttributeValues();
|
||||
strides.assign(range.begin(), range.end());
|
||||
}
|
||||
auto stridesArg = ArrayAttr::get(strides, op.getContext());
|
||||
auto strides_arg = ArrayAttr::get(strides, op.getContext());
|
||||
|
||||
llvm::SmallVector<Attribute, 2> dilation;
|
||||
if (auto rhsDilation = op.rhs_dilation()) {
|
||||
auto range = rhsDilation->getAttributeValues();
|
||||
if (auto rhs_dilation = op.rhs_dilation()) {
|
||||
auto range = rhs_dilation->getAttributeValues();
|
||||
dilation.assign(range.begin(), range.end());
|
||||
} else {
|
||||
// Default dilation of 1.
|
||||
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
||||
}
|
||||
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
|
||||
auto dilation_arg = ArrayAttr::get(dilation, op.getContext());
|
||||
|
||||
// Set padding only if it is non-zero.
|
||||
DenseIntElementsAttr padding = op.paddingAttr();
|
||||
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
|
||||
return !intVal.isNullValue();
|
||||
})) {
|
||||
if (!padding ||
|
||||
!llvm::any_of(padding.getValues<APInt>(),
|
||||
[](APInt int_val) { return !int_val.isNullValue(); })) {
|
||||
padding = nullptr;
|
||||
}
|
||||
|
||||
// The order of input and filter are switched with linalg.conv.
|
||||
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
||||
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
|
||||
op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -293,25 +293,25 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
|||
OpTy op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
auto resultType = getHloOpResultType<isLHLO>(op);
|
||||
auto result_type = getHloOpResultType<isLHLO>(op);
|
||||
|
||||
SmallVector<AffineMap, 2> indexing_maps =
|
||||
Derived::getIndexingMaps(op, &rewriter);
|
||||
if (indexing_maps.empty()) return failure();
|
||||
|
||||
auto nloops = resultType.getRank();
|
||||
auto nloops = result_type.getRank();
|
||||
auto loc = op.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
loc,
|
||||
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : resultType,
|
||||
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
|
||||
/*inputs=*/args.front(),
|
||||
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
|
||||
/*initTensor=*/ValueRange{}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
});
|
||||
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -325,32 +325,32 @@ class BroadcastConverter
|
|||
using DataMovementOpConverter<BroadcastConverter, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
|
||||
Builder* b) {
|
||||
ShapedType inputType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
unsigned inputRank = inputType.getRank();
|
||||
unsigned nloops = getHloOpResultType<isLHLO>(broadcastOp).getRank();
|
||||
ShapedType input_type =
|
||||
broadcast_op.operand().getType().template cast<ShapedType>();
|
||||
unsigned input_rank = input_type.getRank();
|
||||
unsigned nloops = getHloOpResultType<isLHLO>(broadcast_op).getRank();
|
||||
|
||||
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
|
||||
// the input's dimensions.
|
||||
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
|
||||
SmallVector<AffineExpr, 4> inputDimExprs;
|
||||
inputDimExprs.reserve(inputRank);
|
||||
for (int i = 0; i < inputRank; ++i) {
|
||||
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
|
||||
unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
|
||||
SmallVector<AffineExpr, 4> input_dim_exprs;
|
||||
input_dim_exprs.reserve(input_rank);
|
||||
for (int i = 0; i < input_rank; ++i) {
|
||||
input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
|
||||
}
|
||||
|
||||
AffineMap inputMap;
|
||||
AffineMap input_map;
|
||||
MLIRContext* context = b->getContext();
|
||||
if (inputDimExprs.empty()) {
|
||||
if (input_dim_exprs.empty()) {
|
||||
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
|
||||
input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
|
||||
} else {
|
||||
inputMap =
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
|
||||
input_map =
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
|
||||
}
|
||||
return {inputMap, b->getMultiDimIdentityMap(nloops)};
|
||||
return {input_map, b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -363,34 +363,34 @@ class HloBroadcastInDimConverter
|
|||
false>::DataMovementOpConverter;
|
||||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getHloOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
unsigned nloops = resultType.getRank();
|
||||
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
|
||||
auto result_type = getHloOpResultType<false>(broadcast_op);
|
||||
auto operand_type =
|
||||
broadcast_op.operand().getType().template cast<ShapedType>();
|
||||
unsigned nloops = result_type.getRank();
|
||||
|
||||
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||
if (operandType.getRank() == 0) {
|
||||
if (operand_type.getRank() == 0) {
|
||||
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
|
||||
auto operandShape = operandType.getShape();
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
dimExprs.reserve(nloops);
|
||||
auto operand_shape = operand_type.getShape();
|
||||
SmallVector<AffineExpr, 4> dim_exprs;
|
||||
dim_exprs.reserve(nloops);
|
||||
|
||||
if (broadcastOp.broadcast_dimensions()) {
|
||||
if (broadcast_op.broadcast_dimensions()) {
|
||||
for (const auto& broadcastDim :
|
||||
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) {
|
||||
enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
|
||||
int size = broadcastDim.value().getSExtValue();
|
||||
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
|
||||
resultType.getShape()[size] != 1;
|
||||
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
||||
bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
|
||||
result_type.getShape()[size] != 1;
|
||||
dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
||||
: b->getAffineDimExpr(size));
|
||||
}
|
||||
}
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
@ -430,8 +430,8 @@ class LhloBroadcastInDimConverter
|
|||
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, val);
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||
nested_builder.create<linalg::YieldOp>(loc, val);
|
||||
});
|
||||
|
||||
} else {
|
||||
|
@ -441,8 +441,8 @@ class LhloBroadcastInDimConverter
|
|||
loc, /*inputs=*/ValueRange{operand},
|
||||
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
||||
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
});
|
||||
}
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
|
@ -520,35 +520,35 @@ class LhloBroadcastInDimConverter
|
|||
}
|
||||
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
|
||||
ArrayRef<int64_t> broadcastDims,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
MemRefType operandType,
|
||||
ArrayRef<int64_t> broadcast_dims,
|
||||
ArrayRef<int64_t> result_shape,
|
||||
MemRefType operand_type,
|
||||
Builder* b) const {
|
||||
unsigned nloops = resultShape.size();
|
||||
unsigned nloops = result_shape.size();
|
||||
|
||||
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||
if (operandType.getRank() == 0) {
|
||||
if (operand_type.getRank() == 0) {
|
||||
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
|
||||
auto operandShape = operandType.getShape();
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
dimExprs.reserve(nloops);
|
||||
auto operand_shape = operand_type.getShape();
|
||||
SmallVector<AffineExpr, 4> dim_exprs;
|
||||
dim_exprs.reserve(nloops);
|
||||
|
||||
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
|
||||
int size = broadcastDim.value();
|
||||
for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
|
||||
int size = broadcast_dim.value();
|
||||
bool expansion_needed =
|
||||
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
|
||||
operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
|
||||
if (expansion_needed) {
|
||||
op.emitOpError(
|
||||
"BroadcastInDimOp lowering to Linalg does not support size-1 "
|
||||
"dimensions expansion.");
|
||||
}
|
||||
dimExprs.push_back(b->getAffineDimExpr(size));
|
||||
dim_exprs.push_back(b->getAffineDimExpr(size));
|
||||
}
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
@ -561,17 +561,17 @@ class TransposeConverter
|
|||
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
auto result_type =
|
||||
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultType.getRank());
|
||||
auto nloops = result_type.getRank();
|
||||
SmallVector<AffineExpr, 2> input_exprs;
|
||||
input_exprs.resize(result_type.getRank());
|
||||
for (auto permutation : llvm::enumerate(op.permutation())) {
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
input_exprs[permutation.value().getZExtValue()] =
|
||||
b->getAffineDimExpr(permutation.index());
|
||||
}
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
@ -584,101 +584,104 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
OpTy reshapeOp, ArrayRef<Value> args,
|
||||
OpTy reshape_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
|
||||
if (!verifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
||||
return failure();
|
||||
ShapedType operandType =
|
||||
reshapeOp.operand().getType().template cast<ShapedType>();
|
||||
ShapedType resultType = getHloOpResultType<isLHLO>(reshapeOp);
|
||||
ShapedType operand_type =
|
||||
reshape_op.operand().getType().template cast<ShapedType>();
|
||||
ShapedType result_type = getHloOpResultType<isLHLO>(reshape_op);
|
||||
|
||||
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Compute the reassociation maps for the linalg operation.
|
||||
ArrayRef<int64_t> srcShape =
|
||||
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
|
||||
: resultType.getShape());
|
||||
ArrayRef<int64_t> dstShape =
|
||||
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
|
||||
: operandType.getShape());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||
dstShape.size());
|
||||
bool isExpandingOrCollapsing = true;
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= srcShape[currSrcDim];
|
||||
ArrayRef<int64_t> src_shape =
|
||||
(operand_type.getRank() > result_type.getRank()
|
||||
? operand_type.getShape()
|
||||
: result_type.getShape());
|
||||
ArrayRef<int64_t> dst_shape =
|
||||
(operand_type.getRank() > result_type.getRank()
|
||||
? result_type.getShape()
|
||||
: operand_type.getShape());
|
||||
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
||||
dst_shape.size());
|
||||
bool is_expanding_or_collapsing = true;
|
||||
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
||||
int64_t dst_size = dst_shape[curr_dst_dim];
|
||||
int64_t src_size = src_shape[curr_src_dim];
|
||||
while (src_size < dst_size && curr_src_dim < src_shape.size()) {
|
||||
reassociation_map[curr_dst_dim].push_back(
|
||||
rewriter.getAffineDimExpr(curr_src_dim++));
|
||||
src_size *= src_shape[curr_src_dim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in dstShape is not 1, treat subsequent dims in
|
||||
// srcShape which are 1 to be collapsed.
|
||||
if (currDstDim == dstShape.size() - 1 ||
|
||||
dstShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
if (src_size == dst_size) {
|
||||
reassociation_map[curr_dst_dim].push_back(
|
||||
rewriter.getAffineDimExpr(curr_src_dim++));
|
||||
// If the next dim in dst_shape is not 1, treat subsequent dims in
|
||||
// src_shape which are 1 to be collapsed.
|
||||
if (curr_dst_dim == dst_shape.size() - 1 ||
|
||||
dst_shape[curr_dst_dim + 1] != 1) {
|
||||
while (curr_src_dim < src_shape.size() &&
|
||||
src_shape[curr_src_dim] == 1) {
|
||||
reassociation_map[curr_dst_dim].push_back(
|
||||
rewriter.getAffineDimExpr(curr_src_dim++));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isExpandingOrCollapsing = false;
|
||||
is_expanding_or_collapsing = false;
|
||||
break;
|
||||
}
|
||||
currDstDim++;
|
||||
curr_dst_dim++;
|
||||
}
|
||||
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size())
|
||||
isExpandingOrCollapsing = false;
|
||||
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
||||
is_expanding_or_collapsing = false;
|
||||
|
||||
if (!isExpandingOrCollapsing) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
if (!is_expanding_or_collapsing) {
|
||||
auto get_identity_exprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
return exprs;
|
||||
};
|
||||
Location loc = reshapeOp.getLoc();
|
||||
int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
auto elemType = operandType.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
||||
getIdentityExprs(dstShape.size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
||||
getIdentityExprs(srcShape.size())};
|
||||
Location loc = reshape_op.getLoc();
|
||||
int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
|
||||
1, std::multiplies<int64_t>());
|
||||
auto elem_type = operand_type.getElementType();
|
||||
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
||||
get_identity_exprs(dst_shape.size())};
|
||||
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
||||
get_identity_exprs(src_shape.size())};
|
||||
|
||||
if (isLHLO) {
|
||||
auto collapsedType = MemRefType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, resultType, collapsedOp, expandingMap);
|
||||
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
||||
Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, collapsed_type, args[0], collapsing_map);
|
||||
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
||||
loc, result_type, collapsed_op, expanding_map);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
auto collapsedType = RankedTensorType::get({totalElems}, elemType);
|
||||
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, collapsedType, args[0], collapsingMap);
|
||||
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
|
||||
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
|
||||
loc, collapsed_type, args[0], collapsing_map);
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, collapsedOp, expandingMap);
|
||||
reshape_op, result_type, collapsed_op, expanding_map);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isLHLO) {
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
|
||||
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
||||
reshape_op.getLoc(), result_type, args[0], reassociation_map);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, args[0], reassociationMap);
|
||||
reshape_op, result_type, args[0], reassociation_map);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -690,42 +693,42 @@ class IotaConverter : public OpConversionPattern<OpTy> {
|
|||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
OpTy iotaOp, ArrayRef<Value> args,
|
||||
OpTy iota_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
|
||||
if (!resultShapedType) return failure();
|
||||
ShapedType result_shaped_type = getHloOpResultType<isLHLO>(iota_op);
|
||||
if (!result_shaped_type) return failure();
|
||||
|
||||
auto resultElementType = resultShapedType.getElementType();
|
||||
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
||||
auto result_element_type = result_shaped_type.getElementType();
|
||||
if (!result_element_type.isSignlessIntOrFloat()) return failure();
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
unsigned nloops = resultShapedType.getRank();
|
||||
unsigned nloops = result_shaped_type.getRank();
|
||||
|
||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(),
|
||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||
iota_op.getLoc(),
|
||||
/*resultTensorTypes=*/
|
||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{resultShapedType},
|
||||
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
||||
/*inputs=*/ValueRange{},
|
||||
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
|
||||
/*initTensors=*/ValueRange{},
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
||||
ValueRange args) {
|
||||
Value castOp = nestedBuilder.create<IndexCastOp>(
|
||||
nestedLoc, ivs[iotaOp.iota_dimension()],
|
||||
nestedBuilder.getIntegerType(
|
||||
resultElementType.getIntOrFloatBitWidth()));
|
||||
if (resultElementType.template isa<FloatType>()) {
|
||||
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
||||
resultElementType);
|
||||
Value cast_op = nested_builder.create<IndexCastOp>(
|
||||
nested_loc, ivs[iota_op.iota_dimension()],
|
||||
nested_builder.getIntegerType(
|
||||
result_element_type.getIntOrFloatBitWidth()));
|
||||
if (result_element_type.template isa<FloatType>()) {
|
||||
cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
|
||||
result_element_type);
|
||||
}
|
||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
||||
nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
|
||||
});
|
||||
if (isLHLO)
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
rewriter.replaceOp(iota_op, llvm::None);
|
||||
else
|
||||
rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
|
||||
rewriter.replaceOp(iota_op, linalg_op.result_tensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -735,16 +738,16 @@ class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
|||
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||
lmhlo::ConstOp const_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = constOp.getLoc();
|
||||
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
|
||||
if (valueAttr.getType().getRank() != 0) return failure();
|
||||
auto stdConstOp =
|
||||
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
|
||||
ValueRange());
|
||||
rewriter.eraseOp(constOp);
|
||||
auto loc = const_op.getLoc();
|
||||
auto value_attr = const_op.value().cast<DenseElementsAttr>();
|
||||
if (value_attr.getType().getRank() != 0) return failure();
|
||||
auto std_const_op =
|
||||
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
||||
rewriter.create<mlir::AffineStoreOp>(loc, std_const_op,
|
||||
const_op.getOperand(), ValueRange());
|
||||
rewriter.eraseOp(const_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -758,21 +761,21 @@ class ReverseConverter
|
|||
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
auto result_type =
|
||||
getHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.reserve(nloops);
|
||||
auto nloops = result_type.getRank();
|
||||
SmallVector<AffineExpr, 2> input_exprs;
|
||||
input_exprs.reserve(nloops);
|
||||
for (int i = 0; i < nloops; ++i)
|
||||
inputExprs.push_back(b->getAffineDimExpr(i));
|
||||
input_exprs.push_back(b->getAffineDimExpr(i));
|
||||
for (auto dim : op.dimensions()) {
|
||||
int i = dim.getZExtValue();
|
||||
if (resultType.isDynamicDim(i)) return {};
|
||||
int n = resultType.getShape()[i];
|
||||
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
|
||||
if (result_type.isDynamicDim(i)) return {};
|
||||
int n = result_type.getShape()[i];
|
||||
input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
|
||||
}
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
@ -782,31 +785,31 @@ class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
|||
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = sliceOp.getLoc();
|
||||
auto argType =
|
||||
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||
if (!argType || !argType.hasRank()) {
|
||||
auto loc = slice_op.getLoc();
|
||||
auto arg_type =
|
||||
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||
if (!arg_type || !arg_type.hasRank()) {
|
||||
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value, 3> ranges;
|
||||
for (int i = 0, e = argType.getRank(); i < e; ++i) {
|
||||
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
||||
Value start_index = rewriter.create<ConstantIndexOp>(
|
||||
loc, sliceOp.start_indices().getValue<int64_t>(i));
|
||||
loc, slice_op.start_indices().getValue<int64_t>(i));
|
||||
Value limit_index = rewriter.create<ConstantIndexOp>(
|
||||
loc, sliceOp.limit_indices().getValue<int64_t>(i));
|
||||
loc, slice_op.limit_indices().getValue<int64_t>(i));
|
||||
Value stride = rewriter.create<ConstantIndexOp>(
|
||||
loc, sliceOp.strides().getValue<int64_t>(i));
|
||||
loc, slice_op.strides().getValue<int64_t>(i));
|
||||
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
|
||||
limit_index, stride));
|
||||
}
|
||||
auto linalg_slice =
|
||||
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges);
|
||||
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1));
|
||||
rewriter.eraseOp(sliceOp);
|
||||
rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges);
|
||||
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
|
||||
rewriter.eraseOp(slice_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue