[MLIR][KernelGen] Lower mhlo.log_plus_one to std.log1p
PiperOrigin-RevId: 353200069
This commit is contained in:
parent
f6bf9d5780
commit
56758a9562
|
@ -442,11 +442,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
auto ty = result_types.front().cast<FloatType>();
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Log1pOp>{}(
|
||||||
Value x = args.front();
|
loc, result_types, args, b);
|
||||||
Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0));
|
|
||||||
Value x_plus_one = b->create<AddFOp>(loc, x, one);
|
|
||||||
return b->create<::mlir::LogOp>(loc, x_plus_one);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -166,6 +166,16 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @float_log1p
|
||||||
|
func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: log1p
|
||||||
|
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_ceil
|
// CHECK-LABEL: func @float_ceil
|
||||||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
Loading…
Reference in New Issue