Rename XlaHloDialect class into MhloDialect following the recent dialect namespace renaming

PiperOrigin-RevId: 320213526
This commit is contained in:
Mehdi Amini 2020-07-08 17:19:13 +00:00 committed by Mehdi Amini
parent 94dcb90d38
commit a575636862
7 changed files with 12 additions and 13 deletions

View File

@ -39,9 +39,9 @@ class OpBuilder;
namespace mhlo { namespace mhlo {
class XlaHloDialect : public Dialect { class MhloDialect : public Dialect {
public: public:
explicit XlaHloDialect(MLIRContext *context); explicit MhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "mhlo"; } static StringRef getDialectNamespace() { return "mhlo"; }
// Registered hook to materialize a constant operation from a given attribute // Registered hook to materialize a constant operation from a given attribute

View File

@ -18,6 +18,6 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration. // Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops; static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops; static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops; static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -62,9 +62,8 @@ namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace mhlo { namespace mhlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
Attribute value, Type type, Type type, Location loc) {
Location loc) {
// HLO dialect constants only support ElementsAttr unlike standard dialect // HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes. // constant which supports all attributes.
if (value.isa<ElementsAttr>()) if (value.isa<ElementsAttr>())
@ -2128,7 +2127,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
// mhlo Dialect Constructor // mhlo Dialect Constructor
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
XlaHloDialect::XlaHloDialect(MLIRContext* context) MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
@ -2140,7 +2139,7 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
// allowUnknownOperations(); // allowUnknownOperations();
} }
Type XlaHloDialect::parseType(DialectAsmParser& parser) const { Type MhloDialect::parseType(DialectAsmParser& parser) const {
StringRef data_type; StringRef data_type;
if (parser.parseKeyword(&data_type)) return Type(); if (parser.parseKeyword(&data_type)) return Type();
@ -2149,7 +2148,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
return nullptr; return nullptr;
} }
void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
if (type.isa<TokenType>()) { if (type.isa<TokenType>()) {
os << "token"; os << "token";
return; return;

View File

@ -33,7 +33,7 @@ struct TestChloLegalizeToHloPass
conversionTarget.addIllegalDialect<HloClientDialect>(); conversionTarget.addIllegalDialect<HloClientDialect>();
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>(); conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect. // The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>(); conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();

View File

@ -413,7 +413,7 @@ struct HloLegalizeToLhlo
target.addIllegalOp<mlir::TensorStoreOp>(); target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>(); target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>(); target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<mhlo::XlaHloDialect>(); target.addIllegalDialect<mhlo::MhloDialect>();
BufferAssignmentTypeConverter converter; BufferAssignmentTypeConverter converter;
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); }; auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };

View File

@ -34,7 +34,7 @@ struct TestMaterializeBroadcastsPass
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>(); conversionTarget.addLegalDialect<MhloDialect>();
// The conversion uses helpers from the Standard dialect. // The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();

View File

@ -152,7 +152,7 @@ struct TransformUnrankedHloPass
// Setup conversion target. // Setup conversion target.
MLIRContext &ctx = getContext(); MLIRContext &ctx = getContext();
ConversionTarget target(ctx); ConversionTarget target(ctx);
target.addLegalDialect<XlaHloDialect, StandardOpsDialect, target.addLegalDialect<MhloDialect, StandardOpsDialect,
shape::ShapeDialect>(); shape::ShapeDialect>();
target.addLegalOp<FuncOp>(); target.addLegalOp<FuncOp>();
AddLegalOpOnRankedTensor<SqrtOp>(&target); AddLegalOpOnRankedTensor<SqrtOp>(&target);