[MLIR][KernelGen] Implement InferShapedTypeOpInterface for `mhlo.compare/select`

PiperOrigin-RevId: 332227340
This commit is contained in:
A. Unique TensorFlower 2020-09-17 07:09:17 -07:00 committed by TensorFlow MLIR Team
parent 91f16172a4
commit b1fd4d27cf
4 changed files with 74 additions and 6 deletions

View File

@ -678,9 +678,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
let hasCanonicalizer = 1;
}
def HLO_CompareOp: HLO_Op<"compare",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>,
BASE_HLO_CompareOp {
def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]>, BASE_HLO_CompareOp {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
@ -1152,7 +1153,10 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
}
// TODO(jpienaar): Add broadcastable trait.
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_SelectOp {
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>, DeclareOpInterfaceMethods<InferTypeOpInterface>,
]>, BASE_HLO_SelectOp {
let arguments = (ins
HLO_PredTensor:$pred,
HLO_Tensor:$on_true,

View File

@ -1678,6 +1678,20 @@ LogicalResult SelectOp::inferReturnTypes(
return success();
}
LogicalResult SelectOp::inferReturnTypeComponents(
mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
mlir::DictionaryAttr, mlir::RegionRange,
llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
// TODO(b/168772852)
return failure();
}
LogicalResult SelectOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
@ -2473,9 +2487,22 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
build(builder, result, new_type, lhs, rhs, comparison_direction);
}
LogicalResult CompareOp::inferReturnTypeComponents(
mlir::MLIRContext*, llvm::Optional<mlir::Location>, mlir::ValueRange,
mlir::DictionaryAttr, mlir::RegionRange,
llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
// TODO(b/168772852)
return failure();
}
LogicalResult CompareOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
} // namespace mhlo
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -mhlo-test-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
// CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so

View File

@ -0,0 +1,37 @@
// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
// -----
// CHECK-LABEL: @select
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
-> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.select"(%pred, %a, %b)
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xf32>) -> tensor<2xi64>
return %1 : tensor<2xi64>
}
// -----
// CHECK-LABEL: @compare
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" }
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
}