2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
|
|
|
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
|
|
|
|
2020-08-22 14:26:35 +08:00
|
|
|
#include <numeric>
|
|
|
|
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "llvm/ADT/STLExtras.h"
|
2021-01-28 21:44:49 +08:00
|
|
|
#include "llvm/ADT/SetVector.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
2021-02-13 00:30:51 +08:00
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
2020-12-22 07:26:38 +08:00
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2021-01-12 02:11:39 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
2020-12-02 05:17:12 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2020-12-15 16:58:42 +08:00
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/Location.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
2021-01-28 21:44:49 +08:00
|
|
|
#include "mlir/IR/Matchers.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/Operation.h"
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "mlir/IR/OperationSupport.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
2021-01-28 21:44:49 +08:00
|
|
|
#include "mlir/Pass/PassManager.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace {
|
|
|
|
|
2021-01-28 21:44:49 +08:00
|
|
|
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
|
|
|
|
/// are "parallel" except the last `nReduction` elements, where are "reduction"
|
|
|
|
/// attributes.
|
|
|
|
SmallVector<StringRef, 3> GetParallelAndReductionIterators(
|
|
|
|
unsigned nLoops, unsigned nReduction) {
|
|
|
|
SmallVector<StringRef, 3> res(nLoops - nReduction,
|
|
|
|
getParallelIteratorTypeName());
|
|
|
|
res.append(nReduction, getReductionIteratorTypeName());
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
|
2021-01-28 21:44:49 +08:00
|
|
|
return GetParallelAndReductionIterators(nParallelLoops, 0);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
Value GetResultValue(Operation* op) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType GetHloOpResultType(Operation* op) {
|
|
|
|
return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto verify_type = [&](Value val) -> bool {
|
2020-07-07 04:57:00 +08:00
|
|
|
return (isLHLO && val.getType().isa<MemRefType>()) ||
|
|
|
|
(!isLHLO && val.getType().isa<RankedTensorType>());
|
|
|
|
};
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!llvm::all_of(op->getOperands(), verify_type)) return false;
|
2020-07-07 04:57:00 +08:00
|
|
|
return isLHLO ? op->getResults().empty()
|
2020-10-24 03:22:21 +08:00
|
|
|
: llvm::all_of(op->getResults(), verify_type);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-12-24 15:53:08 +08:00
|
|
|
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
|
2021-02-04 07:02:21 +08:00
|
|
|
ArrayRef<Value> dyn_sizes) {
|
2020-12-24 15:53:08 +08:00
|
|
|
return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
|
|
|
|
type.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
|
|
|
Value tensor) {
|
|
|
|
auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!tensor_type) return {};
|
|
|
|
SmallVector<Value, 2> dyn_sizes;
|
|
|
|
for (auto& en : llvm::enumerate(tensor_type.getShape())) {
|
|
|
|
if (en.value() != ShapedType::kDynamicSize) continue;
|
|
|
|
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
|
|
|
}
|
|
|
|
return dyn_sizes;
|
|
|
|
}
|
|
|
|
|
2021-01-13 14:07:29 +08:00
|
|
|
SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
|
|
|
|
SmallVector<int64_t, 4> ret;
|
|
|
|
for (const APInt& element : elements) {
|
|
|
|
ret.push_back(element.getLimitedValue());
|
|
|
|
}
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
2021-01-28 21:44:49 +08:00
|
|
|
/// Returns the constant value associated with the init value if the defining
|
|
|
|
/// operation is a constant.
|
|
|
|
Attribute GetInitValueAsConst(Value init) {
|
|
|
|
DenseElementsAttr attr;
|
|
|
|
if (!matchPattern(init, m_Constant(&attr))) return {};
|
|
|
|
auto type = attr.getType().dyn_cast<ShapedType>();
|
|
|
|
if (!type || type.getRank() != 0) return {};
|
|
|
|
return attr.getValue({});
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns a permutation AffineMap that puts all reduction dimensions to the
|
|
|
|
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
|
|
|
|
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
|
|
|
|
/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
|
|
|
|
/// the AffineMap is returned.
|
|
|
|
AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
|
|
|
|
ArrayRef<int64_t> reduction_dims) {
|
|
|
|
llvm::SmallSetVector<int, 4> s;
|
|
|
|
for (auto dim : reduction_dims) s.insert(dim);
|
|
|
|
|
|
|
|
SmallVector<unsigned, 4> permutation;
|
|
|
|
for (int i = 0; i < rank; ++i)
|
|
|
|
if (!s.count(i)) permutation.push_back(i);
|
|
|
|
for (auto dim : reduction_dims) permutation.push_back(dim);
|
|
|
|
|
|
|
|
auto map = AffineMap::getPermutationMap(permutation, context);
|
|
|
|
return inversePermutation(map);
|
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
OpTy op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = op.getLoc();
|
2020-09-23 00:06:55 +08:00
|
|
|
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!t0) return failure();
|
|
|
|
|
|
|
|
unsigned nloops = t0.getRank();
|
|
|
|
auto fail = [&](ShapedType t) {
|
|
|
|
return !t || !t.hasRank() || t.getRank() != nloops ||
|
|
|
|
!(t.getElementType().isSignlessIntOrFloat() ||
|
|
|
|
t.getElementType().isa<ComplexType>());
|
|
|
|
};
|
|
|
|
if (llvm::any_of(args,
|
|
|
|
[&](Value v) {
|
|
|
|
return fail(v.getType().dyn_cast<ShapedType>());
|
|
|
|
}) ||
|
|
|
|
llvm::any_of(op.getOperation()->getResultTypes(),
|
|
|
|
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
|
|
|
return emitError(loc,
|
|
|
|
"lhlo to linalg conversion expects ranked args of "
|
|
|
|
"signless int, float or complex element type with ")
|
|
|
|
<< nloops << " parallel iterators: " << *(op.getOperation());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Construct the indexing maps needed for linalg.generic ops.
|
2020-10-24 03:22:21 +08:00
|
|
|
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// This doesnt account for implicit broadcast, but the working assumption
|
2020-09-23 00:06:55 +08:00
|
|
|
// in HLO/LHLO is that are broadcasts are made explicit.
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
if (isLHLO && !nloops) return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
ValueRange inputs(args.take_front(num_inputs));
|
2020-09-23 00:06:55 +08:00
|
|
|
for (Value in : inputs)
|
2020-10-24 03:22:21 +08:00
|
|
|
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-12-24 15:53:08 +08:00
|
|
|
SmallVector<Value, 4> output_buffers;
|
|
|
|
if (isLHLO) {
|
|
|
|
output_buffers.append(args.begin() + num_inputs, args.end());
|
|
|
|
} else {
|
2020-07-07 04:57:00 +08:00
|
|
|
Value result = op.getOperation()->getResult(0);
|
2020-12-24 15:53:08 +08:00
|
|
|
ShapedType result_type = result.getType().template cast<ShapedType>();
|
|
|
|
auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
|
|
|
|
output_buffers.push_back(
|
2021-02-04 07:02:21 +08:00
|
|
|
GetInitTensor(rewriter, loc, result_type, dyn_sizes));
|
2020-10-24 03:22:21 +08:00
|
|
|
op_result_types.push_back(result.getType());
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-12-24 15:53:08 +08:00
|
|
|
body_result_types = llvm::to_vector<4>(llvm::map_range(
|
|
|
|
output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap common_indexing_map =
|
2020-09-23 00:06:55 +08:00
|
|
|
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
|
|
|
: AffineMap::get(nloops, 0, rewriter.getContext());
|
|
|
|
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
2020-10-24 03:22:21 +08:00
|
|
|
common_indexing_map);
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-12-16 20:57:35 +08:00
|
|
|
bool failed = false;
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
2020-12-24 15:53:08 +08:00
|
|
|
loc, op_result_types, inputs, output_buffers, indexing_maps,
|
2020-07-07 04:57:00 +08:00
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
2020-07-09 01:05:32 +08:00
|
|
|
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
2020-07-07 04:57:00 +08:00
|
|
|
// That method needs to be moved out of there.
|
2020-10-24 03:22:21 +08:00
|
|
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
|
|
|
op, body_result_types,
|
2020-09-23 00:06:55 +08:00
|
|
|
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
2020-12-16 20:57:35 +08:00
|
|
|
if (op_result == nullptr) {
|
|
|
|
failed = true;
|
|
|
|
} else {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-12-16 20:57:35 +08:00
|
|
|
if (failed) return failure();
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename LhloOp>
|
|
|
|
class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<LhloOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
LhloOp lhlo_op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = lhlo_op.getLoc();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto arg_type =
|
2020-07-07 04:57:00 +08:00
|
|
|
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
|
|
|
|
(arg_type.getRank() != 0)) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create two loads from the input.
|
|
|
|
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
|
|
|
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
2020-07-09 01:05:32 +08:00
|
|
|
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
2020-10-24 03:22:21 +08:00
|
|
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
|
|
|
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
2020-07-07 04:57:00 +08:00
|
|
|
&rewriter);
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
|
2020-07-07 04:57:00 +08:00
|
|
|
rewriter.eraseOp(lhlo_op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-07-09 01:05:32 +08:00
|
|
|
// lmhlo.convolution conversion pattern.
|
2020-07-07 04:57:00 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
/// Converts lmhlo.convolution operation to a linalg.conv op.
|
|
|
|
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// This code has been adapted from IREE's
|
2020-07-07 12:51:24 +08:00
|
|
|
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
2020-07-07 04:57:00 +08:00
|
|
|
LogicalResult matchAndRewrite(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::ConvOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
// Check validity of dimension information.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
|
2020-07-07 04:57:00 +08:00
|
|
|
op.dimension_numbers()) {
|
2020-10-24 03:22:21 +08:00
|
|
|
const int input_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.input_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for input should follow the order of
|
|
|
|
// batch_count, spatial_dims..., input_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
|
|
|
|
dimension_numbers.input_feature_dimension().getInt() !=
|
|
|
|
(input_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
const int kernel_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.kernel_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for filter should follow the order of
|
|
|
|
// spatial_dims..., input_feature_count, num_output_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
|
|
|
|
kernel_spatial_rank ||
|
|
|
|
dimension_numbers.kernel_output_feature_dimension().getInt() !=
|
|
|
|
(kernel_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
const int output_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.output_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for output should follow the order of
|
|
|
|
// batch_count, spatial_dims.., output_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
|
|
|
|
dimension_numbers.output_feature_dimension().getInt() !=
|
|
|
|
(output_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (input_spatial_rank != output_spatial_rank ||
|
|
|
|
input_spatial_rank != kernel_spatial_rank)
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto input_spatial_dim =
|
|
|
|
dimension_numbers.input_spatial_dimensions().begin();
|
|
|
|
auto kernel_spatial_dim =
|
|
|
|
dimension_numbers.kernel_spatial_dimensions().begin();
|
|
|
|
auto output_spatial_dim =
|
|
|
|
dimension_numbers.output_spatial_dimensions().begin();
|
2020-07-07 04:57:00 +08:00
|
|
|
// Check if spatial dims are ordered correctly.
|
2020-10-24 03:22:21 +08:00
|
|
|
for (int i = 0; i < input_spatial_rank; ++i) {
|
2020-07-07 04:57:00 +08:00
|
|
|
const int dim = i + 1;
|
2020-10-24 03:22:21 +08:00
|
|
|
if ((*input_spatial_dim++).getZExtValue() != dim ||
|
|
|
|
(*output_spatial_dim++).getZExtValue() != dim ||
|
|
|
|
(*kernel_spatial_dim++).getZExtValue() != i)
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: LHS dilation for deconvolution not supported yet.
|
2020-12-11 08:38:26 +08:00
|
|
|
// TODO(jurahul): Window reversal is not supported yet.
|
|
|
|
if (op.lhs_dilation() || op.hasWindowReversal()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<Attribute, 4> strides;
|
2020-10-24 03:22:21 +08:00
|
|
|
if (auto window_strides = op.window_strides()) {
|
|
|
|
auto range = window_strides->getAttributeValues();
|
2020-07-07 04:57:00 +08:00
|
|
|
strides.assign(range.begin(), range.end());
|
|
|
|
}
|
2021-02-09 00:01:22 +08:00
|
|
|
auto strides_arg = ArrayAttr::get(op.getContext(), strides);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
llvm::SmallVector<Attribute, 2> dilation;
|
2020-10-24 03:22:21 +08:00
|
|
|
if (auto rhs_dilation = op.rhs_dilation()) {
|
|
|
|
auto range = rhs_dilation->getAttributeValues();
|
2020-07-07 04:57:00 +08:00
|
|
|
dilation.assign(range.begin(), range.end());
|
|
|
|
} else {
|
|
|
|
// Default dilation of 1.
|
|
|
|
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
|
|
|
}
|
2021-02-09 00:01:22 +08:00
|
|
|
auto dilation_arg = ArrayAttr::get(op.getContext(), dilation);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Set padding only if it is non-zero.
|
|
|
|
DenseIntElementsAttr padding = op.paddingAttr();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!padding ||
|
|
|
|
!llvm::any_of(padding.getValues<APInt>(),
|
|
|
|
[](APInt int_val) { return !int_val.isNullValue(); })) {
|
2020-07-07 04:57:00 +08:00
|
|
|
padding = nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The order of input and filter are switched with linalg.conv.
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-09 11:32:16 +08:00
|
|
|
/// Base class for lowering HLO operations that have one operand and one result,
|
2020-07-07 04:57:00 +08:00
|
|
|
/// and are semantically equivalent to a copy of the input to the output (like
|
|
|
|
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
|
|
|
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
|
|
|
|
/// and the output.
|
|
|
|
template <typename Derived, typename OpTy, bool isLHLO = true>
|
|
|
|
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
OpTy op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
|
|
|
auto result_type = GetHloOpResultType<isLHLO>(op);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
SmallVector<AffineMap, 2> indexing_maps =
|
|
|
|
Derived::getIndexingMaps(op, &rewriter);
|
|
|
|
if (indexing_maps.empty()) return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
auto loc = op.getLoc();
|
2020-12-24 15:53:08 +08:00
|
|
|
// TODO(pifon): technically, the op itself could have size operands (e.g.
|
|
|
|
// broadcast into a dynamic dimension).Handle this case.
|
|
|
|
auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
|
|
|
|
: ExtractDynamicSizes(rewriter, loc, args[0]);
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc,
|
2020-10-24 03:22:21 +08:00
|
|
|
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
|
2020-09-23 00:06:55 +08:00
|
|
|
/*inputs=*/args.front(),
|
2020-12-24 15:53:08 +08:00
|
|
|
/*outputBuffers=*/
|
2021-02-04 07:02:21 +08:00
|
|
|
isLHLO
|
|
|
|
? ValueRange{args.back()}
|
|
|
|
: ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
|
2020-12-24 15:53:08 +08:00
|
|
|
indexing_maps, GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Pattern to convert BroadcastOp to Linalg ops.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class BroadcastConverter
|
|
|
|
: public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<BroadcastConverter, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
|
2020-07-07 04:57:00 +08:00
|
|
|
Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
ShapedType input_type =
|
|
|
|
broadcast_op.operand().getType().template cast<ShapedType>();
|
|
|
|
unsigned input_rank = input_type.getRank();
|
2020-12-15 02:46:04 +08:00
|
|
|
unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
|
|
|
|
// the input's dimensions.
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
|
|
|
|
SmallVector<AffineExpr, 4> input_dim_exprs;
|
|
|
|
input_dim_exprs.reserve(input_rank);
|
|
|
|
for (int i = 0; i < input_rank; ++i) {
|
|
|
|
input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap input_map;
|
2020-07-07 04:57:00 +08:00
|
|
|
MLIRContext* context = b->getContext();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (input_dim_exprs.empty()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
|
2020-07-07 04:57:00 +08:00
|
|
|
} else {
|
2020-10-24 03:22:21 +08:00
|
|
|
input_map =
|
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
return {input_map, b->getMultiDimIdentityMap(nloops)};
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class HloBroadcastInDimConverter
|
|
|
|
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::BroadcastInDimOp, false> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::BroadcastInDimOp,
|
2020-07-07 04:57:00 +08:00
|
|
|
false>::DataMovementOpConverter;
|
|
|
|
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(
|
2020-10-24 03:22:21 +08:00
|
|
|
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
|
2020-12-15 02:46:04 +08:00
|
|
|
auto result_type = GetHloOpResultType<false>(broadcast_op);
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_type =
|
|
|
|
broadcast_op.operand().getType().template cast<ShapedType>();
|
|
|
|
unsigned nloops = result_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (operand_type.getRank() == 0) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
SmallVector<AffineExpr, 4> dim_exprs;
|
|
|
|
dim_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (broadcast_op.broadcast_dimensions()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
for (const auto& broadcastDim :
|
2020-10-24 03:22:21 +08:00
|
|
|
enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
|
2020-07-07 04:57:00 +08:00
|
|
|
int size = broadcastDim.value().getSExtValue();
|
2020-10-24 03:22:21 +08:00
|
|
|
bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
|
|
|
|
result_type.getShape()[size] != 1;
|
|
|
|
dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
|
|
|
: b->getAffineDimExpr(size));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-01-12 02:11:39 +08:00
|
|
|
class HloDynamicBroadcastInDimConverter
|
|
|
|
: public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
// Convert only if the producer is an HLO constant. Ideally the pattern
|
|
|
|
// (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
|
|
|
|
// to an Tensor-dialect op similar to TF ConstantLikeOp.
|
|
|
|
if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
|
|
|
|
|
|
|
|
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
|
|
|
|
Value operand = adaptor.operand();
|
|
|
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!operand_type || operand_type.getRank() != 0) return failure();
|
|
|
|
|
|
|
|
Value shape = adaptor.output_dimensions();
|
|
|
|
auto shape_type = shape.getType().cast<RankedTensorType>();
|
|
|
|
int64_t result_rank = shape_type.getDimSize(0);
|
|
|
|
|
|
|
|
SmallVector<Value, 2> dyn_dims;
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
for (int i = 0; i < result_rank; ++i) {
|
|
|
|
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
|
|
|
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
|
|
|
}
|
2021-01-22 19:02:13 +08:00
|
|
|
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!result_type) return failure();
|
2021-01-12 02:11:39 +08:00
|
|
|
|
|
|
|
int64_t nloops = result_type.getRank();
|
|
|
|
Value init = rewriter.create<linalg::InitTensorOp>(
|
|
|
|
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
|
|
|
|
Operation* generic = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc, TypeRange{init.getType()}, ValueRange{operand},
|
|
|
|
/*outputBuffers=*/ValueRange{init},
|
|
|
|
llvm::makeArrayRef(
|
|
|
|
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
|
|
|
|
rewriter.getContext()),
|
|
|
|
rewriter.getMultiDimIdentityMap(nloops)}),
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
|
|
|
});
|
|
|
|
rewriter.replaceOp(op, generic->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
class LhloBroadcastInDimConverter
|
2020-07-09 01:05:32 +08:00
|
|
|
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
2020-07-07 04:57:00 +08:00
|
|
|
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
|
|
|
auto result_shape = result_type.getShape();
|
|
|
|
|
|
|
|
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
|
|
|
|
|
|
|
|
Value operand = std::get<0>(operand_and_dims);
|
|
|
|
auto broadcast_dims = std::get<1>(operand_and_dims);
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
auto operand_type = operand.getType().cast<MemRefType>();
|
|
|
|
|
|
|
|
// For a degenerate case, i.e. broadcasting with expansion of
|
|
|
|
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
|
|
|
|
// Instead the value is loaded and used directly in `linalg.yield`.
|
|
|
|
if (operand_type.getRank() == 1 &&
|
|
|
|
operand_type.getDimSize(0) <
|
|
|
|
result_type.getDimSize(broadcast_dims.front())) {
|
|
|
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
|
|
Value val =
|
|
|
|
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
|
|
|
rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc, /*inputs=*/ValueRange{},
|
|
|
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
2020-07-07 04:57:00 +08:00
|
|
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, val);
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
} else {
|
|
|
|
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
|
|
|
operand_type, &rewriter);
|
|
|
|
rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc, /*inputs=*/ValueRange{operand},
|
|
|
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
|
2020-07-07 04:57:00 +08:00
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, llvm::None);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
|
|
|
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const {
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
2020-07-07 04:57:00 +08:00
|
|
|
Value operand = operand_adaptor.operand();
|
|
|
|
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
|
|
|
|
Value result = operand_adaptor.output();
|
|
|
|
auto result_type = result.getType().cast<MemRefType>();
|
|
|
|
auto result_shape = result_type.getShape();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> operand_strides;
|
|
|
|
int64_t operand_offset;
|
|
|
|
if (failed(getStridesAndOffset(operand_type, operand_strides,
|
|
|
|
operand_offset))) {
|
|
|
|
op.emitOpError() << "Failed to get offset and strides.";
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
|
|
|
|
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
|
|
|
|
linalg::ReassociationIndices collapsed_dims;
|
|
|
|
for (const auto& item :
|
|
|
|
enumerate(op.broadcast_dimensions().getIntValues())) {
|
|
|
|
size_t index = item.index();
|
|
|
|
int dim = item.value().getSExtValue();
|
|
|
|
|
|
|
|
collapsed_dims.push_back(index);
|
|
|
|
|
|
|
|
bool expansion_needed =
|
|
|
|
operand_shape[index] == 1 && result_shape[dim] != 1;
|
|
|
|
if (expansion_needed) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
new_shape.push_back(operand_shape[index]);
|
|
|
|
new_strides.push_back(operand_strides[index]);
|
|
|
|
broadcast_dims.push_back(dim);
|
|
|
|
|
|
|
|
collapsed_dims_list.push_back(collapsed_dims);
|
|
|
|
collapsed_dims.clear();
|
|
|
|
}
|
|
|
|
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
|
|
|
|
// and all dimensions need expansion. Such memref will be reshaped to a 1D
|
|
|
|
// memref with a single element. New shape and strides needs to be updated
|
|
|
|
// accordingly.
|
|
|
|
if (collapsed_dims_list.empty()) {
|
|
|
|
collapsed_dims_list.push_back({});
|
|
|
|
new_shape.push_back(1);
|
|
|
|
new_strides.push_back(1);
|
|
|
|
broadcast_dims.push_back(0);
|
|
|
|
}
|
|
|
|
for (const auto& dims : collapsed_dims) {
|
|
|
|
collapsed_dims_list.back().push_back(dims);
|
|
|
|
}
|
|
|
|
|
|
|
|
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
|
|
|
|
// reduced.
|
|
|
|
if (new_shape.size() < operand_shape.size()) {
|
|
|
|
auto new_memref_type = MemRefType::get(
|
|
|
|
new_shape, operand_type.getElementType(),
|
|
|
|
makeStridedLinearLayoutMap(new_strides, operand_offset,
|
|
|
|
rewriter.getContext()));
|
|
|
|
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
|
|
|
|
operand_adaptor.operand(),
|
|
|
|
collapsed_dims_list);
|
|
|
|
}
|
|
|
|
return std::make_pair(operand, broadcast_dims);
|
|
|
|
}
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
|
2020-10-24 03:22:21 +08:00
|
|
|
ArrayRef<int64_t> broadcast_dims,
|
|
|
|
ArrayRef<int64_t> result_shape,
|
|
|
|
MemRefType operand_type,
|
2020-07-07 04:57:00 +08:00
|
|
|
Builder* b) const {
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned nloops = result_shape.size();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (operand_type.getRank() == 0) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
SmallVector<AffineExpr, 4> dim_exprs;
|
|
|
|
dim_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
|
|
|
|
int size = broadcast_dim.value();
|
2020-07-07 04:57:00 +08:00
|
|
|
bool expansion_needed =
|
2020-10-24 03:22:21 +08:00
|
|
|
operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
|
2020-07-07 04:57:00 +08:00
|
|
|
if (expansion_needed) {
|
|
|
|
op.emitOpError(
|
|
|
|
"BroadcastInDimOp lowering to Linalg does not support size-1 "
|
|
|
|
"dimensions expansion.");
|
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
dim_exprs.push_back(b->getAffineDimExpr(size));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class TransposeConverter
|
|
|
|
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_type =
|
2020-12-15 02:46:04 +08:00
|
|
|
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
SmallVector<AffineExpr, 2> input_exprs;
|
|
|
|
input_exprs.resize(result_type.getRank());
|
2020-07-07 04:57:00 +08:00
|
|
|
for (auto permutation : llvm::enumerate(op.permutation())) {
|
2020-10-24 03:22:21 +08:00
|
|
|
input_exprs[permutation.value().getZExtValue()] =
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getAffineDimExpr(permutation.index());
|
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// Converts reshape ops that can be proven to be either a collapse of dimensions
|
|
|
|
// or expansion of dimensions of the operand.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
OpTy reshape_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
2021-01-28 21:36:37 +08:00
|
|
|
typename OpTy::Adaptor operands(args);
|
2020-10-24 03:22:21 +08:00
|
|
|
ShapedType operand_type =
|
2021-01-28 21:36:37 +08:00
|
|
|
operands.operand().getType().template cast<ShapedType>();
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Compute the reassociation maps for the linalg operation.
|
2020-10-24 03:22:21 +08:00
|
|
|
ArrayRef<int64_t> src_shape =
|
|
|
|
(operand_type.getRank() > result_type.getRank()
|
|
|
|
? operand_type.getShape()
|
|
|
|
: result_type.getShape());
|
|
|
|
ArrayRef<int64_t> dst_shape =
|
|
|
|
(operand_type.getRank() > result_type.getRank()
|
|
|
|
? result_type.getShape()
|
|
|
|
: operand_type.getShape());
|
|
|
|
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
|
|
|
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
|
|
|
dst_shape.size());
|
2021-01-28 21:36:37 +08:00
|
|
|
|
|
|
|
// First scan all dimensions in the source shapes to see whether we have a
|
|
|
|
// perfect case where consecutive dimensions in source are collapsed. For
|
|
|
|
// such case we can just generate one single linalg.reshape.
|
|
|
|
bool is_collapsing_source = true;
|
2020-10-24 03:22:21 +08:00
|
|
|
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
|
|
|
int64_t dst_size = dst_shape[curr_dst_dim];
|
|
|
|
int64_t src_size = src_shape[curr_src_dim];
|
|
|
|
while (src_size < dst_size && curr_src_dim < src_shape.size()) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
|
|
|
src_size *= src_shape[curr_src_dim];
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
if (src_size == dst_size) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
|
|
|
// If the next dim in dst_shape is not 1, treat subsequent dims in
|
|
|
|
// src_shape which are 1 to be collapsed.
|
|
|
|
if (curr_dst_dim == dst_shape.size() - 1 ||
|
|
|
|
dst_shape[curr_dst_dim + 1] != 1) {
|
|
|
|
while (curr_src_dim < src_shape.size() &&
|
|
|
|
src_shape[curr_src_dim] == 1) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
2021-01-28 21:36:37 +08:00
|
|
|
is_collapsing_source = false;
|
2020-08-22 14:26:35 +08:00
|
|
|
break;
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
curr_dst_dim++;
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
2021-01-28 21:36:37 +08:00
|
|
|
is_collapsing_source = false;
|
2020-08-22 14:26:35 +08:00
|
|
|
|
2021-01-28 21:36:37 +08:00
|
|
|
// Otherwise, we need to first reduce all source dimensions into one and
|
|
|
|
// then expand to the destination dimensions.
|
|
|
|
if (!is_collapsing_source) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto get_identity_exprs = [&rewriter](int n) {
|
2020-08-22 14:26:35 +08:00
|
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
|
|
for (int i = 0; i < n; ++i)
|
|
|
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
return exprs;
|
|
|
|
};
|
2020-10-24 03:22:21 +08:00
|
|
|
Location loc = reshape_op.getLoc();
|
|
|
|
int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
|
|
|
|
1, std::multiplies<int64_t>());
|
|
|
|
auto elem_type = operand_type.getElementType();
|
|
|
|
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
2021-01-28 21:36:37 +08:00
|
|
|
// Use operand_type here because we need to collapse all operands
|
|
|
|
// dimensions.
|
|
|
|
get_identity_exprs(operand_type.getShape().size())};
|
2020-10-24 03:22:21 +08:00
|
|
|
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
2021-01-28 21:36:37 +08:00
|
|
|
// Use result_type here because we need to expand to all result
|
|
|
|
// dimensions.
|
|
|
|
get_identity_exprs(result_type.getShape().size())};
|
2020-08-22 14:26:35 +08:00
|
|
|
|
|
|
|
if (isLHLO) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
|
|
|
Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
loc, collapsed_type, args[0], collapsing_map);
|
|
|
|
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
loc, result_type, collapsed_op, expanding_map);
|
2021-02-15 20:17:20 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
|
|
|
|
args[1]);
|
2020-08-22 14:26:35 +08:00
|
|
|
} else {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
|
|
|
|
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
|
|
|
|
loc, collapsed_type, args[0], collapsing_map);
|
2020-08-22 14:26:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, result_type, collapsed_op, expanding_map);
|
2020-08-22 14:26:35 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
if (isLHLO) {
|
2020-10-24 03:22:21 +08:00
|
|
|
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
reshape_op.getLoc(), result_type, args[0], reassociation_map);
|
2021-02-15 20:17:20 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
|
|
|
|
args[1]);
|
2020-07-07 04:57:00 +08:00
|
|
|
} else {
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, result_type, args[0], reassociation_map);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-24 00:18:01 +08:00
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class IotaConverter : public OpConversionPattern<OpTy> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-24 00:18:01 +08:00
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
OpTy iota_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!result_shaped_type) return failure();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_element_type = result_shaped_type.getElementType();
|
|
|
|
if (!result_element_type.isSignlessIntOrFloat()) return failure();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Construct the indexing maps needed for linalg.generic ops.
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned nloops = result_shaped_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-12-24 15:53:08 +08:00
|
|
|
Location loc = iota_op.getLoc();
|
2021-02-04 07:02:21 +08:00
|
|
|
auto dyn_sizes = isLHLO
|
|
|
|
? SmallVector<Value, 2>()
|
|
|
|
: ExtractDynamicSizes(rewriter, loc,
|
|
|
|
GetResultValue<isLHLO>(iota_op));
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
2020-12-24 15:53:08 +08:00
|
|
|
loc,
|
2020-09-23 00:06:55 +08:00
|
|
|
/*resultTensorTypes=*/
|
2020-10-24 03:22:21 +08:00
|
|
|
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
2020-09-23 00:06:55 +08:00
|
|
|
/*inputs=*/ValueRange{},
|
2020-12-24 15:53:08 +08:00
|
|
|
/*outputBuffers=*/
|
|
|
|
isLHLO ? ValueRange{args}
|
2021-02-04 07:02:21 +08:00
|
|
|
: ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
|
|
|
|
dyn_sizes)},
|
2020-07-07 04:57:00 +08:00
|
|
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
2020-07-07 04:57:00 +08:00
|
|
|
ValueRange args) {
|
2020-10-24 03:22:21 +08:00
|
|
|
Value cast_op = nested_builder.create<IndexCastOp>(
|
|
|
|
nested_loc, ivs[iota_op.iota_dimension()],
|
|
|
|
nested_builder.getIntegerType(
|
|
|
|
result_element_type.getIntOrFloatBitWidth()));
|
|
|
|
if (result_element_type.template isa<FloatType>()) {
|
|
|
|
cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
|
|
|
|
result_element_type);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-07-24 00:18:01 +08:00
|
|
|
if (isLHLO)
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(iota_op, llvm::None);
|
2020-07-24 00:18:01 +08:00
|
|
|
else
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(iota_op, linalg_op.result_tensors());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-12-10 15:24:23 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
class ConstConverter : public OpConversionPattern<OpTy> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-12-10 15:24:23 +08:00
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-12-10 15:24:23 +08:00
|
|
|
OpTy const_op, ArrayRef<Value> /*args*/,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-10 15:24:23 +08:00
|
|
|
Location loc = const_op.getLoc();
|
|
|
|
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (value_attr.getType().getRank() != 0) return failure();
|
2020-12-10 15:24:23 +08:00
|
|
|
ReplaceConstOp(loc, const_op, value_attr, rewriter);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
2020-12-10 15:24:23 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
|
|
|
|
DenseElementsAttr value_attr,
|
|
|
|
ConversionPatternRewriter& rewriter) const {
|
|
|
|
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
|
|
|
|
rewriter.replaceOp(op, {std_tensor_const});
|
|
|
|
}
|
|
|
|
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
|
|
|
|
DenseElementsAttr value_attr,
|
|
|
|
ConversionPatternRewriter& rewriter) const {
|
|
|
|
Value std_scalar_const =
|
|
|
|
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
|
|
|
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
|
|
|
|
llvm::None);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
};
|
|
|
|
|
2020-10-29 07:37:38 +08:00
|
|
|
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = reduce_op.getLoc();
|
|
|
|
lmhlo::ReduceOp::Adaptor adaptor(args);
|
|
|
|
auto operand_shape =
|
|
|
|
adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!operand_shape || !operand_shape.hasRank()) {
|
|
|
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// First fill the output buffer with the init value.
|
|
|
|
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
|
|
|
|
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
|
|
|
|
|
|
|
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
|
|
|
SmallVector<int, 4> reduction_dims;
|
|
|
|
for (const auto& dim : dimensions_attr.getIntValues()) {
|
|
|
|
reduction_dims.push_back(dim.getSExtValue());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<AffineExpr, 2> src_exprs;
|
|
|
|
SmallVector<AffineExpr, 2> dst_exprs;
|
|
|
|
SmallVector<StringRef, 4> types;
|
|
|
|
for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
|
|
|
|
bool is_reduced = llvm::is_contained(reduction_dims, i);
|
|
|
|
types.push_back(is_reduced ? getReductionIteratorTypeName()
|
|
|
|
: getParallelIteratorTypeName());
|
|
|
|
|
|
|
|
src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
|
|
if (!is_reduced) {
|
|
|
|
dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
|
|
|
|
|
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
|
2020-12-24 15:53:08 +08:00
|
|
|
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps,
|
|
|
|
types);
|
2020-12-16 00:56:01 +08:00
|
|
|
rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
|
|
|
|
linalg_op.region().end());
|
2020-10-29 07:37:38 +08:00
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard region_guard(rewriter);
|
|
|
|
Block* block = linalg_op.getBody();
|
|
|
|
rewriter.setInsertionPoint(&block->front());
|
|
|
|
|
|
|
|
// The incoming region is operating on buffers, while linalg.generic
|
|
|
|
// expects scalar SSA values. Add some allocs around the original op to
|
|
|
|
// make it compatible.
|
|
|
|
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
|
|
|
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
|
|
|
|
// Now turn the existing signature
|
|
|
|
// (memref<X>, memref<X>, memref<X>) -> ()
|
|
|
|
// into
|
|
|
|
// (X, X) -> X
|
|
|
|
TypeConverter::SignatureConversion signature_converter(3);
|
|
|
|
signature_converter.remapInput(0, alloc_a);
|
|
|
|
signature_converter.remapInput(1, alloc_b);
|
|
|
|
signature_converter.remapInput(2, alloc_res);
|
|
|
|
signature_converter.addInputs(
|
|
|
|
{arg_type.getElementType(), arg_type.getElementType()});
|
|
|
|
Block* entry_block = rewriter.applySignatureConversion(
|
|
|
|
&linalg_op.region(), signature_converter);
|
|
|
|
|
|
|
|
// Store the arguments into the newly allocated buffers.
|
|
|
|
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
|
|
|
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
|
|
|
|
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
|
|
|
|
rewriter.replaceOp(entry_block->getTerminator(), {});
|
|
|
|
|
|
|
|
// Load & yield the result.
|
|
|
|
rewriter.setInsertionPointToEnd(entry_block);
|
|
|
|
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
|
|
|
|
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
// TODO(b/156787842): Support the lowering for dynamic shapes.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class ReverseConverter
|
|
|
|
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_type =
|
2020-12-15 02:46:04 +08:00
|
|
|
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
SmallVector<AffineExpr, 2> input_exprs;
|
|
|
|
input_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
for (int i = 0; i < nloops; ++i)
|
2020-10-24 03:22:21 +08:00
|
|
|
input_exprs.push_back(b->getAffineDimExpr(i));
|
2020-07-07 04:57:00 +08:00
|
|
|
for (auto dim : op.dimensions()) {
|
|
|
|
int i = dim.getZExtValue();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (result_type.isDynamicDim(i)) return {};
|
|
|
|
int n = result_type.getShape()[i];
|
|
|
|
input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto loc = slice_op.getLoc();
|
|
|
|
auto arg_type =
|
|
|
|
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!arg_type || !arg_type.hasRank()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2021-02-05 05:41:18 +08:00
|
|
|
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
|
2020-10-24 03:22:21 +08:00
|
|
|
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
2021-02-05 05:41:18 +08:00
|
|
|
offsets.push_back(rewriter.getI64IntegerAttr(
|
|
|
|
slice_op.start_indices().getValue<int64_t>(i)));
|
|
|
|
sizes.push_back(rewriter.getI64IntegerAttr(
|
|
|
|
slice_op.limit_indices().getValue<int64_t>(i) -
|
|
|
|
slice_op.start_indices().getValue<int64_t>(i)));
|
|
|
|
strides.push_back(
|
|
|
|
rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2021-02-05 05:41:18 +08:00
|
|
|
auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
|
|
|
|
offsets, sizes, strides);
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
|
|
|
|
rewriter.eraseOp(slice_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-01-12 02:33:14 +08:00
|
|
|
enum class DotOperationType {
|
|
|
|
kVectorDot = 0,
|
|
|
|
kMatrixVector = 1,
|
|
|
|
kMatrixMatrix = 2,
|
|
|
|
kUnsupported = 3
|
|
|
|
};
|
|
|
|
|
|
|
|
DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
|
|
|
|
ArrayRef<int64_t> lhs_shape =
|
|
|
|
dot_op.lhs().getType().cast<ShapedType>().getShape();
|
|
|
|
ArrayRef<int64_t> rhs_shape =
|
|
|
|
dot_op.rhs().getType().cast<ShapedType>().getShape();
|
|
|
|
auto shape_matches = [](int64_t a, int64_t b) {
|
|
|
|
return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
|
|
|
|
a == b;
|
|
|
|
};
|
|
|
|
if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
|
|
|
|
shape_matches(lhs_shape[0], rhs_shape[0])) {
|
|
|
|
return DotOperationType::kVectorDot;
|
|
|
|
}
|
|
|
|
if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
|
|
|
|
shape_matches(lhs_shape[1], rhs_shape[0])) {
|
|
|
|
return DotOperationType::kMatrixVector;
|
|
|
|
}
|
|
|
|
if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
|
|
|
|
shape_matches(lhs_shape[1], rhs_shape[0])) {
|
|
|
|
return DotOperationType::kMatrixMatrix;
|
|
|
|
}
|
|
|
|
return DotOperationType::kUnsupported;
|
|
|
|
}
|
|
|
|
|
2021-02-04 07:02:21 +08:00
|
|
|
SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
2021-01-12 02:33:14 +08:00
|
|
|
Value lhs, Value rhs,
|
|
|
|
DotOperationType type) {
|
2021-02-04 07:02:21 +08:00
|
|
|
SmallVector<Value, 2> dyn_shape;
|
2021-01-12 02:33:14 +08:00
|
|
|
switch (type) {
|
|
|
|
case DotOperationType::kMatrixMatrix: {
|
2021-02-04 07:02:21 +08:00
|
|
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
2021-01-12 02:33:14 +08:00
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
2021-02-04 07:02:21 +08:00
|
|
|
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
|
2021-01-12 02:33:14 +08:00
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case DotOperationType::kMatrixVector: {
|
2021-02-04 07:02:21 +08:00
|
|
|
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
2021-01-12 02:33:14 +08:00
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case DotOperationType::kVectorDot:
|
|
|
|
case DotOperationType::kUnsupported:
|
|
|
|
default: {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return dyn_shape;
|
|
|
|
}
|
|
|
|
|
|
|
|
class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
mhlo::DotOp op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
mhlo::DotOp::Adaptor adaptor(args);
|
|
|
|
Type result_type = op.getResult().getType();
|
|
|
|
auto shaped_type = result_type.cast<ShapedType>();
|
|
|
|
DotOperationType op_type = GetDotOperationType(op);
|
|
|
|
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
|
|
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
2021-02-04 07:02:21 +08:00
|
|
|
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
|
|
|
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
|
|
|
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
|
|
|
Value zero_tensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
2021-01-12 02:33:14 +08:00
|
|
|
linalg::LinalgOp linalg_op;
|
|
|
|
switch (op_type) {
|
|
|
|
case DotOperationType::kMatrixMatrix: {
|
|
|
|
linalg_op = rewriter.create<linalg::MatmulOp>(
|
|
|
|
loc, TypeRange{result_type},
|
2021-02-04 07:02:21 +08:00
|
|
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
2021-01-12 02:33:14 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
case DotOperationType::kMatrixVector: {
|
|
|
|
linalg_op = rewriter.create<linalg::MatvecOp>(
|
|
|
|
loc, TypeRange{result_type},
|
2021-02-04 07:02:21 +08:00
|
|
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
2021-01-12 02:33:14 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
case DotOperationType::kVectorDot: {
|
|
|
|
linalg_op = rewriter.create<linalg::DotOp>(
|
|
|
|
loc, TypeRange{result_type},
|
2021-02-04 07:02:21 +08:00
|
|
|
ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
|
2021-01-12 02:33:14 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
case DotOperationType::kUnsupported:
|
|
|
|
default: {
|
|
|
|
return op.emitError("unsupported dot operation type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, linalg_op->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-01-13 14:07:29 +08:00
|
|
|
SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
|
|
|
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
|
|
|
SmallVector<Value, 8> dyn_shape;
|
|
|
|
if (result_type.isDynamicDim(0))
|
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
|
|
|
if (result_type.isDynamicDim(1))
|
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
|
|
|
|
if (result_type.isDynamicDim(2))
|
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
|
|
|
|
return dyn_shape;
|
|
|
|
}
|
|
|
|
|
|
|
|
class DotGeneralOpOnTensorsConversion
|
|
|
|
: public OpConversionPattern<mhlo::DotGeneralOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
mhlo::DotGeneralOp op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
|
|
|
|
auto lhs_bathcing_dims =
|
|
|
|
Extract1DVector(dim_numbers.lhs_batching_dimensions());
|
|
|
|
auto rhs_bathcing_dims =
|
|
|
|
Extract1DVector(dim_numbers.rhs_batching_dimensions());
|
|
|
|
auto lhs_contracting_dims =
|
|
|
|
Extract1DVector(dim_numbers.lhs_contracting_dimensions());
|
|
|
|
auto rhs_contracting_dims =
|
|
|
|
Extract1DVector(dim_numbers.rhs_contracting_dimensions());
|
|
|
|
if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected lhs batching dimensions exactly {0}");
|
|
|
|
}
|
|
|
|
if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected rhs batching dimensions exactly {0}");
|
|
|
|
}
|
|
|
|
if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected lhs contracting dimensions exactly {2}");
|
|
|
|
}
|
|
|
|
if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected rhs contracting dimensions exactly {1}");
|
|
|
|
}
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
mhlo::DotGeneralOp::Adaptor adaptor(args);
|
|
|
|
Type result_type = op.getResult().getType();
|
|
|
|
auto shaped_type = result_type.cast<ShapedType>();
|
|
|
|
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
|
|
|
|
rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
|
|
|
|
auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
|
|
|
|
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
2021-02-04 07:02:21 +08:00
|
|
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
|
|
|
Value zero_tensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
2021-01-13 14:07:29 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
|
|
|
|
loc, /*resultTensorTypes=*/TypeRange{result_type},
|
|
|
|
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
|
2021-02-04 07:02:21 +08:00
|
|
|
/*outputBuffers=*/ValueRange{zero_tensor});
|
2021-01-13 14:07:29 +08:00
|
|
|
rewriter.replaceOp(op, linalg_op.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2021-01-28 21:44:49 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
OpTy op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
// Only convert the body of reduction ops to std ops.
|
|
|
|
auto parent_op = op.getOperation()->getParentRegion()->getParentOp();
|
|
|
|
if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>(
|
|
|
|
parent_op)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (!op.getResult().getType().template isa<TensorType>()) return failure();
|
|
|
|
if (llvm::all_of(args, [](Value arg) {
|
|
|
|
return arg.getType().template isa<TensorType>();
|
|
|
|
})) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
|
|
|
|
args, &rewriter);
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
|
|
|
|
OpBuilder& b, Location loc, Value arg, ShapedType result_type,
|
|
|
|
ArrayRef<int64_t> reduction_dims) {
|
|
|
|
llvm::SmallSetVector<int, 4> s;
|
|
|
|
for (auto dim : reduction_dims) s.insert(dim);
|
|
|
|
|
|
|
|
SmallVector<unsigned, 4> parallel_dims;
|
|
|
|
SmallVector<Value, 8> dyn_shape;
|
|
|
|
int rank = arg.getType().cast<RankedTensorType>().getRank();
|
|
|
|
for (int i = 0, j = 0; i < rank; ++i) {
|
|
|
|
if (s.count(i)) continue;
|
|
|
|
if (!result_type.isDynamicDim(j++)) continue;
|
|
|
|
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
|
|
|
|
}
|
|
|
|
|
|
|
|
return dyn_shape;
|
|
|
|
}
|
|
|
|
|
|
|
|
class ReduceRegionReturnOpConversion
|
|
|
|
: public OpConversionPattern<mhlo::ReturnOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
mhlo::ReturnOp op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
mhlo::ReduceOp op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
mhlo::ReduceOp::Adaptor adaptor(args);
|
|
|
|
if (op.getNumOperands() != 2) {
|
|
|
|
return op.emitError("expects exactly two operands");
|
|
|
|
}
|
|
|
|
Value src = adaptor.operands()[0];
|
|
|
|
auto src_type = src.getType().cast<ShapedType>();
|
|
|
|
int src_rank = src_type.getRank();
|
|
|
|
if (!src_rank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expects known-rank args");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check if init_value is constant. If so, inline the value into the region.
|
|
|
|
Value init_value = adaptor.init_values()[0];
|
|
|
|
Attribute init_const_val = GetInitValueAsConst(init_value);
|
|
|
|
if (init_const_val) {
|
|
|
|
init_value = rewriter.create<ConstantOp>(
|
|
|
|
init_value.getDefiningOp()->getLoc(), init_const_val);
|
|
|
|
} else {
|
|
|
|
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare indexing maps for linalg generic op. The elements are for src and
|
|
|
|
// dst. Transpose `src` to make the reduction loops be the innermost,
|
|
|
|
// because it's easier to fully utilize processors.
|
|
|
|
SmallVector<AffineMap, 3> indexing_maps;
|
|
|
|
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
|
|
|
|
indexing_maps.emplace_back(GetTransposeMapForReduction(
|
|
|
|
rewriter.getContext(), src_rank, reduction_dims));
|
|
|
|
|
|
|
|
// The indexing map of `dst` should drop the reduction loops. Since the
|
|
|
|
// reduction loops now are all in the innermost, drops
|
|
|
|
// `reduction_dims.size()` dimensions. We don't need an inverse permutation
|
|
|
|
// here because they are the same.
|
|
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
|
|
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
|
|
|
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
|
|
|
|
exprs, rewriter.getContext()));
|
|
|
|
|
|
|
|
SmallVector<Value, 2> inputs = {adaptor.operands()[0]};
|
|
|
|
Type result_type = op.getResult(0).getType();
|
|
|
|
auto shaped_type = result_type.cast<ShapedType>();
|
|
|
|
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
|
|
|
|
rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
|
|
|
|
reduction_dims);
|
2021-02-04 07:02:21 +08:00
|
|
|
auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
|
|
|
|
Value filled_tensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
|
|
|
|
.getResult(0);
|
2021-01-28 21:44:49 +08:00
|
|
|
|
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
|
2021-02-04 07:02:21 +08:00
|
|
|
/*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
|
2021-01-28 21:44:49 +08:00
|
|
|
GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
|
|
|
|
|
|
|
|
// Convert the signature of the body. The reduce op region apply function
|
|
|
|
// has a signature (lhs, rhs) -> output, all of the same tensor type t.
|
|
|
|
// This is converted to a function with the same signature but with
|
|
|
|
// element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
|
|
|
|
// be converted to "(f32, f32, f32)".
|
|
|
|
Region& region = linalg_op.region();
|
|
|
|
rewriter.inlineRegionBefore(op.body(), region, region.end());
|
|
|
|
TypeConverter::SignatureConversion signatureConverter(2);
|
|
|
|
signatureConverter.addInputs(0, src_type.getElementType());
|
|
|
|
signatureConverter.addInputs(1, src_type.getElementType());
|
|
|
|
rewriter.applySignatureConversion(®ion, signatureConverter);
|
|
|
|
rewriter.replaceOp(op, linalg_op.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|
|
|
OwningRewritePatternList* patterns) {
|
|
|
|
// clang-format off
|
2020-07-09 01:05:32 +08:00
|
|
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
2020-12-10 15:24:23 +08:00
|
|
|
ConstConverter<lmhlo::ConstOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConvToLinalgConverter,
|
2020-07-24 00:18:01 +08:00
|
|
|
IotaConverter<lmhlo::IotaOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
LhloBroadcastInDimConverter,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
2020-10-05 20:06:35 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
2021-01-15 22:44:28 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::ClampOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::CopyOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
2020-08-31 23:15:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
2021-01-22 19:02:13 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::LogisticOp>,
|
2021-01-13 21:36:19 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::Log1pOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
2020-09-29 20:58:52 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
2020-12-22 07:26:38 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::PowOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
2020-12-08 05:01:25 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::XorOp>,
|
2020-10-29 07:37:38 +08:00
|
|
|
ReduceConverter,
|
2020-07-09 01:05:32 +08:00
|
|
|
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
|
|
|
ReverseConverter<lmhlo::ReverseOp>,
|
|
|
|
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
2020-10-29 07:37:38 +08:00
|
|
|
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
2020-09-05 06:37:02 +08:00
|
|
|
SliceConverter,
|
|
|
|
TransposeConverter<lmhlo::TransposeOp>
|
2020-07-07 04:57:00 +08:00
|
|
|
>(context);
|
|
|
|
// clang-format on
|
|
|
|
}
|
|
|
|
|
|
|
|
// Converts LHLO ops to Linalg generic.
|
2020-07-09 01:05:32 +08:00
|
|
|
// Sample result for lmhlo::AddOp.
|
2020-07-07 04:57:00 +08:00
|
|
|
//
|
2020-07-09 01:05:32 +08:00
|
|
|
// "lmhlo.add"(%arg1, %arg2, %out) :
|
2020-07-07 04:57:00 +08:00
|
|
|
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
|
|
//
|
|
|
|
// will be converted to
|
|
|
|
//
|
|
|
|
// #map0 = (d0, d1) -> (d0, d1)
|
|
|
|
// "linalg.generic"(%arg1, %arg2, %out) ( {
|
|
|
|
// ^bb0(%arg4: f32, %arg5: f32):
|
|
|
|
// %0 = addf %arg4, %arg5 : f32
|
|
|
|
// "linalg.yield"(%0) : (f32) -> ()
|
|
|
|
// }) {
|
|
|
|
// indexing_maps = [#map0, #map0, #map0],
|
|
|
|
// iterator_types = ["parallel", "parallel"],
|
|
|
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
2020-07-29 07:12:08 +08:00
|
|
|
struct LhloLegalizeToLinalgPass
|
|
|
|
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry& registry) const override {
|
2021-02-13 00:30:51 +08:00
|
|
|
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
|
2020-08-26 11:30:05 +08:00
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
ConversionTarget target(getContext());
|
2021-01-21 17:21:23 +08:00
|
|
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
2021-02-13 00:30:51 +08:00
|
|
|
math::MathDialect, StandardOpsDialect,
|
|
|
|
AffineDialect>();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
auto func = getFunction();
|
|
|
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
2020-07-07 04:57:00 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
struct HloLegalizeToLinalgPass
|
|
|
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry& registry) const override {
|
2021-01-21 17:21:23 +08:00
|
|
|
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
|
2021-02-13 00:30:51 +08:00
|
|
|
complex::ComplexDialect, math::MathDialect>();
|
2020-08-26 11:30:05 +08:00
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
ConversionTarget target(getContext());
|
2021-01-21 17:21:23 +08:00
|
|
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
2021-02-13 00:30:51 +08:00
|
|
|
math::MathDialect, StandardOpsDialect,
|
|
|
|
tensor::TensorDialect, scf::SCFDialect>();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
auto func = getFunction();
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
2020-07-07 04:57:00 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
namespace lmhlo {
|
2020-07-07 04:57:00 +08:00
|
|
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<LhloLegalizeToLinalgPass>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-07-09 01:05:32 +08:00
|
|
|
} // namespace lmhlo
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-07-07 12:51:24 +08:00
|
|
|
namespace mhlo {
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|
|
|
OwningRewritePatternList* patterns) {
|
2021-01-28 21:44:49 +08:00
|
|
|
patterns->insert<
|
|
|
|
BroadcastConverter<mhlo::BroadcastOp, false>,
|
|
|
|
ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
|
|
|
|
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ClampOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::PowOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
|
|
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
|
|
|
ReverseConverter<mhlo::ReverseOp, false>,
|
|
|
|
TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
|
|
|
|
DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
|
|
|
|
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
|
|
|
ReduceRegionXLAOpConversion<mhlo::MinOp>,
|
|
|
|
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
|
|
|
|
ReduceRegionReturnOpConversion>(context);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<HloLegalizeToLinalgPass>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-07-07 12:51:24 +08:00
|
|
|
} // namespace mhlo
|
2020-07-07 04:57:00 +08:00
|
|
|
} // namespace mlir
|