Add chlo.constant_like op which splats a constant to shape of operand
This allows specifying a constant whose shape is only known when operand shape is. Also use it to update tf.Acos legalization. PiperOrigin-RevId: 325860604
This commit is contained in:
		
							parent
							
								
									e340e41059
								
							
						
					
					
						commit
						5dac76f4af
					
				| 
						 | 
					@ -24,6 +24,7 @@ limitations under the License.
 | 
				
			||||||
#include "mlir/IR/OpDefinition.h"
 | 
					#include "mlir/IR/OpDefinition.h"
 | 
				
			||||||
#include "mlir/IR/Operation.h"
 | 
					#include "mlir/IR/Operation.h"
 | 
				
			||||||
#include "mlir/IR/StandardTypes.h"
 | 
					#include "mlir/IR/StandardTypes.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/TypeUtilities.h"
 | 
				
			||||||
#include "mlir/IR/Types.h"
 | 
					#include "mlir/IR/Types.h"
 | 
				
			||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
 | 
					#include "mlir/Interfaces/InferTypeOpInterface.h"
 | 
				
			||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
 | 
					#include "mlir/Interfaces/SideEffectInterfaces.h"
 | 
				
			||||||
| 
						 | 
					@ -46,6 +47,19 @@ class HloClientDialect : public Dialect {
 | 
				
			||||||
#define GET_OP_CLASSES
 | 
					#define GET_OP_CLASSES
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					static Value getConstantLike(OpBuilder& b, T constant, Value val) {
 | 
				
			||||||
 | 
					  Type ty = getElementTypeOrSelf(val.getType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto getAttr = [&]() -> Attribute {
 | 
				
			||||||
 | 
					    if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
 | 
				
			||||||
 | 
					    if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
 | 
				
			||||||
 | 
					    llvm_unreachable("unhandled element type");
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  // TODO(jpienaar): Add ability to pass loc via native call and update.
 | 
				
			||||||
 | 
					  return b.create<ConstantLikeOp>(b.getUnknownLoc(), getAttr(), val);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace chlo
 | 
					}  // namespace chlo
 | 
				
			||||||
}  // namespace mlir
 | 
					}  // namespace mlir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -364,6 +364,24 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
 | 
				
			||||||
 | 
					    [NoSideEffect, SameOperandsAndResultShape,
 | 
				
			||||||
 | 
					     InferTypeOpInterface,
 | 
				
			||||||
 | 
					     DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
 | 
				
			||||||
 | 
					     NativeOpTrait<"InferTensorType">]> {
 | 
				
			||||||
 | 
					  let summary = "Constant like operator";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let description = [{
 | 
				
			||||||
 | 
					    Returns a splat constant of the same shape as the operand.
 | 
				
			||||||
 | 
					  }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO(jpienaar): value's type could be tightened.
 | 
				
			||||||
 | 
					  let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
 | 
				
			||||||
 | 
					  let results = (outs HLO_Tensor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Broadcasting compare op
 | 
					// Broadcasting compare op
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
 | 
				
			||||||
class ConstantSplat<string value> : NativeCodeCall<
 | 
					class ConstantSplat<string value> : NativeCodeCall<
 | 
				
			||||||
    "hlo::getSplat(&$_builder, $0, " # value # ")">;
 | 
					    "hlo::getSplat(&$_builder, $0, " # value # ")">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HLO_ConstantLike<string value> : NativeCodeCall<
 | 
				
			||||||
 | 
					    "chlo::getConstantLike($_builder, " # value # ", $0)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
 | 
					def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def BinBroadcastDimensions : NativeCodeCall<
 | 
					def BinBroadcastDimensions : NativeCodeCall<
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,30 @@
 | 
				
			||||||
 | 
					/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// This is the canonicalize pattern definition file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					include "mlir/IR/OpBase.td"
 | 
				
			||||||
 | 
					include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
 | 
				
			||||||
 | 
					include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def UnaryToBinaryEinsumEq : NativeCodeCall<
 | 
				
			||||||
 | 
					  "$_builder.getStringAttr(\",\" + $0.getValue().str())">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
 | 
				
			||||||
 | 
					// operand.
 | 
				
			||||||
 | 
					def UnaryEinsumToEinsum : Pat<
 | 
				
			||||||
 | 
					  (HLO_UnaryEinsumOp $operand, $equation),
 | 
				
			||||||
 | 
					  (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
 | 
				
			||||||
 | 
					                $operand, (UnaryToBinaryEinsumEq $equation))>;
 | 
				
			||||||
| 
						 | 
					@ -15,10 +15,12 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
				
			||||||
#include "mlir-hlo/utils/broadcast_utils.h"
 | 
					#include "mlir-hlo/utils/broadcast_utils.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
#include "mlir/IR/Builders.h"
 | 
					#include "mlir/IR/Builders.h"
 | 
				
			||||||
#include "mlir/IR/Diagnostics.h"
 | 
					#include "mlir/IR/Diagnostics.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/PatternMatch.h"
 | 
				
			||||||
#include "mlir/IR/StandardTypes.h"
 | 
					#include "mlir/IR/StandardTypes.h"
 | 
				
			||||||
#include "mlir/IR/TypeUtilities.h"
 | 
					#include "mlir/IR/TypeUtilities.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -259,6 +261,48 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
 | 
				
			||||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
 | 
					#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
 | 
				
			||||||
#undef BROADCAST_BINARY_OP_DEFS
 | 
					#undef BROADCAST_BINARY_OP_DEFS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static LogicalResult Verify(ConstantLikeOp op) {
 | 
				
			||||||
 | 
					  if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
 | 
				
			||||||
 | 
					    return op.emitOpError() << "value's type doesn't match element return type";
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					LogicalResult ConstantLikeOp::inferReturnTypeComponents(
 | 
				
			||||||
 | 
					    MLIRContext* context, Optional<Location> location, ValueRange operands,
 | 
				
			||||||
 | 
					    DictionaryAttr attributes, RegionRange regions,
 | 
				
			||||||
 | 
					    SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
 | 
				
			||||||
 | 
					  ConstantLikeOp::Adaptor op(operands, attributes);
 | 
				
			||||||
 | 
					  if (failed(op.verify(location.getValue()))) return failure();
 | 
				
			||||||
 | 
					  Type element_type = op.value().getType();
 | 
				
			||||||
 | 
					  Type operand_type = op.operand().getType();
 | 
				
			||||||
 | 
					  if (operand_type.isa<UnrankedTensorType>()) {
 | 
				
			||||||
 | 
					    inferedReturnShapes.emplace_back(element_type);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    const auto& shape = operand_type.cast<RankedTensorType>().getShape();
 | 
				
			||||||
 | 
					    inferedReturnShapes.emplace_back(shape, element_type);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
 | 
				
			||||||
 | 
					  using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(ConstantLikeOp op,
 | 
				
			||||||
 | 
					                                PatternRewriter& rewriter) const override {
 | 
				
			||||||
 | 
					    auto op_type = op.operand().getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					    if (!op_type.hasStaticShape()) return failure();
 | 
				
			||||||
 | 
					    auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
 | 
				
			||||||
 | 
					    ElementsAttr attr = DenseElementsAttr::get(type, op.value());
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ConstantLikeOp::getCanonicalizationPatterns(
 | 
				
			||||||
 | 
					    OwningRewritePatternList& results, MLIRContext* context) {
 | 
				
			||||||
 | 
					  results.insert<ConstantLikeToConstant>(context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define GET_OP_CLASSES
 | 
					#define GET_OP_CLASSES
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
 | 
				
			||||||
  return %2 : tensor<2x2xi32>
 | 
					  return %2 : tensor<2x2xi32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: constant_like_constant
 | 
				
			||||||
 | 
					func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> {
 | 
				
			||||||
 | 
					  // CHECK: mhlo.constant dense<3.200000e+00>
 | 
				
			||||||
 | 
					  %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<3x4xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: constant_like_constant_dynamic
 | 
				
			||||||
 | 
					func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  // CHECK: chlo.constant_like
 | 
				
			||||||
 | 
					  %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: dynamic_slice_variable_start
 | 
					// CHECK-LABEL: dynamic_slice_variable_start
 | 
				
			||||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
 | 
					func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
 | 
				
			||||||
  // CHECK: "mhlo.dynamic-slice"
 | 
					  // CHECK: "mhlo.dynamic-slice"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue