Add chlo.constant_like op which splats a constant to shape of operand
This allows specifying a constant whose shape is only known when operand shape is. Also use it to update tf.Acos legalization. PiperOrigin-RevId: 325860604
This commit is contained in:
parent
e340e41059
commit
5dac76f4af
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
@ -46,6 +47,19 @@ class HloClientDialect : public Dialect {
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||
|
||||
template <typename T>
|
||||
static Value getConstantLike(OpBuilder& b, T constant, Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
|
||||
auto getAttr = [&]() -> Attribute {
|
||||
if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
|
||||
if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
|
||||
llvm_unreachable("unhandled element type");
|
||||
};
|
||||
// TODO(jpienaar): Add ability to pass loc via native call and update.
|
||||
return b.create<ConstantLikeOp>(b.getUnknownLoc(), getAttr(), val);
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -364,6 +364,24 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
InferTypeOpInterface,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
||||
NativeOpTrait<"InferTensorType">]> {
|
||||
let summary = "Constant like operator";
|
||||
|
||||
let description = [{
|
||||
Returns a splat constant of the same shape as the operand.
|
||||
}];
|
||||
|
||||
// TODO(jpienaar): value's type could be tightened.
|
||||
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Broadcasting compare op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
|||
class ConstantSplat<string value> : NativeCodeCall<
|
||||
"hlo::getSplat(&$_builder, $0, " # value # ")">;
|
||||
|
||||
class HLO_ConstantLike<string value> : NativeCodeCall<
|
||||
"chlo::getConstantLike($_builder, " # value # ", $0)">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the canonicalize pattern definition file.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
|
||||
def UnaryToBinaryEinsumEq : NativeCodeCall<
|
||||
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
||||
|
||||
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
|
||||
// operand.
|
||||
def UnaryEinsumToEinsum : Pat<
|
||||
(HLO_UnaryEinsumOp $operand, $equation),
|
||||
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
|
@ -15,10 +15,12 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
|
@ -259,6 +261,48 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
|||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||
#undef BROADCAST_BINARY_OP_DEFS
|
||||
|
||||
static LogicalResult Verify(ConstantLikeOp op) {
|
||||
if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
|
||||
return op.emitOpError() << "value's type doesn't match element return type";
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
ConstantLikeOp::Adaptor op(operands, attributes);
|
||||
if (failed(op.verify(location.getValue()))) return failure();
|
||||
Type element_type = op.value().getType();
|
||||
Type operand_type = op.operand().getType();
|
||||
if (operand_type.isa<UnrankedTensorType>()) {
|
||||
inferedReturnShapes.emplace_back(element_type);
|
||||
} else {
|
||||
const auto& shape = operand_type.cast<RankedTensorType>().getShape();
|
||||
inferedReturnShapes.emplace_back(shape, element_type);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
|
||||
using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConstantLikeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto op_type = op.operand().getType().cast<ShapedType>();
|
||||
if (!op_type.hasStaticShape()) return failure();
|
||||
auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
|
||||
ElementsAttr attr = DenseElementsAttr::get(type, op.value());
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ConstantLikeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<ConstantLikeToConstant>(context);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
|
||||
|
|
|
@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
|||
return %2 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: constant_like_constant
|
||||
func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> {
|
||||
// CHECK: mhlo.constant dense<3.200000e+00>
|
||||
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: constant_like_constant_dynamic
|
||||
func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> {
|
||||
// CHECK: chlo.constant_like
|
||||
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// CHECK: "mhlo.dynamic-slice"
|
||||
|
|
Loading…
Reference in New Issue