mlir-hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc

146 lines
5.1 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LMHLO dialect.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace mlir {
namespace lmhlo {
LmhloDialect::LmhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloDialect>()) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
>();
}
//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceOp op) {
// AllReduce had variadic operands and results that have the same size.
// Each memeber of the operand should have the same type as the corresponding
// member of the result.
for (auto it : llvm::enumerate(
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
Type operandType = std::get<0>(it.value());
Type resultType = std::get<1>(it.value());
if (operandType != resultType)
return op.emitOpError("requires operand #")
<< it.index() << " (type: " << operandType << ") and result #"
<< it.index() << " (type: " << resultType << ") to have same type";
}
// Since AllReduce has a single reduction computation attached to it (which is
// applied over all the operands and results), they all need to have the same
// element type. Since we already check that each operand and corresponding
// result has the same type, its sufficient to check just the memref element
// type for each operands.
Type elementType =
op.operands().front().getType().cast<MemRefType>().getElementType();
bool allMatch = llvm::all_of(
op.operands().drop_front().getType(), [elementType](Type type) {
return type.cast<MemRefType>().getElementType() == elementType;
});
if (!allMatch)
return op.emitOpError("requires all operands to have same element type");
return success();
}
//===----------------------------------------------------------------------===//
// ConstOp.
//===----------------------------------------------------------------------===//
/// An lho.constant on an memref that is locally allocated and with no other
/// users (other than dealloc's) can be erased.
// TODO: This can be generalized to an arbitrary op by making use of memory
// effects (write memory effect).
struct EraseConstOp : public OpRewritePattern<ConstOp> {
using OpRewritePattern<ConstOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConstOp op,
PatternRewriter& rewriter) const override {
Value memref = op.output();
if (!memref.getDefiningOp<AllocOp>()) {
return failure();
}
// Check that all uses of the memref are either DeallocOps or this op.
for (Operation* user : memref.getUsers())
if (user != op && !isa<DeallocOp>(user)) return failure();
rewriter.eraseOp(op);
return success();
}
};
void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<EraseConstOp>(context);
}
} // namespace lmhlo
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
namespace mlir {
namespace lmhlo {
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
void FusionOp::build(OpBuilder& builder, OperationState& result,
ArrayRef<NamedAttribute> attributes) {
result.addAttributes(attributes);
Region* bodyRegion = result.addRegion();
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
}
} // namespace lmhlo
} // namespace mlir