Sync with latest MLIR. (#26)

This commit is contained in:
Tian Jin 2020-01-13 12:21:29 -05:00 committed by GitHub
parent f384e3187e
commit 22a6bdc574
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 94 deletions

View File

@ -680,7 +680,7 @@ private:
auto tensor_val = auto tensor_val =
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
ret_types.emplace_back(tensor_val->getType()); ret_types.emplace_back(tensor_val.getType());
ret_vals.push_back(tensor_val); ret_vals.push_back(tensor_val);
} }

View File

@ -65,14 +65,14 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
// Exp // Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the /// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); } void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tanh // Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the /// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXTanhOp::inferShapes() { void ONNXTanhOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -80,7 +80,7 @@ void ONNXTanhOp::inferShapes() {
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSinhOp::inferShapes() { void ONNXSinhOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -88,27 +88,27 @@ void ONNXSinhOp::inferShapes() {
/// Infer the output shape of the ONNXCoshOp. This method is required by the /// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCoshOp::inferShapes() { void ONNXCoshOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cos // Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the /// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); } void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Log // Log
/// Infer the output shape of the ONNXLogOp. This method is required by the /// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); } void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// HardSigmoid // HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() { void ONNXHardSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -116,21 +116,21 @@ void ONNXHardSigmoidOp::inferShapes() {
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSigmoidOp::inferShapes() { void ONNXSigmoidOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Elu // Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the /// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu // Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXReluOp::inferShapes() { void ONNXReluOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -138,7 +138,7 @@ void ONNXReluOp::inferShapes() {
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() { void ONNXLeakyReluOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -146,7 +146,7 @@ void ONNXLeakyReluOp::inferShapes() {
/// Infer the output shape of the ONNXSeluOp. This method is required by /// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSeluOp::inferShapes() { void ONNXSeluOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -154,7 +154,7 @@ void ONNXSeluOp::inferShapes() {
/// Infer the output shape of the ONNXReciprocalOp. This method is required by /// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXReciprocalOp::inferShapes() { void ONNXReciprocalOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -162,12 +162,12 @@ void ONNXReciprocalOp::inferShapes() {
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXAddOp::inferShapes() { void ONNXAddOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -175,12 +175,12 @@ void ONNXAddOp::inferShapes() {
/// Infer the output shape of the ONNXMulOp. This method is required by the /// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXMulOp::inferShapes() { void ONNXMulOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -188,12 +188,12 @@ void ONNXMulOp::inferShapes() {
/// Infer the output shape of the ONNXDivOp. This method is required by the /// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXDivOp::inferShapes() { void ONNXDivOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -201,12 +201,12 @@ void ONNXDivOp::inferShapes() {
/// Infer the output shape of the ONNXSubOp. This method is required by the /// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSubOp::inferShapes() { void ONNXSubOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -214,12 +214,12 @@ void ONNXSubOp::inferShapes() {
/// Infer the output shape of the ONNXAndOp. This method is required by the /// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXAndOp::inferShapes() { void ONNXAndOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -227,12 +227,12 @@ void ONNXAndOp::inferShapes() {
/// Infer the output shape of the ONNXOrOp. This method is required by the /// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXOrOp::inferShapes() { void ONNXOrOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -240,12 +240,12 @@ void ONNXOrOp::inferShapes() {
/// Infer the output shape of the ONNXXorOp. This method is required by the /// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXXorOp::inferShapes() { void ONNXXorOp::inferShapes() {
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult()->setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -256,15 +256,15 @@ void ONNXXorOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
void ONNXSumOp::inferShapes() { void ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i)->getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return;
} }
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>(); Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult()->setType(resultTy); getResult().setType(resultTy);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -273,15 +273,15 @@ void ONNXSumOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
void ONNXMaxOp::inferShapes() { void ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i)->getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return;
} }
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>(); Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult()->setType(resultTy); getResult().setType(resultTy);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -290,15 +290,15 @@ void ONNXMaxOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
void ONNXMinOp::inferShapes() { void ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i)->getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return;
} }
Type resultTy = getOperand(0)->getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i)->getType().cast<RankedTensorType>(); Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult()->setType(resultTy); getResult().setType(resultTy);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -306,7 +306,7 @@ void ONNXMinOp::inferShapes() {
/// Infer the output shape of the ONNXIdentityOp. This method is required by the /// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXIdentityOp::inferShapes() { void ONNXIdentityOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -315,15 +315,15 @@ void ONNXIdentityOp::inferShapes() {
void ONNXMatMulOp::inferShapes() { void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// TODO: // TODO:
@ -336,30 +336,30 @@ void ONNXMatMulOp::inferShapes() {
void ONNXGemmOp::inferShapes() { void ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// FullGemm // FullGemm
void ONNXFullGemmOp::inferShapes() { void ONNXFullGemmOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0)->getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1)->getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]); dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// TODO: // TODO:
@ -372,11 +372,11 @@ void ONNXFullGemmOp::inferShapes() {
void ONNXReshapeOp::inferShapes() { void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
if (!getOperand(1)->getType().isa<RankedTensorType>()) if (!getOperand(1).getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked.");
auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>(); auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>(); auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported. // Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1) if (shapeTensorTy.getShape().size() != 1)
@ -392,7 +392,7 @@ void ONNXReshapeOp::inferShapes() {
for (int i = 0; i < outputRank; ++i) for (int i = 0; i < outputRank; ++i)
dims.emplace_back(-1); dims.emplace_back(-1);
getResult()->setType( getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType())); RankedTensorType::get(dims, inputTensorTy.getElementType()));
} }
@ -402,16 +402,16 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() { void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand()->getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked.");
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
// TODO: Once attributes are supported we can handle the case where the // TODO: Once attributes are supported we can handle the case where the
// transposition uses a permutation vector to interchange the axes. // transposition uses a permutation vector to interchange the axes.
auto arrayTy = getOperand()->getType().cast<RankedTensorType>(); auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -61,7 +61,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
Value maxDim = nullptr; Value maxDim = nullptr;
for (int i = 0; i < operands.size(); i++) { for (int i = 0; i < operands.size(); i++) {
auto operandShape = auto operandShape =
operands[i]->getType().cast<MemRefType>().getShape(); operands[i].getType().cast<MemRefType>().getShape();
int operandDimIdx = operandShape.size() - 1 - reversedIdx; int operandDimIdx = operandShape.size() - 1 - reversedIdx;
if (operandDimIdx < 0) if (operandDimIdx < 0)
@ -162,7 +162,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
int dimIdx = rank - 1 - reversedIdx; int dimIdx = rank - 1 - reversedIdx;
sharedDimCount[dimIdx] = 0; sharedDimCount[dimIdx] = 0;
for (int i = 0; i < operands.size(); ++i) { for (int i = 0; i < operands.size(); ++i) {
auto shape = operands[i]->getType().cast<MemRefType>().getShape(); auto shape = operands[i].getType().cast<MemRefType>().getShape();
if (reversedIdx <= shape.size() - 1) if (reversedIdx <= shape.size() - 1)
sharedDimCount[dimIdx]++; sharedDimCount[dimIdx]++;
} }
@ -174,7 +174,7 @@ getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
// more than one, since they are potentially broadcasted dimensions. // more than one, since they are potentially broadcasted dimensions.
for (int i = 0; i < operands.size(); ++i) { for (int i = 0; i < operands.size(); ++i) {
std::map<int, Value> broadcastedDims; std::map<int, Value> broadcastedDims;
auto shape = operands[i]->getType().cast<MemRefType>().getShape(); auto shape = operands[i].getType().cast<MemRefType>().getShape();
int size = shape.size(); int size = shape.size();
for (int j = 0; j < shape.size(); ++j) { for (int j = 0; j < shape.size(); ++j) {
if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
@ -198,7 +198,7 @@ getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
std::map<int, Value> broadcastedDims) { std::map<int, Value> broadcastedDims) {
// `operand` must has a ranked type. This should have been checked by the // `operand` must has a ranked type. This should have been checked by the
// shape inference pass. // shape inference pass.
auto operandShape = operand->getType().cast<MemRefType>().getShape(); auto operandShape = operand.getType().cast<MemRefType>().getShape();
auto rank = operandShape.size(); auto rank = operandShape.size();
auto loopCount = loopIVs.size(); auto loopCount = loopIVs.size();
@ -319,7 +319,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
/* Lower UnaryOp to Ops in the Standard dialect. /* Lower UnaryOp to Ops in the Standard dialect.
*/ */
auto loc = op->getLoc(); auto loc = op->getLoc();
Type element_type = operands.front()->getType(); Type element_type = operands.front().getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands, return rewriter.create<ScalarIOp<UnaryOp>>(loc, result_types, operands,
mlir::None); mlir::None);

View File

@ -24,7 +24,7 @@ include "dialect/onnx/onnx.td"
/// dag benefitsAdded = (addBenefit 0) /// dag benefitsAdded = (addBenefit 0)
/// >; /// >;
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>; def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite // Pattern-Match and Rewrite

View File

@ -55,7 +55,7 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
for (size_t i = 0; i < nestedForOps.size() - 1; i++) { for (size_t i = 0; i < nestedForOps.size() - 1; i++) {
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
auto forIV = nestedForOps[i].getBody()->getArgument(0); auto forIV = nestedForOps[i].getBody()->getArgument(0);
iterateIV->replaceAllUsesWith(forIV); iterateIV.replaceAllUsesWith(forIV);
iterateOp.bodyRegion().front().eraseArgument(0); iterateOp.bodyRegion().front().eraseArgument(0);
} }

View File

@ -65,7 +65,7 @@ public:
// First operand. // First operand.
Type dstType = Type dstType =
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
@ -73,7 +73,7 @@ public:
// Second operand. // Second operand.
Type srcType = Type srcType =
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
@ -253,7 +253,7 @@ public:
// Get the first memref returned, convert to a dynamic memref and store // Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output. // it in the wrapped Output.
auto outMemRef = outputMemRefs.getResult(0); auto outMemRef = outputMemRefs.getResult(0);
auto outMemRefTy = outMemRef->getType().dyn_cast<LLVMType>(); auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
auto outMemRefRank = auto outMemRefRank =
outMemRefTy.getStructElementType(3).getArrayNumElements(); outMemRefTy.getStructElementType(3).getArrayNumElements();
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(