mlir-hlo/lib/Dialect/mhlo/IR/hlo_ops.cc

2253 lines
82 KiB
C++
Raw Normal View History

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the MHLO dialect.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <functional>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
#include "mlir-hlo/utils/convert_op_folder.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/InliningUtils.h"
namespace mlir {
#include "hlo_patterns.cc.inc"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
namespace mhlo {
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
Type type, Location loc) {
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
template <typename T>
static LogicalResult Verify(T op) {
return success();
}
namespace {
//===----------------------------------------------------------------------===//
// Utilities for the canonicalize patterns
//===----------------------------------------------------------------------===//
// Returns 1D 64-bit dense elements attribute with the given values.
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(ty, values);
}
// Given the start indices and slice sizes for a dynamic-slice that can be
// converted to a static slice, returns the limits for the static slice.
DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
DenseIntElementsAttr slice_sizes,
Builder* builder) {
SmallVector<int64_t, 4> slice_limits;
for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) {
int64_t start_index = start_indices.getValue<IntegerAttr>(i).getInt();
int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
slice_limits.push_back(start_index + slice_size);
}
return GetI64ElementsAttr(slice_limits, builder);
}
#include "mhlo_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//
static void Print(ConstOp op, OpAsmPrinter* printer) {
// Print op name.
*printer << op.getOperationName();
// Elide attribute value while printing the attribute dictionary.
SmallVector<StringRef, 1> elided_attrs;
elided_attrs.push_back("value");
printer->printOptionalAttrDict(op.getAttrs(), elided_attrs);
*printer << ' ' << op.value();
}
static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) {
if (parser->parseOptionalAttrDict(result->attributes)) return failure();
// If colon is not present after attribute dictionary, it should be short form
// and attribute 'value' is outside the dictionary.
if (failed(parser->parseOptionalColon())) {
Attribute value;
if (parser->parseAttribute(value, "value", result->attributes))
return failure();
return parser->addTypeToList(value.getType(), result->types);
}
// Long form should have type of the result after colon.
Type ty;
if (parser->parseType(ty)) return failure();
result->types.push_back(ty);
return success();
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
// Return the held attribute value.
return value();
}
// Builds a constant op with the specified attribute `value`.
void ConstOp::build(OpBuilder& builder, OperationState& result,
Attribute value) {
Type type;
if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
type = elemAttr.getType();
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
// All XLA types must be tensor types. In the build() method, we want to
// provide more flexibility by allowing attributes of scalar types. But we
// need to wrap it up with ElementsAttr to construct valid XLA constants.
type = RankedTensorType::get(/*shape=*/{}, value.getType());
value = DenseElementsAttr::get(type.cast<TensorType>(), value);
}
// TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building mhlo.constant");
result.types.push_back(type);
result.addAttribute("value", value);
}
//===----------------------------------------------------------------------===//
// DotGeneralOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DotGeneralOp op) {
auto dot_dimension_numbers = op.dot_dimension_numbers();
int64_t lhs_batching_dimensions_size = llvm::size(
dot_dimension_numbers.lhs_batching_dimensions().getValues<int64_t>());
int64_t rhs_batching_dimensions_size = llvm::size(
dot_dimension_numbers.rhs_batching_dimensions().getValues<int64_t>());
if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) {
return op.emitError()
<< "lhs and rhs should have the same number of batching dimensions";
}
int64_t lhs_contracting_dimensions_size = llvm::size(
dot_dimension_numbers.lhs_contracting_dimensions().getValues<int64_t>());
int64_t rhs_contracting_dimensions_size = llvm::size(
dot_dimension_numbers.rhs_contracting_dimensions().getValues<int64_t>());
if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) {
return op.emitError() << "lhs and rhs should have the same number of "
"contracting dimensions";
}
return success();
}
//===----------------------------------------------------------------------===//
// IotaOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(IotaOp op) {
auto shape = op.getType().cast<ShapedType>();
if (!shape.hasRank()) return success();
if (shape.getRank() == 0)
return op.emitOpError() << "does not support scalars.";
auto iota_dimension = op.iota_dimension().getSExtValue();
if (iota_dimension >= shape.getRank() || iota_dimension < 0)
return op.emitOpError() << "iota dimension cannot go beyond the output "
"rank or be negative.";
return success();
}
// Iota operations across multiple dimensions can be reduced to an iota and a
// ranked broadcast.
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
using OpRewritePattern<IotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
result_ty.getElementType());
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
broadcast_attr);
return success();
}
};
void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IotaBroadcast>(context);
}
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
auto dimension = iota_dimension().getLimitedValue();
auto result_ty = getResult().getType().cast<ShapedType>();
if (result_ty.hasRank() && result_ty.getDimSize(dimension) == 1) {
Builder builder(getContext());
return builder.getZeroAttr(result_ty);
}
return {};
}
//===----------------------------------------------------------------------===//
// DynamicIotaOp
//===----------------------------------------------------------------------===//
namespace {
struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) {
return failure();
}
rewriter.replaceOpWithNewOp<IotaOp>(iota, result_ty, iota.iota_dimension());
return success();
}
};
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
// and a ranked broadcast.
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_dimension_int = iota_dimension.getLimitedValue();
auto converted_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
iota.output_shape().getType().cast<ShapedType>().getShape(),
rewriter.getI64Type()),
iota.output_shape());
auto sliced_shape = rewriter.create<SliceOp>(
iota.getLoc(), converted_shape,
GetI64ElementsAttr(iota_dimension_int, &rewriter),
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
GetI64ElementsAttr(1, &rewriter));
auto converted_sliced_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
{1},
iota.output_shape().getType().cast<ShapedType>().getElementType()),
sliced_shape);
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
auto new_iota = rewriter.create<DynamicIotaOp>(
iota.getLoc(), iota_type, converted_sliced_shape,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
return success();
}
};
} // namespace
void DynamicIotaOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicIotaIsStatic>(context);
results.insert<DynamicIotaBroadcast>(context);
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//
void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) {
auto shaped_type = operand.getType().cast<ShapedType>();
Type new_type;
if (!shaped_type.getElementType().isa<ComplexType>()) {
new_type = operand.getType();
} else if (shaped_type.hasRank()) {
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
} else {
new_type = UnrankedTensorType::get(operand.getType());
}
return AbsOp::build(builder, result, new_type, operand);
}
//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CollectivePermuteOp op) {
// Check that source target pair is Nx2 tensor.
auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>();
if (type.getRank() != 2)
return op.emitError() << "expect source_target_pairs attribute to be of "
"rank 2, but got rank "
<< type.getRank();
if (type.getShape()[1] != 2)
return op.emitError()
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
<< type.getShape() << ")";
// Check source target pairs for duplicate sources or targets
llvm::DenseSet<int64_t> sources;
llvm::DenseSet<int64_t> targets;
for (auto i = op.source_target_pairs().begin(),
e = op.source_target_pairs().end();
i != e; ++i) {
auto val = (*i).getSExtValue();
if (i.getIndex() % 2 == 0) {
bool is_unique = sources.insert(val).second;
if (!is_unique) return op.emitError() << "duplicate sources not allowed.";
} else {
bool is_unique = targets.insert(val).second;
if (!is_unique) return op.emitError() << "duplicate targets not allowed.";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ConvertOp
//===----------------------------------------------------------------------===//
void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
Type result_element_ty) {
Type result_ty;
Type operand_ty = operand.getType();
if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
} else {
result_ty = UnrankedTensorType::get(result_element_ty);
}
build(builder, result, result_ty, operand);
}
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand().getType() == getResult().getType()) return getOperand();
// If the result has non-static shape, a convert op is necessary to go from
// static shape to non-static shape.
if (!getResult().getType().cast<TensorType>().hasStaticShape()) return {};
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return hlo::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
}
return {};
}
//===----------------------------------------------------------------------===//
// DequantizeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DequantizeOp op) {
auto input_type = op.input().getType().dyn_cast<ShapedType>();
auto output_type = op.output().getType().dyn_cast<ShapedType>();
if (!input_type || !output_type) {
return op.emitError() << "ranked input and output.";
}
auto input_shape = input_type.getShape();
auto output_shape = output_type.getShape().vec();
if (op.transpose_output()) {
std::reverse(output_shape.begin(), output_shape.end());
}
// Check the input rank and output rank are same, and also the lower
// dimensions are same.
if (input_shape.size() != output_shape.size() ||
!std::equal(input_shape.begin(),
std::next(input_shape.begin(), input_shape.size() - 1),
output_shape.begin())) {
return op.emitError() << "mismatched dimensions.";
}
// Check that the last dimension of the output is 2x or 4x of that of the
// input depending on the unpacked input is 16 or 8 bits.
int input_last_dim = *input_shape.rbegin();
int output_last_dim = *output_shape.rbegin();
int scale_factor = op.is_16bits() ? 2 : 4;
if (output_last_dim != scale_factor * input_last_dim) {
return op.emitError() << "last dimension of output should be "
<< scale_factor << "x of the input.";
}
return success();
}
//===----------------------------------------------------------------------===//
// GetTupleElementOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(GetTupleElementOp op) {
auto indexVal = op.index().getZExtValue();
auto operandType = op.getOperand().getType().cast<TupleType>();
if (indexVal >= operandType.size()) {
return op.emitOpError(
llvm::formatv("index {0} is out of bounds of operand with size {1}",
indexVal, operandType.size()));
}
auto expectedType = operandType.getType(indexVal);
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
}
return success();
}
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp =
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue());
}
return {};
}
//===----------------------------------------------------------------------===//
// TupleOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TupleOp op) {
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
op.operand_type_end()};
auto expectedType = TupleType::get(operandTypes, op.getContext());
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
}
return success();
}
namespace {
// Pattern for unpacking and repacking the same tuple.
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
using OpRewritePattern<TupleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TupleOp op,
PatternRewriter& rewriter) const override {
if (op.val().empty()) return failure();
Value first_element = op.val().front();
auto first_element_op =
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure();
Value tuple_predecessor = first_element_op.getOperand();
if (tuple_predecessor.getType() != op.getType()) return failure();
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
element_and_idx.value().getDefiningOp());
if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor)
return failure();
}
rewriter.replaceOp(op, tuple_predecessor);
return success();
}
};
} // namespace
void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<UnpackRepackSameTuple>(context);
}
//===----------------------------------------------------------------------===//
// AllToAllOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllToAllOp op) {
// If operand is ranked, size of split dimension should be a multiple of split
// count.
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
if (!type) return success();
auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue());
auto split_count = op.split_count().getSExtValue();
if (split_dim_size % split_count != 0) {
return op.emitError() << "split dimension has size " << split_dim_size
<< ", expected to be a multiple of split_count "
<< split_count;
}
return success();
}
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
// TODO(b/129012527) These should be expressed as type constraints.
static LogicalResult Verify(BroadcastOp op) {
auto sizes = op.broadcast_sizes();
auto sizesType = sizes.getType();
auto sizesRank = sizesType.getRank();
if (sizesRank != 1) {
return op.emitOpError(llvm::formatv(
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
}
auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank();
auto sizesSize = sizesType.getNumElements();
auto expectedRank = operandRank + sizesSize;
if (resultRank != expectedRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) does not match operand rank "
"({1}) plus size of broadcast_sizes ({2})",
resultRank, operandRank, sizesSize));
}
llvm::SmallVector<int64_t, 10> expectedShape(sizes.getValues<int64_t>());
auto operandShape = operandType.getShape();
expectedShape.insert(expectedShape.end(), operandShape.begin(),
operandShape.end());
auto resultShape = resultType.getShape();
if (resultShape != llvm::makeArrayRef(expectedShape)) {
return op.emitOpError(llvm::formatv(
"result has shape [{0}] instead of [{1}]",
llvm::make_range(resultShape.begin(), resultShape.end()),
llvm::make_range(expectedShape.begin(), expectedShape.end())));
}
return success();
}
//===----------------------------------------------------------------------===//
// BroadcastInDimOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
auto operandRank = operandType.getRank();
if (!op.broadcast_dimensions()) {
if (operandRank == 0) {
return success();
}
return op.emitOpError(
llvm::formatv("broadcast_dimensions is absent, but required because "
"operand has non-zero rank ({0})",
operandRank));
}
auto dimensions = op.broadcast_dimensions();
auto dimensionsType = op.broadcast_dimensions().getType();
auto dimensionsRank = dimensionsType.getRank();
if (dimensionsRank != 1) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
}
auto dimensionsSize = dimensionsType.getNumElements();
if (dimensionsSize != operandRank) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
dimensionsSize, operandRank));
}
auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
if (resultRank < operandRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is less than operand rank ({1})",
resultRank, operandRank));
}
for (int i = 0; i != dimensionsSize; ++i) {
auto dimIndex = dimensions.getValue<int64_t>(i);
if (dimIndex >= resultRank) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result result with rank {1}",
dimIndex, resultRank));
}
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
}
}
return success();
}
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
auto type = getType().cast<RankedTensorType>();
if (type == getOperand().getType()) {
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
llvm::seq<int64_t>(0, type.getRank()).begin())) {
return {};
}
return getOperand();
}
// Constant fold when an operand is a splat tensor attribute.
if (!attrs[0] || !type.hasStaticShape()) return {};
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
if (!splatOperandAttr) return {};
// MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
// attribute iterator not implemented for complex element types.
if (type.getElementType().isa<ComplexType>()) return {};
return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
}
//===----------------------------------------------------------------------===//
// DynamicBroadcastInDimOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DynamicBroadcastInDimOp op) {
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
// If either the operand or result are unranked, there is very little
// to verify statically.
if (!operandType || !resultType) {
return success();
}
auto outputDimensionsType =
op.output_dimensions().getType().cast<RankedTensorType>();
auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
auto operandRank = operandType.getRank();
auto resultRank = resultType.getRank();
// Verify broadcast_dimensions.
auto bcastDimensions = op.broadcast_dimensions();
auto bcastDimensionsType = op.broadcast_dimensions().getType();
auto bcastDimensionsRank = bcastDimensionsType.getRank();
// TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
if (bcastDimensionsRank != 1) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
bcastDimensionsRank));
}
auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
if (bcastDimensionsSize != operandRank) {
return op.emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
bcastDimensionsSize, operandRank));
}
if (resultRank < operandRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is less than operand rank ({1})",
resultRank, operandRank));
}
for (int i = 0; i != bcastDimensionsSize; ++i) {
auto dimIndex = bcastDimensions.getValue<int64_t>(i);
if (dimIndex >= resultRank) {
return op.emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result result with rank {1}",
dimIndex, resultRank));
}
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
}
}
if (outputDimensionsSize != resultRank) {
return op.emitOpError(
llvm::formatv("result rank ({0}) is not equal to number of output "
"dimensions ({1})",
resultRank, outputDimensionsSize));
}
return success();
}
// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary
// BroadcastInDimOp.
class DynamicBroadcastInDimOpNotActuallyDynamic
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op,
PatternRewriter& rewriter) const override {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type || !type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "requires static shape");
}
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(
op, op.getType(), op.operand(), op.broadcast_dimensions());
return success();
}
};
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>(
context);
}
//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ClampOp op) {
auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandShape = operandType.getShape();
auto minType = op.min().getType().cast<RankedTensorType>();
auto minShape = minType.getShape();
if (minShape != operandShape && minType.getRank() != 0) {
return op.emitOpError(llvm::formatv(
"min shape [{0}] is not scalar and does not match operand shape [{1}]",
llvm::make_range(minShape.begin(), minShape.end()),
llvm::make_range(operandShape.begin(), operandShape.end())));
}
auto maxType = op.max().getType().cast<RankedTensorType>();
auto maxShape = maxType.getShape();
if (maxShape != operandShape && maxType.getRank() != 0) {
return op.emitOpError(llvm::formatv(
"max shape [{0}] is not scalar and does not match operand shape [{1}]",
llvm::make_range(maxShape.begin(), maxShape.end()),
llvm::make_range(operandShape.begin(), operandShape.end())));
}
return success();
}
//===----------------------------------------------------------------------===//
// ComplexOp
//===----------------------------------------------------------------------===//
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
Value rhs) {
auto type = lhs.getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty);
} else if (type.isa<UnrankedTensorType>()) {
result_ty = UnrankedTensorType::get(element_ty);
} else {
result_ty = element_ty;
}
build(builder, state, result_ty, lhs, rhs);
}
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand();
}
return {};
}
namespace {
Type CreateRealType(Type type) {
auto element_ty = getElementTypeOrSelf(type);
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
element_ty = complex_ty.getElementType();
}
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(ranked_type.getShape(), element_ty);
} else if (type.dyn_cast<UnrankedTensorType>()) {
return UnrankedTensorType::get(element_ty);
}
return element_ty;
}
} // namespace
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val.getType()), val);
}
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1);
}
return {};
}
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val.getType()), val);
}
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0);
}
return {};
}
//===----------------------------------------------------------------------===//
// ConcatenateOp
//===----------------------------------------------------------------------===//
namespace {
class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto axis = op.dimension().getLimitedValue();
llvm::SmallVector<Value, 6> new_operands;
for (auto operand : op.getOperands()) {
auto ty = operand.getType().cast<ShapedType>();
if (ty.getDimSize(axis) != 0) {
new_operands.push_back(operand);
}
}
if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
new_operands, op.dimension());
return success();
}
return failure();
}
};
} // namespace
LogicalResult ConcatenateOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
if (operands.empty()) {
return failure();
}
auto dimension_attr = attributes.get("dimension").cast<IntegerAttr>();
auto dimension = dimension_attr.getInt();
auto first_type = (*operands.begin()).getType().cast<ShapedType>();
auto out_element = first_type.getElementType();
for (auto operand : operands.getTypes()) {
auto element_type = getElementTypeOrSelf(operand);
if (element_type != out_element) {
return failure();
}
}
// If an input is unranked the output shape is unranked.
if (!first_type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success();
}
auto out_shape = llvm::to_vector<6>(first_type.getShape());
out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) {
auto type = operand.cast<ShapedType>();
if (!type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success();
}
// If the dimension is dynamic we know the output dimension is dynamic.
auto dim = type.getShape()[dimension];
if (dim == -1) {
out_shape[dimension] = -1;
break;
}
out_shape[dimension] += dim;
}
inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element));
return success();
}
void ConcatenateOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<ConcatenateOperandRemoval>(context);
}
template <typename T>
static Attribute foldConcatenateHelper(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
auto axis = op->dimension().getLimitedValue();
auto type = op->getType().cast<ShapedType>();
SmallVector<T, 6> values;
auto shape = type.getShape();
size_t top_size = 1;
for (int i = 0, e = axis; i < e; i++) {
top_size = top_size * shape[i];
}
for (size_t i = 0; i < top_size; i++) {
for (auto operand : operands) {
DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
size_t bottom_size = attr.getNumElements() / top_size;
auto iter = attr.getValues<T>().begin() + i * bottom_size;
values.append(iter, iter + bottom_size);
}
}
return DenseElementsAttr::get(type, values);
}
static Attribute foldConcatenate(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
for (auto operand : operands) {
if (!operand) return {};
}
auto type = op->getResult().getType().cast<ShapedType>();
auto etype = type.getElementType();
if (etype.isa<IntegerType>()) {
return foldConcatenateHelper<APInt>(op, operands);
}
if (etype.isa<FloatType>()) {
return foldConcatenateHelper<APFloat>(op, operands);
}
return {};
}
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
if (getNumOperands() == 1) return getOperand(0);
ShapedType type = getResult().getType().cast<ShapedType>();
if (!type.hasStaticShape()) return {};
auto axis = dimension().getLimitedValue();
if (auto attr = foldConcatenate(this, operands)) {
return attr;
}
llvm::SmallVector<Value, 6> new_operands;
for (auto operand : getOperands()) {
auto ty = operand.getType().cast<ShapedType>();
if (ty.getDimSize(axis) != 0) {
return {};
}
}
return DenseElementsAttr::get(type, ArrayRef<Attribute>());
}
static LogicalResult Verify(ConcatenateOp op) {
Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
RankedTensorType first_ranked_type;
int num_operands = op.getNumOperands();
for (int i = 0; i < num_operands; i++) {
auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
if (second_type.getElementType() != element_type) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match element type", i));
}
if (!second_type.hasRank()) {
continue;
}
if (!first_ranked_type) {
first_ranked_type = second_type.cast<RankedTensorType>();
continue;
}
if (first_ranked_type.getRank() != second_type.getRank()) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match rank", i));
}
auto first_shape = second_type.getShape();
auto second_shape = second_type.getShape();
for (int d = 0; d < first_ranked_type.getRank(); ++d) {
if (first_shape[d] != second_shape[d] && d != op.dimension()) {
return op.emitOpError(llvm::formatv(
"operands (0) and ({0}) non-concat dimensions do not match "
"({1}) != ({2})",
i, llvm::make_range(first_shape.begin(), first_shape.end()),
llvm::make_range(second_shape.begin(), second_shape.end())));
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// DynamicReshapeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DynamicReshapeOp op) {
auto result_type = op.result().getType().dyn_cast<RankedTensorType>();
auto output_shape_type =
op.output_shape().getType().dyn_cast<RankedTensorType>();
if (result_type && output_shape_type && output_shape_type.hasStaticShape() &&
output_shape_type.getDimSize(0) != result_type.getRank()) {
return op.emitError() << "output should have a rank equal to the number of "
"elements in output_shape";
}
return success();
}
namespace {
class DynamicReshapeOpNotActuallyDynamic
: public OpRewritePattern<DynamicReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicReshapeOp op,
PatternRewriter& rewriter) const override {
auto type = op.result().getType().dyn_cast<RankedTensorType>();
if (!type || !type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "requires static shape tensor");
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), op.operand());
return success();
}
};
} // namespace
void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
}
//===----------------------------------------------------------------------===//
// DynamicSliceOp
//===----------------------------------------------------------------------===//
namespace {
// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops.
// This canonicalization is applied the case when the `begin` input values are
// compile time constants and thus can be made into a tensor.
struct DynamicSliceToSlice : public OpRewritePattern<DynamicSliceOp> {
using OpRewritePattern<DynamicSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice,
PatternRewriter& rewriter) const override {
Value input = dynamic_slice.operand();
auto input_tensor = input.getType().dyn_cast<RankedTensorType>();
if (!input_tensor) return failure();
SmallVector<int64_t, 4> temp_start_indices;
for (Value start : dynamic_slice.start_indices()) {
APInt val;
if (!matchPattern(start, m_ConstantInt(&val))) {
return failure();
}
temp_start_indices.push_back(*(val.getRawData()));
}
// At this point we've determined that the start indices are all constants;
// pack them into a single tensor.
auto loc = dynamic_slice.getLoc();
int64_t input_rank = input_tensor.getRank();
auto slice_start_indices =
GetI64ElementsAttr(temp_start_indices, &rewriter);
DenseIntElementsAttr slice_limits = BuildSliceLimits(
slice_start_indices, dynamic_slice.slice_sizes(), &rewriter);
DenseIntElementsAttr slice_strides =
GetI64ElementsAttr(SmallVector<int64_t, 4>(input_rank, 1), &rewriter);
auto result = rewriter.create<SliceOp>(loc, input, slice_start_indices,
slice_limits, slice_strides);
rewriter.replaceOp(dynamic_slice, {result});
return success();
}
};
} // namespace
void DynamicSliceOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicSliceToSlice>(context);
}
// Verifies that the number of slice sizes and the number of start indices match
static LogicalResult Verify(DynamicSliceOp op) {
int num_slice_sizes = op.slice_sizes().getNumElements();
int num_start_indices = op.start_indices().size();
if (num_start_indices != num_slice_sizes) {
return op.emitOpError()
<< "has mismatched number of slice sizes (" << num_slice_sizes
<< ") and number of start indices (" << num_start_indices << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// InfeedOp
//===----------------------------------------------------------------------===//
// Checks that the result type is of the form `tuple< any_type, token >`.
static LogicalResult Verify(InfeedOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
if (subtypes.size() != 2)
return op.emitOpError()
<< "result is expected to be a tuple of size 2, but got "
<< subtypes.size();
if (!subtypes[1].isa<TokenType>())
return op.emitOpError() << "second element of result tuple is expected to "
"be of token type, but got "
<< subtypes[1];
return success();
}
//===----------------------------------------------------------------------===//
// MapOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(MapOp op) {
// Checks if the number of `operands` match the arity of the map `computation`
// region.
auto& computation_block = op.computation().front();
auto computation_args = computation_block.getArguments();
if (op.operands().size() != computation_args.size())
return op.emitOpError()
<< "expects number of operands to match the arity "
"of map computation, but got: "
<< op.operands().size() << " and " << computation_args.size();
// The parameters of computation should all be scalars and match the element
// type of operands.
auto operand_type = op.operands()[0].getType().cast<TensorType>();
auto operand_elem_ty = operand_type.getElementType();
for (auto indexed_arg : llvm::enumerate(computation_args)) {
auto arg_type = indexed_arg.value().getType().dyn_cast<TensorType>();
if (!arg_type || arg_type.getRank() != 0)
return op.emitOpError()
<< "computation arguments must be 0-rank tensor, but got: arg #"
<< indexed_arg.index() << " of type "
<< indexed_arg.value().getType();
if (arg_type.getElementType() != operand_elem_ty) {
return op.emitOpError()
<< "element type of operands and computation arguments must "
"match, but got: "
<< operand_elem_ty << " and " << arg_type.getElementType();
}
}
// Mapped computation must return single output
auto computation_outputs = computation_block.getTerminator()->getOperands();
if (computation_outputs.size() != 1)
return op.emitOpError()
<< "computation must return single output, but got: "
<< computation_outputs.size();
// The output of computation must be scalar and have the same element type
// as op result.
auto computation_output_type =
computation_outputs[0].getType().dyn_cast<TensorType>();
if (!computation_output_type || computation_output_type.getRank() != 0)
return op.emitOpError()
<< "computation must return 0-rank tensor, but got: "
<< computation_outputs[0].getType();
auto result_type = op.getType().cast<TensorType>();
if (computation_output_type.getElementType() != result_type.getElementType())
return op.emitOpError() << "element type of result and computation output "
"must match, but got: "
<< result_type.getElementType() << " and "
<< computation_output_type.getElementType();
// Checks that the requested map dimension numbers are monotonically
// increasing.
auto values = op.dimensions().getValues<int64_t>();
auto dimensions = std::vector<int64_t>{values.begin(), values.end()};
for (int i = 0, e = dimensions.size(); i < e; ++i) {
if (dimensions[i] != i)
return op.emitOpError() << "requires monotonically increasing dimension "
"numbers, but got: "
<< op.dimensions();
}
// Checks that number of dimensions of operands matches the size of
// `dimensions` since we currently only support mapping across all
// dimensions: i.e., scalar map functions.
if (operand_type.hasRank()) {
if (dimensions.size() != operand_type.getShape().size())
return op.emitOpError()
<< "applied to a subset of dimensions currently not supported: "
"operand dimensions = "
<< operand_type.getShape().size()
<< ", requested map dimensions size = " << dimensions.size();
}
return success();
}
//===----------------------------------------------------------------------===//
// RecvOp
//===----------------------------------------------------------------------===//
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
static LogicalResult Verify(RecvOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
if (subtypes.size() != 2)
return op.emitOpError()
<< "result is expected to be a tuple of size 2, but got "
<< subtypes.size();
if (!subtypes[1].isa<TokenType>())
return op.emitOpError() << "second element of result tuple is expected to "
"be of token type, but got "
<< subtypes[1];
return success();
}
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
OpFoldResult CopyOp::fold(ArrayRef<Attribute> operands) { return getOperand(); }
//===----------------------------------------------------------------------===//
// ReverseOp
//===----------------------------------------------------------------------===//
OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
auto input = operand();
// No dimensions to reverse.
if (dimensions().getNumElements() == 0) return input;
llvm::SmallVector<APInt, 5> new_dims;
new_dims.reserve(dimensions().getNumElements());
auto shaped_type = input.getType().cast<ShapedType>();
for (auto dim : dimensions().getValues<APInt>()) {
if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
return nullptr;
}
}
return input;
}
//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
// Returns the result type after reducing operand of the given type across the
// specified dimensions.
static TensorType GetReduceResultType(Type operand_ty,
DenseIntElementsAttr dimensions,
Builder* builder) {
Type element_ty = getElementTypeOrSelf(operand_ty);
auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return UnrankedTensorType::get(element_ty);
int64_t rank = ranked_ty.getRank();
llvm::SmallVector<bool, 4> dims_mask(rank, false);
for (int64_t dim : dimensions.getValues<int64_t>()) dims_mask[dim] = true;
SmallVector<int64_t, 4> shape;
for (int64_t i = 0; i < rank; ++i) {
if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i));
}
return RankedTensorType::get(shape, element_ty);
}
void ReduceOp::build(OpBuilder& builder, OperationState& state,
ValueRange operands, ValueRange init_values,
DenseIntElementsAttr dimensions) {
SmallVector<Type, 1> result_ty;
result_ty.reserve(operands.size());
for (Value operand : operands) {
result_ty.push_back(
GetReduceResultType(operand.getType(), dimensions, &builder));
}
build(builder, state, result_ty, operands, init_values, dimensions);
}
LogicalResult ReduceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult>& results) {
// No dimensions to reduce.
if (dimensions().getNumElements() == 0) {
for (Value input : this->operands()) {
results.push_back(input);
}
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SelectOp op) {
// TODO(jpienaar): Update to allow broadcastable and unranked inputs. This
// corresponds to the client side HLO.
return success();
}
// Makes it such that a SelectOp that is a non-root operation in a DRR infers
// the return type based on operand type.
LogicalResult SelectOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto x_type = operands[1].getType();
auto y_type = operands[2].getType();
auto x_tensor = x_type.cast<TensorType>();
auto y_tensor = y_type.cast<TensorType>();
// Check for type compatibility in the select op. This requires that the two
// non-predicate operands:
// (a) have the same element type
// (b) have compatible shapes (i.e. the same shape and/or at least one
// dynamic shape)
if (x_tensor.getElementType() != y_tensor.getElementType() ||
failed(mlir::verifyCompatibleShape(x_type, y_type))) {
return emitOptionalError(location, "incompatible operand types: ", x_type,
" and ", y_type);
}
// TODO(lucyfox): Support output shape inference when operands have compatible
// shapes. (The output shape should be the most general of the operand shapes
// at each dimension.) For now, handle the straightforward cases and fail
// otherwise. When this is fully implemented, this logic should move into
// reusable functionality in MLIR Core.
Type output_type;
if (x_type == y_type || !x_tensor.hasRank()) {
output_type = x_type;
} else if (!y_tensor.hasRank()) {
output_type = y_type;
} else {
return emitOptionalError(location,
"currently unsupported operand types: ", x_type,
" and ", y_type);
}
inferredReturnTypes.assign({output_type});
return success();
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(PadOp op) {
auto input_type = op.operand().getType().cast<RankedTensorType>();
auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
if (pad_type.getRank() != 0) {
return op.emitOpError(
llvm::formatv("padding value type should be a rank-0 "
"tensor, is rank {0}",
pad_type.getRank()));
}
const auto& padding_low = op.edge_padding_low();
if (padding_low.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"edge_padding_low length ({0}) must match operand rank ({1})",
padding_low.getType().getNumElements(), input_type.getRank()));
}
const auto& padding_high = op.edge_padding_high();
if (padding_high.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"edge_padding_high length ({0}) must match operand rank ({1})",
padding_high.getType().getNumElements(), input_type.getRank()));
}
const auto& padding_interior = op.interior_padding();
if (padding_interior.getType().getNumElements() != input_type.getRank()) {
return op.emitOpError(llvm::formatv(
"interior_padding length ({0}) must match operand rank ({1})",
padding_interior.getType().getNumElements(), input_type.getRank()));
}
auto input_shape = input_type.getShape();
auto output_shape =
op.getResult().getType().cast<RankedTensorType>().getShape();
if (input_shape.size() != output_shape.size()) {
return op.emitOpError(
llvm::formatv("operand rank ({0}) and result rank({0}) should match",
input_shape.size(), output_shape.size()));
}
for (int i = 0, e = input_shape.size(); i < e; i++) {
int padding_low_val = padding_low.getValue<IntegerAttr>(i).getInt();
int padding_high_val = padding_high.getValue<IntegerAttr>(i).getInt();
int padding_interior_val =
padding_interior.getValue<IntegerAttr>(i).getInt();
int expected_output =
input_shape[i] + padding_low_val + padding_high_val +
std::max<int64_t>(input_shape[i] - 1, 0LL) * padding_interior_val;
if (expected_output != output_shape[i]) {
return op.emitOpError(llvm::formatv(
"expected output shape's dimension #{0} to be {1} but found {2}", i,
expected_output, output_shape[i]));
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(ReshapeOp op) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto result_ty = op.getType().cast<RankedTensorType>();
assert(result_ty && result_ty.hasStaticShape() &&
"result type must be statically shaped");
int64_t num_result_elements = result_ty.getNumElements();
int64_t num_operand_elements = operand_ty.getNumElements();
if (num_result_elements != num_operand_elements)
return op.emitOpError()
<< "number of output elements (" << num_result_elements
<< ") doesn't match expected number of elements ("
<< num_operand_elements << ")";
return success();
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (getOperand().getType() == getType()) {
return getOperand();
}
if (auto prev_op =
dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
setOperand(prev_op.getOperand());
return getResult();
}
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(getResult().getType().cast<ShapedType>());
}
return {};
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CaseOp op) {
auto num_branches = op.branches().size();
if (op.branch_operands().size() != num_branches)
return op.emitOpError() << "expects number of branches " << num_branches
<< " to be same as number of branch operands "
<< op.branch_operands().size();
MutableArrayRef<Region> branches = op.branches();
OperandRange branch_operands = op.branch_operands();
for (unsigned i = 0; i < num_branches; ++i) {
mlir::Region& branch_region = branches[i];
if (branch_region.empty())
return op.emitOpError() << "cannot have empty regions";
mlir::Block& entry_block = branch_region.front();
if (entry_block.getNumArguments() != 1)
return op.emitOpError()
<< "expects branch regions to have single argument, but found "
<< entry_block.getNumArguments() << " for branch " << i;
auto operand = branch_operands[i];
if (entry_block.getArgument(0).getType() != operand.getType())
return op.emitOpError()
<< "expects operand " << i + 1 << " to be of type "
<< entry_block.getArgument(0).getType() << ", but found "
<< operand.getType();
WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
if (return_op.getOperands().getTypes() != op.getResultTypes())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walker.wasInterrupted())
return op.emitOpError()
<< "branch " << i
<< " returned values do not match op result types";
}
return success();
}
//===----------------------------------------------------------------------===//
// BinaryOps
//===----------------------------------------------------------------------===//
namespace {
// Updates the element type of a (presumed) tensor type 'x', returning either
// a permuted UnrankedTensorType or RankedTensorType.
static Type UpdateResultElementType(Builder* builder, Type x,
Type element_type) {
auto x_ranked = x.dyn_cast<RankedTensorType>();
if (!x_ranked) {
return UnrankedTensorType::get(element_type);
}
auto shape_x = x_ranked.getShape();
return RankedTensorType::get(shape_x, element_type);
}
} // namespace
template <typename Op, typename ElementType = Type, typename ValType,
typename Convert>
static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
if (!attrs[0] || !attrs[1]) return {};
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
if (!lhs || !rhs) return {};
ShapedType type = op->getType().template cast<ShapedType>();
if (!type.hasStaticShape()) {
return {};
}
Type etype = type.getElementType();
// Evaluate for integer values.
if (!etype.isa<ElementType>()) {
return {};
}
SmallVector<ValType, 6> values;
values.reserve(lhs.getNumElements());
for (const auto zip :
llvm::zip(lhs.getValues<ValType>(), rhs.getValues<ValType>())) {
values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
}
return DenseElementsAttr::get(type, values);
}
template <typename T>
struct divide : std::divides<T> {};
template <>
struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
};
template <typename T>
struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
};
template <>
struct max<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smax(a, b);
}
};
template <typename T>
struct min {
T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
};
template <>
struct min<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smin(a, b);
}
};
#define BINARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
return {}; \
}
// Addition, subtraction and multiplication use the std:: versions of the ops.
// Due to the other ops behaving differently in signed vs unsigned integers,
// APInts need a special implementation. Currently, it replicates signed int
// op behavior.
BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min);
#undef BINARY_FOLDER
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
return build(builder, result,
InferOutputTypes(&builder, operand, start_indices, limit_indices,
strides),
operand, start_indices, limit_indices, strides);
}
template <typename I, typename E>
static void SliceElements(I values, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
ArrayRef<int64_t> strides,
llvm::SmallVectorImpl<E>* out_values) {
assert(starts.size() == limits.size());
assert(starts.size() == strides.size());
if (starts.empty()) return;
int64_t start = starts.front();
int64_t limit = limits.front();
int64_t stride = strides.front();
if (starts.size() == 1) {
for (int i = start; i < limit; i += stride) {
out_values->push_back(*(values + i));
}
return;
}
for (; start < limit; start += stride) {
auto begin = values + start * sizes.front();
SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
limits.drop_front(), strides.drop_front(), out_values);
}
}
template <typename I, typename E>
static Attribute FoldSlice(SliceOp* op, I values) {
auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
auto result_type = op->operand().getType().cast<ShapedType>();
if (!result_type.hasStaticShape()) return {};
auto shape = result_type.getShape();
int64_t count = result_type.getNumElements();
// Compute the striding for each dimension.
llvm::SmallVector<int64_t, 6> sizes;
sizes.reserve(shape.size());
for (auto v : shape) {
count = count / v;
sizes.push_back(count);
}
llvm::SmallVector<E, 6> out_values;
out_values.reserve(result_type.getNumElements());
SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
out_values);
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
// Check if the SliceOp is a NoOp operation.
auto operand_shape = getOperand().getType().cast<ShapedType>().getShape();
auto result_type = getResult().getType().cast<ShapedType>();
auto result_shape = result_type.getShape();
if (result_type.hasStaticShape() && (operand_shape == result_shape)) {
return getOperand();
}
if (operands.empty() || !operands.front()) return {};
// Evaluate for statically valued inputs.
DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
if (!elements) return {};
auto etype = elements.getType().getElementType();
if (etype.isa<IntegerType>()) {
return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
this, elements.getIntValues().begin());
} else if (etype.isa<FloatType>()) {
return FoldSlice<
llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
std::function<APFloat(const APInt&)>>,
APFloat>(this, elements.getFloatValues().begin());
}
return {};
}
namespace {
// In cases where a concat is fed into a slice, it is possible the concat
// can be simplified or bypassed. This checks which inputs to the concat are
// used by the slice, either reducing the number of concatenated values or
// entirely removes the concat.
struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
using OpRewritePattern<SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SliceOp slice,
PatternRewriter& rewriter) const override {
auto result_ty = slice.getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) {
return failure();
}
auto slice_input = slice.operand();
auto slice_input_ty = slice_input.getType().cast<ShapedType>();
auto concat = dyn_cast_or_null<ConcatenateOp>(slice_input.getDefiningOp());
if (!concat) {
return failure();
}
auto dimension = concat.dimension().getSExtValue();
auto start = slice.start_indices().getIntValues();
auto limit = slice.limit_indices().getIntValues();
auto slice_start = (*(start.begin() + dimension)).getSExtValue();
auto slice_limit = (*(limit.begin() + dimension)).getSExtValue();
// We need to determine what inputs from the concat affect the slice, and
// how the bounds of the slice need to be updated for the minimally required
// inputs.
int64_t running_size = 0;
int64_t front_offset = slice_input_ty.getShape()[dimension];
auto subset_start = concat.operand_end();
auto subset_end = concat.operand_end();
for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
auto input = *it;
ShapedType input_ty = input.getType().cast<ShapedType>();
if (input_ty.isDynamicDim(dimension)) {
return failure();
}
auto dim_size = input_ty.getShape()[dimension];
// If this position is in the slice its the start of the subset and we
// need to update the start and limit values.
if (running_size + dim_size > slice_start &&
subset_start == concat.operand_end()) {
subset_start = it;
front_offset = running_size;
}
// Determine the last required offset.
if (running_size < slice_limit) {
subset_end = it + 1;
}
running_size += dim_size;
}
auto subset_size = subset_end - subset_start;
// We need all inputs so no optimization.
if (subset_size == concat.getNumOperands()) {
return failure();
}
if (subset_size > 1 && !concat.getResult().hasOneUse()) {
return failure();
}
auto concat_range = OperandRange(subset_start, subset_end);
auto new_concat = rewriter.create<ConcatenateOp>(
concat.getLoc(), concat_range, concat.dimension());
llvm::SmallVector<APInt, 6> new_start(start);
llvm::SmallVector<APInt, 6> new_limit(limit);
new_start[dimension] -= front_offset;
new_limit[dimension] -= front_offset;
auto attr_type = slice.start_indices().getType().cast<ShapedType>();
auto create = rewriter.create<SliceOp>(
slice.getLoc(), new_concat,
DenseIntElementsAttr::get(attr_type, new_start),
DenseIntElementsAttr::get(attr_type, new_limit), slice.strides());
rewriter.replaceOp(slice, create.getResult());
return success();
}
};
} // namespace
void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<SimplifyConcatSlice>(context);
}
// Returns output dimension size for slice result for the given arguments.
// Returns -1 if arguments are illegal.
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
int64_t stride) {
if (input_dim == -1 || start < 0 || start > end || end > input_dim ||
stride == 0)
return -1;
return llvm::divideCeil(end - start, stride);
}
Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
Type ty = operand.getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return ty;
int64_t rank = ranked_ty.getRank();
// Illegal attributes.
ShapedType attr_ty = start_indices.getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isSignlessInteger(64) ||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
return ty;
SmallVector<int64_t, 4> start(start_indices.getValues<int64_t>());
SmallVector<int64_t, 4> limit(limit_indices.getValues<int64_t>());
SmallVector<int64_t, 4> stride_vals(strides.getValues<int64_t>());
SmallVector<int64_t, 4> shape;
shape.reserve(rank);
for (int64_t i = 0, e = rank; i != e; i++) {
shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i],
stride_vals[i]));
}
return RankedTensorType::get(shape, ranked_ty.getElementType());
}
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
void SortOp::build(OpBuilder& builder, OperationState& state,
ValueRange operands, int64_t dimension, bool is_stable) {
state.addOperands(operands);
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder.getTupleType(element_types));
state.addRegion();
}
static LogicalResult Verify(SortOp op) {
Operation::operand_range operands = op.operands();
if (operands.empty()) return op.emitOpError("requires at least one input");
// TODO(antiagainst): verify partionally dynamic shapes
if (llvm::all_of(operands, [](Value operand) {
return operand.getType().cast<ShapedType>().hasRank();
})) {
ArrayRef<int64_t> input_shape =
(*operands.begin()).getType().cast<ShapedType>().getShape();
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
return operand.getType().cast<ShapedType>().getShape() != input_shape;
}))
return op.emitOpError("requires all inputs to have the same dimensions");
int64_t rank = input_shape.size();
int64_t cmp_dim = op.dimension().getSExtValue();
if (cmp_dim < -rank || cmp_dim >= rank)
return op.emitOpError("dimension attribute value must be in range [-")
<< rank << ", " << rank << "), but found " << cmp_dim;
}
Block& block = op.comparator().front();
size_t num_operands = op.getOperation()->getNumOperands();
if (block.getNumArguments() != 2 * num_operands)
return op.emitOpError("comparator block should have ")
<< 2 * num_operands << " arguments";
for (auto indexed_operand : llvm::enumerate(operands)) {
int index = indexed_operand.index();
Type element_type =
indexed_operand.value().getType().cast<ShapedType>().getElementType();
Type tensor_type = RankedTensorType::get({}, element_type);
for (int i : {2 * index, 2 * index + 1}) {
Type arg_type = block.getArgument(i).getType();
if (arg_type != tensor_type)
return op.emitOpError("comparator block argument #")
<< i << " should be of type " << tensor_type << " but got "
<< arg_type;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
for (auto it : llvm::enumerate(permutation().getValues<APInt>())) {
if (it.index() != it.value()) {
return {};
}
}
return getOperand();
}
static LogicalResult Verify(TransposeOp op) {
// permutation is an attribute of the op so it has static shape.
auto permutationType = op.permutation().getType();
auto permutationRank = permutationType.getRank();
if (permutationRank != 1) {
return op.emitOpError(llvm::formatv(
"permutation has rank {0} instead of rank 1", permutationRank));
}
auto permutationSize = permutationType.getNumElements();
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
if (operandType) {
auto operandRank = operandType.getRank();
if (operandRank != permutationSize) {
return op.emitOpError(llvm::formatv(
"operand rank ({0}) does not match permutation size ({1})",
operandRank, permutationSize));
}
}
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
if (resultType) {
auto resultRank = resultType.getRank();
if (resultRank != permutationSize) {
return op.emitOpError(llvm::formatv(
"result rank ({0}) does not match permutation size ({1})", resultRank,
permutationSize));
}
}
if (!resultType || !operandType) return success();
auto operandRank = operandType.getRank();
SmallVector<int64_t, 4> expectedShape(operandRank);
for (int i = 0; i != operandRank; ++i) {
auto permutedDim = op.permutation().getValue<IntegerAttr>(i).getInt();
expectedShape[i] = operandType.getDimSize(permutedDim);
}
auto expectedType =
RankedTensorType::get(expectedShape, resultType.getElementType());
if (failed(verifyCompatibleShape(resultType, expectedType))) {
return op.emitOpError(llvm::formatv(
"result type {0} is incompatible with the expected type {1}",
resultType, expectedType));
}
return success();
}
//===----------------------------------------------------------------------===//
// TriangularSolveOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TriangularSolveOp op) {
auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
// Skip verifier if a is unranked tensor.
if (!a_type) return success();
// Check that a should have rank >= 2
auto a_rank = a_type.getRank();
if (a_rank < 2)
return op.emitOpError()
<< "operand 'a' must have rank >= 2, but got " << a_type;
// The two minor dimensions of a must have same size.
if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
return op.emitOpError() << "two minor dimensions of operand 'a' must have "
"equal size, but got "
<< a_type;
auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
// If b is unranked skip remaining checks.
if (!b_type) return success();
// Check that a and b have same rank.
auto b_rank = b_type.getRank();
if (a_rank != b_rank)
return op.emitOpError() << "operands must have equal rank, but got "
<< a_type << " and " << b_type;
// The shared dimension of a and b should match.
if (a_type.getDimSize(a_rank - 1) !=
b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
"not match, but got "
<< a_type << " and " << b_type;
// The leading batch dimensions of a and b must be equal.
auto a_batch_dims = a_type.getShape().drop_back(2);
auto b_batch_dims = b_type.getShape().drop_back(2);
if (a_batch_dims != b_batch_dims)
return op.emitOpError()
<< "leading batch dimensions of the operands must be same, but got "
<< a_type << " and " << b_type;
// Result and argument b must have same shape.
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return success();
if (result_type != b_type)
return op.emitOpError()
<< "result and operand 'b' must have same shape, but got "
<< result_type << " and " << b_type;
return success();
}
//===----------------------------------------------------------------------===//
// GetTupleElementOp
//===----------------------------------------------------------------------===//
void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
Value tuple, int32_t index) {
if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
auto element_type = tuple_type.getType(index);
build(builder, result, element_type, tuple,
builder.getI32IntegerAttr(index));
return;
}
build(builder, result, tuple.getType(), tuple,
builder.getI32IntegerAttr(index));
}
//===----------------------------------------------------------------------===//
// TupleOp
//===----------------------------------------------------------------------===//
void TupleOp::build(OpBuilder& builder, OperationState& result,
ValueRange values) {
SmallVector<Type, 4> types;
types.reserve(values.size());
for (auto val : values) {
types.push_back(val.getType());
}
build(builder, result, builder.getTupleType(types), values);
}
//===----------------------------------------------------------------------===//
// UnaryEinsumOp
//===----------------------------------------------------------------------===//
void UnaryEinsumOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<UnaryEinsumToEinsum>(context);
}
//===----------------------------------------------------------------------===//
// CompareOp
//===----------------------------------------------------------------------===//
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, StringAttr comparison_direction) {
auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction);
}
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
//===----------------------------------------------------------------------===//
// mhlo Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct HLOInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// We don't have any special restrictions on what can be inlined into
// destination regions (e.g. while/conditional bodies). Always allow it.
bool isLegalToInline(Region* dest, Region* src,
BlockAndValueMapping& valueMapping) const final {
return true;
}
// Operations in mhlo dialect are always legal to inline since they are
// pure.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
return true;
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// mhlo Dialect Constructor
//===----------------------------------------------------------------------===//
MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
>();
addInterfaces<HLOInlinerInterface>();
addTypes<TokenType>();
}
Type MhloDialect::parseType(DialectAsmParser& parser) const {
StringRef data_type;
if (parser.parseKeyword(&data_type)) return Type();
if (data_type == "token") return TokenType::get(getContext());
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
return nullptr;
}
void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
if (type.isa<TokenType>()) {
os << "token";
return;
}
os << "<unknown mhlo type>";
}
//===----------------------------------------------------------------------===//
// Shape inference
//===----------------------------------------------------------------------===//
LogicalResult deriveShapeFromFirstOperand(
OpBuilder* builder, Operation* op,
SmallVectorImpl<Value>* reifiedReturnShapes) {
Value operand = op->getOperand(0);
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
if (!operand_type) {
op->emitOpError() << "first operand is not a shaped type";
return failure();
}
auto loc = op->getLoc();
SmallVector<Value, 4> shape_values;
shape_values.reserve(operand_type.getRank());
auto shape_scalar_type = builder->getIntegerType(64);
for (auto element : llvm::enumerate(operand_type.getShape())) {
if (element.value() == ShapedType::kDynamicSize) {
Value dim = builder->create<DimOp>(loc, operand, element.index());
shape_values.push_back(
builder->create<IndexCastOp>(loc, dim, shape_scalar_type));
} else {
shape_values.push_back(builder->create<ConstantOp>(
loc, builder->getI64IntegerAttr(element.value())));
}
}
*reifiedReturnShapes = SmallVector<Value, 1>{
builder->create<TensorFromElementsOp>(loc, shape_values)};
return success();
}
} // namespace mhlo
} // namespace mlir