Add chlo.constant_like op which splats a constant to shape of operand

This allows specifying a constant whose shape is only known when operand shape is. Also use it to update tf.Acos legalization.

PiperOrigin-RevId: 325860604
This commit is contained in:
Jacques Pienaar 2020-08-10 12:19:15 -07:00 committed by TensorFlow MLIR Team
parent e340e41059
commit 5dac76f4af
6 changed files with 123 additions and 0 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
@ -46,6 +47,19 @@ class HloClientDialect : public Dialect {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
template <typename T>
static Value getConstantLike(OpBuilder& b, T constant, Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
llvm_unreachable("unhandled element type");
};
// TODO(jpienaar): Add ability to pass loc via native call and update.
return b.create<ConstantLikeOp>(b.getUnknownLoc(), getAttr(), val);
}
} // namespace chlo } // namespace chlo
} // namespace mlir } // namespace mlir

View File

@ -364,6 +364,24 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
}]; }];
} }
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
NativeOpTrait<"InferTensorType">]> {
let summary = "Constant like operator";
let description = [{
Returns a splat constant of the same shape as the operand.
}];
// TODO(jpienaar): value's type could be tightened.
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Broadcasting compare op // Broadcasting compare op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall< class ConstantSplat<string value> : NativeCodeCall<
"hlo::getSplat(&$_builder, $0, " # value # ")">; "hlo::getSplat(&$_builder, $0, " # value # ")">;
class HLO_ConstantLike<string value> : NativeCodeCall<
"chlo::getConstantLike($_builder, " # value # ", $0)">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall< def BinBroadcastDimensions : NativeCodeCall<

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the canonicalize pattern definition file.
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
def UnaryToBinaryEinsumEq : NativeCodeCall<
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
// operand.
def UnaryEinsumToEinsum : Pat<
(HLO_UnaryEinsumOp $operand, $equation),
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
$operand, (UnaryToBinaryEinsumEq $equation))>;

View File

@ -15,10 +15,12 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/broadcast_utils.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
@ -259,6 +261,48 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
#undef BROADCAST_BINARY_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS
static LogicalResult Verify(ConstantLikeOp op) {
if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
return op.emitOpError() << "value's type doesn't match element return type";
return success();
}
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ConstantLikeOp::Adaptor op(operands, attributes);
if (failed(op.verify(location.getValue()))) return failure();
Type element_type = op.value().getType();
Type operand_type = op.operand().getType();
if (operand_type.isa<UnrankedTensorType>()) {
inferedReturnShapes.emplace_back(element_type);
} else {
const auto& shape = operand_type.cast<RankedTensorType>().getShape();
inferedReturnShapes.emplace_back(shape, element_type);
}
return success();
}
struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConstantLikeOp op,
PatternRewriter& rewriter) const override {
auto op_type = op.operand().getType().cast<ShapedType>();
if (!op_type.hasStaticShape()) return failure();
auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
ElementsAttr attr = DenseElementsAttr::get(type, op.value());
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
return success();
}
};
void ConstantLikeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<ConstantLikeToConstant>(context);
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"

View File

@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
return %2 : tensor<2x2xi32> return %2 : tensor<2x2xi32>
} }
// CHECK-LABEL: constant_like_constant
func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> {
// CHECK: mhlo.constant dense<3.200000e+00>
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
// CHECK-LABEL: constant_like_constant_dynamic
func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> {
// CHECK: chlo.constant_like
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: dynamic_slice_variable_start // CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> { func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "mhlo.dynamic-slice" // CHECK: "mhlo.dynamic-slice"