OneHotEncoder Shape Inference (#265)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* redo get onnx.md and onnxop.td.inc using onnx 1.6

* Add shape inference for scaler op

* add benefit for scaler decompose and simplify scaler shape inference

* add scaler decompose benefit num and simplify shape inference

* add cast builder

* cast rewrite only for float

* add cast op same type rewrite rule

* working on cast lowering

* cast lowering working

* add cast lowering

* fix format

* Delete OpBuildTable.inc

* complete requested changes

Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
Anh Leu 2020-08-14 15:13:31 -05:00 committed by GitHub
parent d3dcee7366
commit 00299910f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 1 deletions

View File

@ -2775,6 +2775,45 @@ LogicalResult ONNXDropoutOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// OneHotEncoder
//===----------------------------------------------------------------------===//
LogicalResult ONNXOneHotEncoderOp::inferShapes() {
ShapedType inputType = X().getType().dyn_cast<ShapedType>();
if (!inputType)
return emitError("Non-shaped input type");
auto shape = inputType.getShape();
int64_t outDim = 0;
// If the input is a tensor of float, int32, or double,
// the data will be cast to integers and
// the cats_int64s category list will be used for the lookups.
if (inputType.getElementType().isIntOrFloat()) {
if (!cats_int64s())
return emitError("input is a tensor of float, int32, or double, but no "
"cats_int64s attribute");
outDim = ArrayAttrSize(cats_int64s());
} else {
if (!cats_strings())
return emitError("input is not a tensor of float, int32, or double, but "
"no cats_strings attribute");
outDim = ArrayAttrSize(cats_strings());
}
// Encoded output data, having one more dimension than X
// total category count will determine the size of the extra dimension
SmallVector<int64_t, 2> dims;
for (int i = 0; i != shape.size(); ++i) {
dims.emplace_back(shape[i]);
}
dims.emplace_back(outDim);
getResult().setType(
RankedTensorType::get(dims, FloatType::getF32(getContext())));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNX type related code // ONNX type related code
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6019,7 +6019,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer",
} }
def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX OneHotEncoder operation"; let summary = "ONNX OneHotEncoder operation";
let description = [{ let description = [{
"Replace each input element with an array of ones and zeros, where a single" "Replace each input element with an array of ones and zeros, where a single"

View File

@ -1480,3 +1480,51 @@ func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>)
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>) // CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1> // CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
} }
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for OneHotEncoder.
//===----------------------------------------------------------------------===//
func @test_onehotencoder_string1 (%arg0: tensor<20x1x!onnx.String>) -> tensor<*xf32> {
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_onehotencoder_string1
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<20x1x2xf32>
// CHECK: return [[RES]] : tensor<20x1x2xf32>
}
// -----
func @test_onehotencoder_string2 (%arg0: tensor<20x2x!onnx.String>) -> tensor<*xf32> {
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_onehotencoder_string2
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<20x2x2xf32>
// CHECK: return [[RES]] : tensor<20x2x2xf32>
}
// -----
func @test_onehotencoder_float1(%arg0: tensor<20x1xf32>) -> tensor<*xf32> {
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_onehotencoder_float1
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<20x1x3xf32>
// CHECK: return [[RES]] : tensor<20x1x3xf32>
}
// -----
func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_onehotencoder_float2
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
}

View File

@ -280,6 +280,7 @@ OpsWithShapeInference=[
'Min', 'Min',
'Mul', 'Mul',
'Neg', 'Neg',
'OneHotEncoder',
'Or', 'Or',
'Pad', 'Pad',
'Pow', 'Pow',