Add chlo.constant_like op which splats a constant to shape of operand
This allows specifying a constant whose shape is only known when operand shape is. Also use it to update tf.Acos legalization. PiperOrigin-RevId: 325860604
This commit is contained in:
parent
e340e41059
commit
5dac76f4af
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
@ -46,6 +47,19 @@ class HloClientDialect : public Dialect {
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Value getConstantLike(OpBuilder& b, T constant, Value val) {
|
||||||
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
|
|
||||||
|
auto getAttr = [&]() -> Attribute {
|
||||||
|
if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
|
||||||
|
if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
|
||||||
|
llvm_unreachable("unhandled element type");
|
||||||
|
};
|
||||||
|
// TODO(jpienaar): Add ability to pass loc via native call and update.
|
||||||
|
return b.create<ConstantLikeOp>(b.getUnknownLoc(), getAttr(), val);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -364,6 +364,24 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
|
||||||
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
InferTypeOpInterface,
|
||||||
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
||||||
|
NativeOpTrait<"InferTensorType">]> {
|
||||||
|
let summary = "Constant like operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns a splat constant of the same shape as the operand.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// TODO(jpienaar): value's type could be tightened.
|
||||||
|
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Broadcasting compare op
|
// Broadcasting compare op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||||
class ConstantSplat<string value> : NativeCodeCall<
|
class ConstantSplat<string value> : NativeCodeCall<
|
||||||
"hlo::getSplat(&$_builder, $0, " # value # ")">;
|
"hlo::getSplat(&$_builder, $0, " # value # ")">;
|
||||||
|
|
||||||
|
class HLO_ConstantLike<string value> : NativeCodeCall<
|
||||||
|
"chlo::getConstantLike($_builder, " # value # ", $0)">;
|
||||||
|
|
||||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||||
|
|
||||||
def BinBroadcastDimensions : NativeCodeCall<
|
def BinBroadcastDimensions : NativeCodeCall<
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the canonicalize pattern definition file.
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||||
|
|
||||||
|
def UnaryToBinaryEinsumEq : NativeCodeCall<
|
||||||
|
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
||||||
|
|
||||||
|
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
|
||||||
|
// operand.
|
||||||
|
def UnaryEinsumToEinsum : Pat<
|
||||||
|
(HLO_UnaryEinsumOp $operand, $equation),
|
||||||
|
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||||
|
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
|
@ -15,10 +15,12 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
|
@ -259,6 +261,48 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||||
#undef BROADCAST_BINARY_OP_DEFS
|
#undef BROADCAST_BINARY_OP_DEFS
|
||||||
|
|
||||||
|
static LogicalResult Verify(ConstantLikeOp op) {
|
||||||
|
if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
|
||||||
|
return op.emitOpError() << "value's type doesn't match element return type";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||||
|
ConstantLikeOp::Adaptor op(operands, attributes);
|
||||||
|
if (failed(op.verify(location.getValue()))) return failure();
|
||||||
|
Type element_type = op.value().getType();
|
||||||
|
Type operand_type = op.operand().getType();
|
||||||
|
if (operand_type.isa<UnrankedTensorType>()) {
|
||||||
|
inferedReturnShapes.emplace_back(element_type);
|
||||||
|
} else {
|
||||||
|
const auto& shape = operand_type.cast<RankedTensorType>().getShape();
|
||||||
|
inferedReturnShapes.emplace_back(shape, element_type);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
|
||||||
|
using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ConstantLikeOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto op_type = op.operand().getType().cast<ShapedType>();
|
||||||
|
if (!op_type.hasStaticShape()) return failure();
|
||||||
|
auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
|
||||||
|
ElementsAttr attr = DenseElementsAttr::get(type, op.value());
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ConstantLikeOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<ConstantLikeToConstant>(context);
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||||
|
|
||||||
|
|
|
@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||||
return %2 : tensor<2x2xi32>
|
return %2 : tensor<2x2xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: constant_like_constant
|
||||||
|
func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> {
|
||||||
|
// CHECK: mhlo.constant dense<3.200000e+00>
|
||||||
|
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32>
|
||||||
|
return %0 : tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: constant_like_constant_dynamic
|
||||||
|
func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: chlo.constant_like
|
||||||
|
%0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: dynamic_slice_variable_start
|
// CHECK-LABEL: dynamic_slice_variable_start
|
||||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||||
// CHECK: "mhlo.dynamic-slice"
|
// CHECK: "mhlo.dynamic-slice"
|
||||||
|
|
Loading…
Reference in New Issue