OneHotEncoder Shape Inference (#265)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * redo get onnx.md and onnxop.td.inc using onnx 1.6 * Add shape inference for scaler op * add benefit for scaler decompose and simplify scaler shape inference * add scaler decompose benefit num and simplify shape inference * add cast builder * cast rewrite only for float * add cast op same type rewrite rule * working on cast lowering * cast lowering working * add cast lowering * fix format * Delete OpBuildTable.inc * complete requested changes Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
d3dcee7366
commit
00299910f3
|
@ -2775,6 +2775,45 @@ LogicalResult ONNXDropoutOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OneHotEncoder
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ONNXOneHotEncoderOp::inferShapes() {
|
||||||
|
ShapedType inputType = X().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!inputType)
|
||||||
|
return emitError("Non-shaped input type");
|
||||||
|
auto shape = inputType.getShape();
|
||||||
|
int64_t outDim = 0;
|
||||||
|
|
||||||
|
// If the input is a tensor of float, int32, or double,
|
||||||
|
// the data will be cast to integers and
|
||||||
|
// the cats_int64s category list will be used for the lookups.
|
||||||
|
if (inputType.getElementType().isIntOrFloat()) {
|
||||||
|
if (!cats_int64s())
|
||||||
|
return emitError("input is a tensor of float, int32, or double, but no "
|
||||||
|
"cats_int64s attribute");
|
||||||
|
outDim = ArrayAttrSize(cats_int64s());
|
||||||
|
} else {
|
||||||
|
if (!cats_strings())
|
||||||
|
return emitError("input is not a tensor of float, int32, or double, but "
|
||||||
|
"no cats_strings attribute");
|
||||||
|
outDim = ArrayAttrSize(cats_strings());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoded output data, having one more dimension than X
|
||||||
|
// total category count will determine the size of the extra dimension
|
||||||
|
SmallVector<int64_t, 2> dims;
|
||||||
|
for (int i = 0; i != shape.size(); ++i) {
|
||||||
|
dims.emplace_back(shape[i]);
|
||||||
|
}
|
||||||
|
dims.emplace_back(outDim);
|
||||||
|
|
||||||
|
getResult().setType(
|
||||||
|
RankedTensorType::get(dims, FloatType::getF32(getContext())));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNX type related code
|
// ONNX type related code
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6019,7 +6019,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder",
|
def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX OneHotEncoder operation";
|
let summary = "ONNX OneHotEncoder operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Replace each input element with an array of ones and zeros, where a single"
|
"Replace each input element with an array of ones and zeros, where a single"
|
||||||
|
|
|
@ -1480,3 +1480,51 @@ func @test_dropout(%arg0: tensor<1x2x3x4xf32>) -> (tensor<*xf32>, tensor<*xi1>)
|
||||||
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
|
// CHECK: [[RES:%.+]], [[MASK:%.+]] = "onnx.Dropout"(%arg0) {ratio = 1.000000e-01 : f32} : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>)
|
||||||
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
|
// CHECK: return [[RES]], [[MASK]] : tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test shape inference for OneHotEncoder.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_onehotencoder_string1 (%arg0: tensor<20x1x!onnx.String>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_onehotencoder_string1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1x!onnx.String>) -> tensor<20x1x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<20x1x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_onehotencoder_string2 (%arg0: tensor<20x2x!onnx.String>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_onehotencoder_string2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x!onnx.String>) -> tensor<20x2x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<20x2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_onehotencoder_float1(%arg0: tensor<20x1xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_onehotencoder_float1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x1xf32>) -> tensor<20x1x3xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<20x1x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_onehotencoder_float2(%arg0: tensor<20x2x3xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.OneHotEncoder"(%arg0) {cats_strings = ["female", "male"], cats_int64s = [1, 2, 4], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_onehotencoder_float2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.OneHotEncoder"(%arg0) {cats_int64s = [1, 2, 4], cats_strings = ["female", "male"], zeros = 1 : i64} : (tensor<20x2x3xf32>) -> tensor<20x2x3x3xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<20x2x3x3xf32>
|
||||||
|
}
|
|
@ -280,6 +280,7 @@ OpsWithShapeInference=[
|
||||||
'Min',
|
'Min',
|
||||||
'Mul',
|
'Mul',
|
||||||
'Neg',
|
'Neg',
|
||||||
|
'OneHotEncoder',
|
||||||
'Or',
|
'Or',
|
||||||
'Pad',
|
'Pad',
|
||||||
'Pow',
|
'Pow',
|
||||||
|
|
Loading…
Reference in New Issue