2020-07-07 04:57:00 +08:00
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
2020-07-29 07:12:08 +08:00
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
2020-07-07 12:51:24 +08:00
|
|
|
namespace mhlo {
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
|
|
|
|
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
|
|
|
|
// a static broadcast.
|
|
|
|
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
|
|
|
Value value_1d, Value shape_value,
|
|
|
|
int64_t feature_dim,
|
|
|
|
PatternRewriter& rewriter) { // NOLINT
|
|
|
|
Builder b(rewriter.getContext());
|
|
|
|
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
|
|
|
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
|
|
|
if (shape_value) {
|
2020-07-07 12:51:24 +08:00
|
|
|
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
2020-07-07 04:57:00 +08:00
|
|
|
loc, result_type, value_1d, shape_value, dims);
|
|
|
|
}
|
|
|
|
assert(result_type.hasStaticShape());
|
2020-07-07 12:51:24 +08:00
|
|
|
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
|
|
|
dims);
|
2020-07-07 04:57:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Calculate the shape value of operand, assuming it is a dynamic shape with
|
|
|
|
// static rank.
|
|
|
|
Value CalculateShapeValue(Location loc, Value operand,
|
|
|
|
PatternRewriter& rewriter) { // NOLINT
|
|
|
|
RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
|
|
|
|
llvm::SmallVector<Value, 4> shape_values;
|
|
|
|
int64_t rank = result_type.getRank();
|
|
|
|
shape_values.reserve(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
|
|
|
|
}
|
|
|
|
return rewriter.create<TensorFromElementsOp>(loc, shape_values);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
|
|
|
FloatType fp_type, Value variance,
|
|
|
|
RankedTensorType broadcast_to_type,
|
|
|
|
PatternRewriter& rewriter) { // NOLINT
|
|
|
|
Builder b(rewriter.getContext());
|
|
|
|
if (epsilon_attr.getType() != fp_type) {
|
|
|
|
// Need to convert.
|
|
|
|
bool loses_info;
|
|
|
|
APFloat epsilon_float = epsilon_attr.getValue();
|
|
|
|
auto status = epsilon_float.convert(
|
|
|
|
fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
|
|
|
|
if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
|
|
|
|
op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
|
|
|
|
"type: opStatus = "
|
|
|
|
<< static_cast<int>(status);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
if (loses_info) {
|
|
|
|
op->emitWarning("Conversion of epsilon loses precision");
|
|
|
|
}
|
|
|
|
epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto scalar_type = RankedTensorType::get({}, fp_type);
|
|
|
|
auto epsilon_tensor_attr =
|
|
|
|
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
|
|
|
Value epsilon =
|
2020-07-07 12:51:24 +08:00
|
|
|
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
2020-07-07 04:57:00 +08:00
|
|
|
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
|
|
|
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
|
|
|
if (broadcast_to_type.hasStaticShape()) {
|
2020-07-07 12:51:24 +08:00
|
|
|
return rewriter.create<mhlo::BroadcastInDimOp>(
|
2020-07-07 04:57:00 +08:00
|
|
|
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
|
|
|
}
|
|
|
|
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
2020-07-07 12:51:24 +08:00
|
|
|
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
2020-07-07 04:57:00 +08:00
|
|
|
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
|
|
|
/*broadcast_dims=*/dims);
|
|
|
|
}
|
|
|
|
|
|
|
|
class UnfuseBatchNormInferencePattern
|
2020-07-07 12:51:24 +08:00
|
|
|
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
|
2020-07-07 04:57:00 +08:00
|
|
|
public:
|
2020-07-07 12:51:24 +08:00
|
|
|
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
|
2020-07-07 04:57:00 +08:00
|
|
|
|
2020-07-07 12:51:24 +08:00
|
|
|
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
|
2020-07-07 04:57:00 +08:00
|
|
|
PatternRewriter& rewriter) const override {
|
|
|
|
// Enforce type invariants.
|
|
|
|
// Note that we deduce the actual element type from the variance,
|
|
|
|
// which should not be subject to quantization at a higher level.
|
|
|
|
auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
|
|
|
|
auto variance_type =
|
|
|
|
bn_op.variance().getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!input_type || !variance_type) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
|
|
|
|
if (!fp_type) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
int64_t feature_dim = bn_op.feature_index().getSExtValue();
|
|
|
|
|
|
|
|
// Add epsilon to the variance and sqrt to get stddev:
|
|
|
|
// stddev = sqrt(variance + epsilon)
|
|
|
|
auto epsilon =
|
|
|
|
MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
|
|
|
|
bn_op.variance(), variance_type, rewriter);
|
|
|
|
if (!epsilon) {
|
|
|
|
return failure();
|
|
|
|
}
|
2020-07-07 12:51:24 +08:00
|
|
|
Value stddev =
|
|
|
|
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
|
|
|
|
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
// Broadcast all terms.
|
|
|
|
Value shape_value;
|
|
|
|
if (!input_type.hasStaticShape()) {
|
|
|
|
shape_value =
|
|
|
|
CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
|
|
|
|
}
|
|
|
|
auto broadcast_scale =
|
|
|
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
|
|
|
|
shape_value, feature_dim, rewriter);
|
|
|
|
auto broadcast_offset =
|
|
|
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
|
|
|
|
shape_value, feature_dim, rewriter);
|
|
|
|
auto broadcast_mean =
|
|
|
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
|
|
|
|
shape_value, feature_dim, rewriter);
|
|
|
|
auto broadcast_stddev = BroadcastToFeatureDim(
|
|
|
|
bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
|
|
|
|
|
|
|
|
// Compute:
|
|
|
|
// scale * (input - mean) / stddev + offset
|
2020-07-07 12:51:24 +08:00
|
|
|
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
|
|
|
|
broadcast_mean);
|
|
|
|
result =
|
|
|
|
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
|
|
|
|
result =
|
|
|
|
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
|
|
|
|
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
|
2020-07-07 04:57:00 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Populates conversion patterns to unfuse batch normalization operations.
|
|
|
|
// In combination with marking such ops as illegal, this allows backends that
|
|
|
|
// do not have special support for fused batchnorm to use simpler arithmetic
|
|
|
|
// primitives.
|
|
|
|
void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
|
|
|
|
OwningRewritePatternList* patterns) {
|
|
|
|
patterns->insert<UnfuseBatchNormInferencePattern>(context);
|
|
|
|
}
|
|
|
|
|
2020-07-07 12:51:24 +08:00
|
|
|
} // namespace mhlo
|
2020-07-07 04:57:00 +08:00
|
|
|
} // namespace mlir
|