2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
|
|
|
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
|
|
|
|
2020-08-22 14:26:35 +08:00
|
|
|
#include <numeric>
|
|
|
|
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "llvm/ADT/STLExtras.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
2020-12-02 05:17:12 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2020-12-15 16:58:42 +08:00
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/Location.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
|
|
#include "mlir/IR/Operation.h"
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "mlir/IR/OperationSupport.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2020-09-23 00:06:55 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
|
|
|
|
static constexpr StringRef kParallelIterType = "parallel";
|
|
|
|
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
Value GetResultValue(Operation* op) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType GetHloOpResultType(Operation* op) {
|
|
|
|
return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <bool isLHLO = true>
|
2020-12-15 02:46:04 +08:00
|
|
|
bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto verify_type = [&](Value val) -> bool {
|
2020-07-07 04:57:00 +08:00
|
|
|
return (isLHLO && val.getType().isa<MemRefType>()) ||
|
|
|
|
(!isLHLO && val.getType().isa<RankedTensorType>());
|
|
|
|
};
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!llvm::all_of(op->getOperands(), verify_type)) return false;
|
2020-07-07 04:57:00 +08:00
|
|
|
return isLHLO ? op->getResults().empty()
|
2020-10-24 03:22:21 +08:00
|
|
|
: llvm::all_of(op->getResults(), verify_type);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
OpTy op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = op.getLoc();
|
2020-09-23 00:06:55 +08:00
|
|
|
ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!t0) return failure();
|
|
|
|
|
|
|
|
unsigned nloops = t0.getRank();
|
|
|
|
auto fail = [&](ShapedType t) {
|
|
|
|
return !t || !t.hasRank() || t.getRank() != nloops ||
|
|
|
|
!(t.getElementType().isSignlessIntOrFloat() ||
|
|
|
|
t.getElementType().isa<ComplexType>());
|
|
|
|
};
|
|
|
|
if (llvm::any_of(args,
|
|
|
|
[&](Value v) {
|
|
|
|
return fail(v.getType().dyn_cast<ShapedType>());
|
|
|
|
}) ||
|
|
|
|
llvm::any_of(op.getOperation()->getResultTypes(),
|
|
|
|
[&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
|
|
|
|
return emitError(loc,
|
|
|
|
"lhlo to linalg conversion expects ranked args of "
|
|
|
|
"signless int, float or complex element type with ")
|
|
|
|
<< nloops << " parallel iterators: " << *(op.getOperation());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Construct the indexing maps needed for linalg.generic ops.
|
2020-10-24 03:22:21 +08:00
|
|
|
SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// This doesnt account for implicit broadcast, but the working assumption
|
2020-09-23 00:06:55 +08:00
|
|
|
// in HLO/LHLO is that are broadcasts are made explicit.
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
if (isLHLO && !nloops) return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
int num_inputs = (isLHLO ? args.size() - 1 : args.size());
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
ValueRange inputs(args.take_front(num_inputs));
|
2020-09-23 00:06:55 +08:00
|
|
|
for (Value in : inputs)
|
2020-10-24 03:22:21 +08:00
|
|
|
body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
ValueRange output_buffers(args.take_back(args.size() - num_inputs));
|
|
|
|
for (Value out : output_buffers)
|
|
|
|
body_result_types.emplace_back(getElementTypeOrSelf(out.getType()));
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
if (!isLHLO) {
|
|
|
|
// HLO operations have return as tensor types.
|
2020-10-24 03:22:21 +08:00
|
|
|
assert(body_result_types.empty() &&
|
2020-07-07 04:57:00 +08:00
|
|
|
"When lowering HLO ops result can't be part of arguments");
|
|
|
|
Value result = op.getOperation()->getResult(0);
|
2020-10-24 03:22:21 +08:00
|
|
|
body_result_types.push_back(getElementTypeOrSelf(result));
|
|
|
|
op_result_types.push_back(result.getType());
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap common_indexing_map =
|
2020-09-23 00:06:55 +08:00
|
|
|
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
|
|
|
: AffineMap::get(nloops, 0, rewriter.getContext());
|
|
|
|
SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
|
2020-10-24 03:22:21 +08:00
|
|
|
common_indexing_map);
|
2020-09-23 00:06:55 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc, op_result_types, inputs, output_buffers,
|
2020-09-23 00:06:55 +08:00
|
|
|
/*initTensors=*/ValueRange{}, indexing_maps,
|
2020-07-07 04:57:00 +08:00
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
2020-07-09 01:05:32 +08:00
|
|
|
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
2020-07-07 04:57:00 +08:00
|
|
|
// That method needs to be moved out of there.
|
2020-10-24 03:22:21 +08:00
|
|
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
|
|
|
|
op, body_result_types,
|
2020-09-23 00:06:55 +08:00
|
|
|
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
|
2020-10-24 03:22:21 +08:00
|
|
|
nested_builder.create<linalg::YieldOp>(loc, op_result);
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename LhloOp>
|
|
|
|
class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<LhloOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
LhloOp lhlo_op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = lhlo_op.getLoc();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto arg_type =
|
2020-07-07 04:57:00 +08:00
|
|
|
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
|
|
|
|
(arg_type.getRank() != 0)) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create two loads from the input.
|
|
|
|
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
|
|
|
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
2020-07-09 01:05:32 +08:00
|
|
|
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
2020-10-24 03:22:21 +08:00
|
|
|
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
|
|
|
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
2020-07-07 04:57:00 +08:00
|
|
|
&rewriter);
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
|
2020-07-07 04:57:00 +08:00
|
|
|
rewriter.eraseOp(lhlo_op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-07-09 01:05:32 +08:00
|
|
|
// lmhlo.convolution conversion pattern.
|
2020-07-07 04:57:00 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
/// Converts lmhlo.convolution operation to a linalg.conv op.
|
|
|
|
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// This code has been adapted from IREE's
|
2020-07-07 12:51:24 +08:00
|
|
|
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
2020-07-07 04:57:00 +08:00
|
|
|
LogicalResult matchAndRewrite(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::ConvOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
// Check validity of dimension information.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
|
2020-07-07 04:57:00 +08:00
|
|
|
op.dimension_numbers()) {
|
2020-10-24 03:22:21 +08:00
|
|
|
const int input_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.input_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for input should follow the order of
|
|
|
|
// batch_count, spatial_dims..., input_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
|
|
|
|
dimension_numbers.input_feature_dimension().getInt() !=
|
|
|
|
(input_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
const int kernel_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.kernel_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for filter should follow the order of
|
|
|
|
// spatial_dims..., input_feature_count, num_output_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
|
|
|
|
kernel_spatial_rank ||
|
|
|
|
dimension_numbers.kernel_output_feature_dimension().getInt() !=
|
|
|
|
(kernel_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
const int output_spatial_rank =
|
|
|
|
llvm::size(dimension_numbers.output_spatial_dimensions());
|
2020-07-07 04:57:00 +08:00
|
|
|
// The dimensions for output should follow the order of
|
|
|
|
// batch_count, spatial_dims.., output_feature_count.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
|
|
|
|
dimension_numbers.output_feature_dimension().getInt() !=
|
|
|
|
(output_spatial_rank + 1))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (input_spatial_rank != output_spatial_rank ||
|
|
|
|
input_spatial_rank != kernel_spatial_rank)
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto input_spatial_dim =
|
|
|
|
dimension_numbers.input_spatial_dimensions().begin();
|
|
|
|
auto kernel_spatial_dim =
|
|
|
|
dimension_numbers.kernel_spatial_dimensions().begin();
|
|
|
|
auto output_spatial_dim =
|
|
|
|
dimension_numbers.output_spatial_dimensions().begin();
|
2020-07-07 04:57:00 +08:00
|
|
|
// Check if spatial dims are ordered correctly.
|
2020-10-24 03:22:21 +08:00
|
|
|
for (int i = 0; i < input_spatial_rank; ++i) {
|
2020-07-07 04:57:00 +08:00
|
|
|
const int dim = i + 1;
|
2020-10-24 03:22:21 +08:00
|
|
|
if ((*input_spatial_dim++).getZExtValue() != dim ||
|
|
|
|
(*output_spatial_dim++).getZExtValue() != dim ||
|
|
|
|
(*kernel_spatial_dim++).getZExtValue() != i)
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: LHS dilation for deconvolution not supported yet.
|
2020-12-11 08:38:26 +08:00
|
|
|
// TODO(jurahul): Window reversal is not supported yet.
|
|
|
|
if (op.lhs_dilation() || op.hasWindowReversal()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<Attribute, 4> strides;
|
2020-10-24 03:22:21 +08:00
|
|
|
if (auto window_strides = op.window_strides()) {
|
|
|
|
auto range = window_strides->getAttributeValues();
|
2020-07-07 04:57:00 +08:00
|
|
|
strides.assign(range.begin(), range.end());
|
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
auto strides_arg = ArrayAttr::get(strides, op.getContext());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
llvm::SmallVector<Attribute, 2> dilation;
|
2020-10-24 03:22:21 +08:00
|
|
|
if (auto rhs_dilation = op.rhs_dilation()) {
|
|
|
|
auto range = rhs_dilation->getAttributeValues();
|
2020-07-07 04:57:00 +08:00
|
|
|
dilation.assign(range.begin(), range.end());
|
|
|
|
} else {
|
|
|
|
// Default dilation of 1.
|
|
|
|
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
auto dilation_arg = ArrayAttr::get(dilation, op.getContext());
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Set padding only if it is non-zero.
|
|
|
|
DenseIntElementsAttr padding = op.paddingAttr();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!padding ||
|
|
|
|
!llvm::any_of(padding.getValues<APInt>(),
|
|
|
|
[](APInt int_val) { return !int_val.isNullValue(); })) {
|
2020-07-07 04:57:00 +08:00
|
|
|
padding = nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The order of input and filter are switched with linalg.conv.
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-09 11:32:16 +08:00
|
|
|
/// Base class for lowering HLO operations that have one operand and one result,
|
2020-07-07 04:57:00 +08:00
|
|
|
/// and are semantically equivalent to a copy of the input to the output (like
|
|
|
|
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
|
|
|
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
|
|
|
|
/// and the output.
|
|
|
|
template <typename Derived, typename OpTy, bool isLHLO = true>
|
|
|
|
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
OpTy op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
|
|
|
auto result_type = GetHloOpResultType<isLHLO>(op);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
SmallVector<AffineMap, 2> indexing_maps =
|
|
|
|
Derived::getIndexingMaps(op, &rewriter);
|
|
|
|
if (indexing_maps.empty()) return failure();
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
auto loc = op.getLoc();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc,
|
2020-10-24 03:22:21 +08:00
|
|
|
/*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
|
2020-09-23 00:06:55 +08:00
|
|
|
/*inputs=*/args.front(),
|
|
|
|
/*outputBuffers=*/isLHLO ? ValueRange{args.back()} : ValueRange{},
|
|
|
|
/*initTensor=*/ValueRange{}, indexing_maps,
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Pattern to convert BroadcastOp to Linalg ops.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class BroadcastConverter
|
|
|
|
: public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<BroadcastConverter, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
|
2020-07-07 04:57:00 +08:00
|
|
|
Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
ShapedType input_type =
|
|
|
|
broadcast_op.operand().getType().template cast<ShapedType>();
|
|
|
|
unsigned input_rank = input_type.getRank();
|
2020-12-15 02:46:04 +08:00
|
|
|
unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
|
|
|
|
// the input's dimensions.
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
|
|
|
|
SmallVector<AffineExpr, 4> input_dim_exprs;
|
|
|
|
input_dim_exprs.reserve(input_rank);
|
|
|
|
for (int i = 0; i < input_rank; ++i) {
|
|
|
|
input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap input_map;
|
2020-07-07 04:57:00 +08:00
|
|
|
MLIRContext* context = b->getContext();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (input_dim_exprs.empty()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
|
2020-07-07 04:57:00 +08:00
|
|
|
} else {
|
2020-10-24 03:22:21 +08:00
|
|
|
input_map =
|
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
return {input_map, b->getMultiDimIdentityMap(nloops)};
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class HloBroadcastInDimConverter
|
|
|
|
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::BroadcastInDimOp, false> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::BroadcastInDimOp,
|
2020-07-07 04:57:00 +08:00
|
|
|
false>::DataMovementOpConverter;
|
|
|
|
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(
|
2020-10-24 03:22:21 +08:00
|
|
|
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
|
2020-12-15 02:46:04 +08:00
|
|
|
auto result_type = GetHloOpResultType<false>(broadcast_op);
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_type =
|
|
|
|
broadcast_op.operand().getType().template cast<ShapedType>();
|
|
|
|
unsigned nloops = result_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (operand_type.getRank() == 0) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
SmallVector<AffineExpr, 4> dim_exprs;
|
|
|
|
dim_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (broadcast_op.broadcast_dimensions()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
for (const auto& broadcastDim :
|
2020-10-24 03:22:21 +08:00
|
|
|
enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
|
2020-07-07 04:57:00 +08:00
|
|
|
int size = broadcastDim.value().getSExtValue();
|
2020-10-24 03:22:21 +08:00
|
|
|
bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
|
|
|
|
result_type.getShape()[size] != 1;
|
|
|
|
dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
|
|
|
: b->getAffineDimExpr(size));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
class LhloBroadcastInDimConverter
|
2020-07-09 01:05:32 +08:00
|
|
|
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
2020-07-07 04:57:00 +08:00
|
|
|
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
|
|
|
auto result_shape = result_type.getShape();
|
|
|
|
|
|
|
|
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
|
|
|
|
|
|
|
|
Value operand = std::get<0>(operand_and_dims);
|
|
|
|
auto broadcast_dims = std::get<1>(operand_and_dims);
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
auto operand_type = operand.getType().cast<MemRefType>();
|
|
|
|
|
|
|
|
// For a degenerate case, i.e. broadcasting with expansion of
|
|
|
|
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
|
|
|
|
// Instead the value is loaded and used directly in `linalg.yield`.
|
|
|
|
if (operand_type.getRank() == 1 &&
|
|
|
|
operand_type.getDimSize(0) <
|
|
|
|
result_type.getDimSize(broadcast_dims.front())) {
|
|
|
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
|
|
|
Value val =
|
|
|
|
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
|
|
|
rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc, /*inputs=*/ValueRange{},
|
|
|
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
2020-07-07 04:57:00 +08:00
|
|
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, val);
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
} else {
|
|
|
|
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
|
|
|
operand_type, &rewriter);
|
|
|
|
rewriter.create<linalg::GenericOp>(
|
2020-09-23 00:06:55 +08:00
|
|
|
loc, /*inputs=*/ValueRange{operand},
|
|
|
|
/*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
|
2020-07-07 04:57:00 +08:00
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
|
|
|
|
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, llvm::None);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
|
|
|
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const {
|
2020-07-09 01:05:32 +08:00
|
|
|
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
2020-07-07 04:57:00 +08:00
|
|
|
Value operand = operand_adaptor.operand();
|
|
|
|
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
|
|
|
|
Value result = operand_adaptor.output();
|
|
|
|
auto result_type = result.getType().cast<MemRefType>();
|
|
|
|
auto result_shape = result_type.getShape();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> operand_strides;
|
|
|
|
int64_t operand_offset;
|
|
|
|
if (failed(getStridesAndOffset(operand_type, operand_strides,
|
|
|
|
operand_offset))) {
|
|
|
|
op.emitOpError() << "Failed to get offset and strides.";
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
|
|
|
|
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
|
|
|
|
linalg::ReassociationIndices collapsed_dims;
|
|
|
|
for (const auto& item :
|
|
|
|
enumerate(op.broadcast_dimensions().getIntValues())) {
|
|
|
|
size_t index = item.index();
|
|
|
|
int dim = item.value().getSExtValue();
|
|
|
|
|
|
|
|
collapsed_dims.push_back(index);
|
|
|
|
|
|
|
|
bool expansion_needed =
|
|
|
|
operand_shape[index] == 1 && result_shape[dim] != 1;
|
|
|
|
if (expansion_needed) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
new_shape.push_back(operand_shape[index]);
|
|
|
|
new_strides.push_back(operand_strides[index]);
|
|
|
|
broadcast_dims.push_back(dim);
|
|
|
|
|
|
|
|
collapsed_dims_list.push_back(collapsed_dims);
|
|
|
|
collapsed_dims.clear();
|
|
|
|
}
|
|
|
|
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
|
|
|
|
// and all dimensions need expansion. Such memref will be reshaped to a 1D
|
|
|
|
// memref with a single element. New shape and strides needs to be updated
|
|
|
|
// accordingly.
|
|
|
|
if (collapsed_dims_list.empty()) {
|
|
|
|
collapsed_dims_list.push_back({});
|
|
|
|
new_shape.push_back(1);
|
|
|
|
new_strides.push_back(1);
|
|
|
|
broadcast_dims.push_back(0);
|
|
|
|
}
|
|
|
|
for (const auto& dims : collapsed_dims) {
|
|
|
|
collapsed_dims_list.back().push_back(dims);
|
|
|
|
}
|
|
|
|
|
|
|
|
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
|
|
|
|
// reduced.
|
|
|
|
if (new_shape.size() < operand_shape.size()) {
|
|
|
|
auto new_memref_type = MemRefType::get(
|
|
|
|
new_shape, operand_type.getElementType(),
|
|
|
|
makeStridedLinearLayoutMap(new_strides, operand_offset,
|
|
|
|
rewriter.getContext()));
|
|
|
|
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
|
|
|
|
operand_adaptor.operand(),
|
|
|
|
collapsed_dims_list);
|
|
|
|
}
|
|
|
|
return std::make_pair(operand, broadcast_dims);
|
|
|
|
}
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
|
2020-10-24 03:22:21 +08:00
|
|
|
ArrayRef<int64_t> broadcast_dims,
|
|
|
|
ArrayRef<int64_t> result_shape,
|
|
|
|
MemRefType operand_type,
|
2020-07-07 04:57:00 +08:00
|
|
|
Builder* b) const {
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned nloops = result_shape.size();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
2020-10-24 03:22:21 +08:00
|
|
|
if (operand_type.getRank() == 0) {
|
2020-07-07 04:57:00 +08:00
|
|
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto operand_shape = operand_type.getShape();
|
|
|
|
SmallVector<AffineExpr, 4> dim_exprs;
|
|
|
|
dim_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
|
|
|
|
int size = broadcast_dim.value();
|
2020-07-07 04:57:00 +08:00
|
|
|
bool expansion_needed =
|
2020-10-24 03:22:21 +08:00
|
|
|
operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
|
2020-07-07 04:57:00 +08:00
|
|
|
if (expansion_needed) {
|
|
|
|
op.emitOpError(
|
|
|
|
"BroadcastInDimOp lowering to Linalg does not support size-1 "
|
|
|
|
"dimensions expansion.");
|
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
dim_exprs.push_back(b->getAffineDimExpr(size));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class TransposeConverter
|
|
|
|
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_type =
|
2020-12-15 02:46:04 +08:00
|
|
|
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
SmallVector<AffineExpr, 2> input_exprs;
|
|
|
|
input_exprs.resize(result_type.getRank());
|
2020-07-07 04:57:00 +08:00
|
|
|
for (auto permutation : llvm::enumerate(op.permutation())) {
|
2020-10-24 03:22:21 +08:00
|
|
|
input_exprs[permutation.value().getZExtValue()] =
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getAffineDimExpr(permutation.index());
|
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// Converts reshape ops that can be proven to be either a collapse of dimensions
|
|
|
|
// or expansion of dimensions of the operand.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
OpTy reshape_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
2020-10-24 03:22:21 +08:00
|
|
|
ShapedType operand_type =
|
|
|
|
reshape_op.operand().getType().template cast<ShapedType>();
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
|
2020-07-07 04:57:00 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Compute the reassociation maps for the linalg operation.
|
2020-10-24 03:22:21 +08:00
|
|
|
ArrayRef<int64_t> src_shape =
|
|
|
|
(operand_type.getRank() > result_type.getRank()
|
|
|
|
? operand_type.getShape()
|
|
|
|
: result_type.getShape());
|
|
|
|
ArrayRef<int64_t> dst_shape =
|
|
|
|
(operand_type.getRank() > result_type.getRank()
|
|
|
|
? result_type.getShape()
|
|
|
|
: operand_type.getShape());
|
|
|
|
unsigned curr_src_dim = 0, curr_dst_dim = 0;
|
|
|
|
SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
|
|
|
|
dst_shape.size());
|
|
|
|
bool is_expanding_or_collapsing = true;
|
|
|
|
while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
|
|
|
|
int64_t dst_size = dst_shape[curr_dst_dim];
|
|
|
|
int64_t src_size = src_shape[curr_src_dim];
|
|
|
|
while (src_size < dst_size && curr_src_dim < src_shape.size()) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
|
|
|
src_size *= src_shape[curr_src_dim];
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
if (src_size == dst_size) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
|
|
|
// If the next dim in dst_shape is not 1, treat subsequent dims in
|
|
|
|
// src_shape which are 1 to be collapsed.
|
|
|
|
if (curr_dst_dim == dst_shape.size() - 1 ||
|
|
|
|
dst_shape[curr_dst_dim + 1] != 1) {
|
|
|
|
while (curr_src_dim < src_shape.size() &&
|
|
|
|
src_shape[curr_src_dim] == 1) {
|
|
|
|
reassociation_map[curr_dst_dim].push_back(
|
|
|
|
rewriter.getAffineDimExpr(curr_src_dim++));
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
2020-10-24 03:22:21 +08:00
|
|
|
is_expanding_or_collapsing = false;
|
2020-08-22 14:26:35 +08:00
|
|
|
break;
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
curr_dst_dim++;
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
|
|
|
is_expanding_or_collapsing = false;
|
2020-08-22 14:26:35 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!is_expanding_or_collapsing) {
|
|
|
|
auto get_identity_exprs = [&rewriter](int n) {
|
2020-08-22 14:26:35 +08:00
|
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
|
|
for (int i = 0; i < n; ++i)
|
|
|
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
return exprs;
|
|
|
|
};
|
2020-10-24 03:22:21 +08:00
|
|
|
Location loc = reshape_op.getLoc();
|
|
|
|
int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
|
|
|
|
1, std::multiplies<int64_t>());
|
|
|
|
auto elem_type = operand_type.getElementType();
|
|
|
|
SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
|
|
|
|
get_identity_exprs(dst_shape.size())};
|
|
|
|
SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
|
|
|
|
get_identity_exprs(src_shape.size())};
|
2020-08-22 14:26:35 +08:00
|
|
|
|
|
|
|
if (isLHLO) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto collapsed_type = MemRefType::get({total_elems}, elem_type);
|
|
|
|
Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
loc, collapsed_type, args[0], collapsing_map);
|
|
|
|
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
loc, result_type, collapsed_op, expanding_map);
|
2020-08-22 14:26:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
|
2020-08-22 14:26:35 +08:00
|
|
|
/*outputPermutation =*/nullptr);
|
|
|
|
} else {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
|
|
|
|
Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
|
|
|
|
loc, collapsed_type, args[0], collapsing_map);
|
2020-08-22 14:26:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, result_type, collapsed_op, expanding_map);
|
2020-08-22 14:26:35 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
if (isLHLO) {
|
2020-10-24 03:22:21 +08:00
|
|
|
Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
|
|
|
|
reshape_op.getLoc(), result_type, args[0], reassociation_map);
|
2020-07-07 04:57:00 +08:00
|
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, reshape_buffer, args[1], /*inputPermutation =*/nullptr,
|
2020-07-07 04:57:00 +08:00
|
|
|
/*outputPermutation =*/nullptr);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
reshape_op, result_type, args[0], reassociation_map);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-24 00:18:01 +08:00
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class IotaConverter : public OpConversionPattern<OpTy> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-24 00:18:01 +08:00
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
OpTy iota_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-15 02:46:04 +08:00
|
|
|
ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
|
2020-10-24 03:22:21 +08:00
|
|
|
if (!result_shaped_type) return failure();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_element_type = result_shaped_type.getElementType();
|
|
|
|
if (!result_element_type.isSignlessIntOrFloat()) return failure();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Construct the indexing maps needed for linalg.generic ops.
|
2020-10-24 03:22:21 +08:00
|
|
|
unsigned nloops = result_shaped_type.getRank();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-10-24 03:22:21 +08:00
|
|
|
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
|
|
|
iota_op.getLoc(),
|
2020-09-23 00:06:55 +08:00
|
|
|
/*resultTensorTypes=*/
|
2020-10-24 03:22:21 +08:00
|
|
|
isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
|
2020-09-23 00:06:55 +08:00
|
|
|
/*inputs=*/ValueRange{},
|
|
|
|
/*outputBuffers=*/isLHLO ? ValueRange{args} : ValueRange{},
|
|
|
|
/*initTensors=*/ValueRange{},
|
2020-07-07 04:57:00 +08:00
|
|
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
|
|
|
GetNParallelLoopsAttrs(nloops),
|
2020-10-24 03:22:21 +08:00
|
|
|
[&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
|
2020-07-07 04:57:00 +08:00
|
|
|
ValueRange args) {
|
2020-10-24 03:22:21 +08:00
|
|
|
Value cast_op = nested_builder.create<IndexCastOp>(
|
|
|
|
nested_loc, ivs[iota_op.iota_dimension()],
|
|
|
|
nested_builder.getIntegerType(
|
|
|
|
result_element_type.getIntOrFloatBitWidth()));
|
|
|
|
if (result_element_type.template isa<FloatType>()) {
|
|
|
|
cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
|
|
|
|
result_element_type);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-10-24 03:22:21 +08:00
|
|
|
nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
});
|
2020-07-24 00:18:01 +08:00
|
|
|
if (isLHLO)
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(iota_op, llvm::None);
|
2020-07-24 00:18:01 +08:00
|
|
|
else
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.replaceOp(iota_op, linalg_op.result_tensors());
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-12-10 15:24:23 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
class ConstConverter : public OpConversionPattern<OpTy> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-12-10 15:24:23 +08:00
|
|
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-12-10 15:24:23 +08:00
|
|
|
OpTy const_op, ArrayRef<Value> /*args*/,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-12-10 15:24:23 +08:00
|
|
|
Location loc = const_op.getLoc();
|
|
|
|
auto value_attr = const_op.value().template cast<DenseElementsAttr>();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (value_attr.getType().getRank() != 0) return failure();
|
2020-12-10 15:24:23 +08:00
|
|
|
ReplaceConstOp(loc, const_op, value_attr, rewriter);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
2020-12-10 15:24:23 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
void ReplaceConstOp(Location loc, mhlo::ConstOp op,
|
|
|
|
DenseElementsAttr value_attr,
|
|
|
|
ConversionPatternRewriter& rewriter) const {
|
|
|
|
Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
|
|
|
|
rewriter.replaceOp(op, {std_tensor_const});
|
|
|
|
}
|
|
|
|
void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
|
|
|
|
DenseElementsAttr value_attr,
|
|
|
|
ConversionPatternRewriter& rewriter) const {
|
|
|
|
Value std_scalar_const =
|
|
|
|
rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
|
|
|
|
rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
|
|
|
|
llvm::None);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
}
|
2020-07-07 04:57:00 +08:00
|
|
|
};
|
|
|
|
|
2020-10-29 07:37:38 +08:00
|
|
|
class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
|
|
|
auto loc = reduce_op.getLoc();
|
|
|
|
lmhlo::ReduceOp::Adaptor adaptor(args);
|
|
|
|
auto operand_shape =
|
|
|
|
adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!operand_shape || !operand_shape.hasRank()) {
|
|
|
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// First fill the output buffer with the init value.
|
|
|
|
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
|
|
|
|
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
|
|
|
|
|
|
|
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
|
|
|
SmallVector<int, 4> reduction_dims;
|
|
|
|
for (const auto& dim : dimensions_attr.getIntValues()) {
|
|
|
|
reduction_dims.push_back(dim.getSExtValue());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<AffineExpr, 2> src_exprs;
|
|
|
|
SmallVector<AffineExpr, 2> dst_exprs;
|
|
|
|
SmallVector<StringRef, 4> types;
|
|
|
|
for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
|
|
|
|
bool is_reduced = llvm::is_contained(reduction_dims, i);
|
|
|
|
types.push_back(is_reduced ? getReductionIteratorTypeName()
|
|
|
|
: getParallelIteratorTypeName());
|
|
|
|
|
|
|
|
src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
|
|
if (!is_reduced) {
|
|
|
|
dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
|
|
|
|
|
|
|
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
|
|
|
loc, /*resultTensorTypes=*/ArrayRef<Type>{},
|
|
|
|
/*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(),
|
|
|
|
/*initTensors=*/ValueRange{}, maps, types);
|
2020-12-16 00:56:01 +08:00
|
|
|
rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
|
|
|
|
linalg_op.region().end());
|
2020-10-29 07:37:38 +08:00
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard region_guard(rewriter);
|
|
|
|
Block* block = linalg_op.getBody();
|
|
|
|
rewriter.setInsertionPoint(&block->front());
|
|
|
|
|
|
|
|
// The incoming region is operating on buffers, while linalg.generic
|
|
|
|
// expects scalar SSA values. Add some allocs around the original op to
|
|
|
|
// make it compatible.
|
|
|
|
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
|
|
|
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
|
|
|
|
|
|
|
|
// Now turn the existing signature
|
|
|
|
// (memref<X>, memref<X>, memref<X>) -> ()
|
|
|
|
// into
|
|
|
|
// (X, X) -> X
|
|
|
|
TypeConverter::SignatureConversion signature_converter(3);
|
|
|
|
signature_converter.remapInput(0, alloc_a);
|
|
|
|
signature_converter.remapInput(1, alloc_b);
|
|
|
|
signature_converter.remapInput(2, alloc_res);
|
|
|
|
signature_converter.addInputs(
|
|
|
|
{arg_type.getElementType(), arg_type.getElementType()});
|
|
|
|
Block* entry_block = rewriter.applySignatureConversion(
|
|
|
|
&linalg_op.region(), signature_converter);
|
|
|
|
|
|
|
|
// Store the arguments into the newly allocated buffers.
|
|
|
|
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
|
|
|
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
|
|
|
|
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
|
|
|
|
rewriter.replaceOp(entry_block->getTerminator(), {});
|
|
|
|
|
|
|
|
// Load & yield the result.
|
|
|
|
rewriter.setInsertionPointToEnd(entry_block);
|
|
|
|
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
|
|
|
|
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
// TODO(b/156787842): Support the lowering for dynamic shapes.
|
|
|
|
template <typename OpTy, bool isLHLO = true>
|
|
|
|
class ReverseConverter
|
|
|
|
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO> {
|
|
|
|
public:
|
|
|
|
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
|
|
|
isLHLO>::DataMovementOpConverter;
|
|
|
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto result_type =
|
2020-12-15 02:46:04 +08:00
|
|
|
GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
|
2020-10-24 03:22:21 +08:00
|
|
|
auto nloops = result_type.getRank();
|
|
|
|
SmallVector<AffineExpr, 2> input_exprs;
|
|
|
|
input_exprs.reserve(nloops);
|
2020-07-07 04:57:00 +08:00
|
|
|
for (int i = 0; i < nloops; ++i)
|
2020-10-24 03:22:21 +08:00
|
|
|
input_exprs.push_back(b->getAffineDimExpr(i));
|
2020-07-07 04:57:00 +08:00
|
|
|
for (auto dim : op.dimensions()) {
|
|
|
|
int i = dim.getZExtValue();
|
2020-10-24 03:22:21 +08:00
|
|
|
if (result_type.isDynamicDim(i)) return {};
|
|
|
|
int n = result_type.getShape()[i];
|
|
|
|
input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
return {
|
2020-10-24 03:22:21 +08:00
|
|
|
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
|
2020-07-07 04:57:00 +08:00
|
|
|
b->getMultiDimIdentityMap(nloops)};
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-09 01:05:32 +08:00
|
|
|
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
2020-10-24 03:22:21 +08:00
|
|
|
lmhlo::SliceOp slice_op, ArrayRef<Value> args,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConversionPatternRewriter& rewriter) const final {
|
2020-10-24 03:22:21 +08:00
|
|
|
auto loc = slice_op.getLoc();
|
|
|
|
auto arg_type =
|
|
|
|
slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
|
|
|
if (!arg_type || !arg_type.hasRank()) {
|
2020-07-07 04:57:00 +08:00
|
|
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value, 3> ranges;
|
2020-10-24 03:22:21 +08:00
|
|
|
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
|
2020-07-07 04:57:00 +08:00
|
|
|
Value start_index = rewriter.create<ConstantIndexOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
loc, slice_op.start_indices().getValue<int64_t>(i));
|
2020-07-07 04:57:00 +08:00
|
|
|
Value limit_index = rewriter.create<ConstantIndexOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
loc, slice_op.limit_indices().getValue<int64_t>(i));
|
2020-07-07 04:57:00 +08:00
|
|
|
Value stride = rewriter.create<ConstantIndexOp>(
|
2020-10-24 03:22:21 +08:00
|
|
|
loc, slice_op.strides().getValue<int64_t>(i));
|
2020-07-07 04:57:00 +08:00
|
|
|
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
|
|
|
|
limit_index, stride));
|
|
|
|
}
|
|
|
|
auto linalg_slice =
|
2020-10-24 03:22:21 +08:00
|
|
|
rewriter.create<linalg::SliceOp>(loc, slice_op.getOperand(0), ranges);
|
|
|
|
rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
|
|
|
|
rewriter.eraseOp(slice_op);
|
2020-07-07 04:57:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|
|
|
OwningRewritePatternList* patterns) {
|
|
|
|
// clang-format off
|
2020-07-09 01:05:32 +08:00
|
|
|
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
2020-12-10 15:24:23 +08:00
|
|
|
ConstConverter<lmhlo::ConstOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
ConvToLinalgConverter,
|
2020-07-24 00:18:01 +08:00
|
|
|
IotaConverter<lmhlo::IotaOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
LhloBroadcastInDimConverter,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
2020-10-05 20:06:35 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
2020-07-07 04:57:00 +08:00
|
|
|
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::CopyOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
2020-08-31 23:15:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
2020-09-29 20:58:52 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
2020-12-08 05:01:25 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
|
2020-07-09 01:05:32 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
|
|
|
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<lmhlo::XorOp>,
|
2020-10-29 07:37:38 +08:00
|
|
|
ReduceConverter,
|
2020-07-09 01:05:32 +08:00
|
|
|
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
|
|
|
ReverseConverter<lmhlo::ReverseOp>,
|
|
|
|
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
2020-10-29 07:37:38 +08:00
|
|
|
ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
|
2020-09-05 06:37:02 +08:00
|
|
|
SliceConverter,
|
|
|
|
TransposeConverter<lmhlo::TransposeOp>
|
2020-07-07 04:57:00 +08:00
|
|
|
>(context);
|
|
|
|
// clang-format on
|
|
|
|
}
|
|
|
|
|
|
|
|
// Converts LHLO ops to Linalg generic.
|
2020-07-09 01:05:32 +08:00
|
|
|
// Sample result for lmhlo::AddOp.
|
2020-07-07 04:57:00 +08:00
|
|
|
//
|
2020-07-09 01:05:32 +08:00
|
|
|
// "lmhlo.add"(%arg1, %arg2, %out) :
|
2020-07-07 04:57:00 +08:00
|
|
|
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
|
|
|
//
|
|
|
|
// will be converted to
|
|
|
|
//
|
|
|
|
// #map0 = (d0, d1) -> (d0, d1)
|
|
|
|
// "linalg.generic"(%arg1, %arg2, %out) ( {
|
|
|
|
// ^bb0(%arg4: f32, %arg5: f32):
|
|
|
|
// %0 = addf %arg4, %arg5 : f32
|
|
|
|
// "linalg.yield"(%0) : (f32) -> ()
|
|
|
|
// }) {
|
|
|
|
// indexing_maps = [#map0, #map0, #map0],
|
|
|
|
// iterator_types = ["parallel", "parallel"],
|
|
|
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
2020-07-29 07:12:08 +08:00
|
|
|
struct LhloLegalizeToLinalgPass
|
|
|
|
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry& registry) const override {
|
|
|
|
registry.insert<AffineDialect, linalg::LinalgDialect>();
|
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
ConversionTarget target(getContext());
|
2020-07-11 01:03:44 +08:00
|
|
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
|
|
|
AffineDialect>();
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
auto func = getFunction();
|
|
|
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
2020-07-07 04:57:00 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
struct HloLegalizeToLinalgPass
|
|
|
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
2020-08-26 11:30:05 +08:00
|
|
|
void getDependentDialects(DialectRegistry& registry) const override {
|
|
|
|
registry.insert<linalg::LinalgDialect>();
|
|
|
|
}
|
|
|
|
|
2020-07-07 04:57:00 +08:00
|
|
|
void runOnFunction() override {
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
ConversionTarget target(getContext());
|
|
|
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
|
|
|
|
|
|
|
auto func = getFunction();
|
2020-07-07 12:51:24 +08:00
|
|
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
2020-10-27 21:55:28 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
2020-07-07 04:57:00 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2020-07-09 01:05:32 +08:00
|
|
|
namespace lmhlo {
|
2020-07-07 04:57:00 +08:00
|
|
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<LhloLegalizeToLinalgPass>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-07-09 01:05:32 +08:00
|
|
|
} // namespace lmhlo
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-07-07 12:51:24 +08:00
|
|
|
namespace mhlo {
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|
|
|
OwningRewritePatternList* patterns) {
|
2020-07-24 00:18:01 +08:00
|
|
|
patterns
|
|
|
|
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
2020-12-10 15:24:23 +08:00
|
|
|
ConstConverter<mhlo::ConstOp>, HloBroadcastInDimConverter,
|
|
|
|
IotaConverter<mhlo::IotaOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
2020-10-05 20:06:35 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
2020-08-31 23:15:32 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
2020-09-29 20:58:52 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
2020-12-08 05:01:25 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
|
2020-12-09 04:07:39 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::SignOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
|
|
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
2020-12-08 22:38:26 +08:00
|
|
|
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
2020-07-24 00:18:01 +08:00
|
|
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
|
|
|
ReverseConverter<mhlo::ReverseOp, false>,
|
|
|
|
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
2020-07-29 07:12:08 +08:00
|
|
|
return std::make_unique<HloLegalizeToLinalgPass>();
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
2020-07-07 12:51:24 +08:00
|
|
|
} // namespace mhlo
|
2020-07-07 04:57:00 +08:00
|
|
|
} // namespace mlir
|