aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 12:24:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 13:38:39 -0700
commit2940b6c9ac3518b27633d90880b9092157496ee8 (patch)
treecba06a8dabb4c1cf79faba8267b8eea4e16b2a6d /tensorflow/core/framework/shape_inference.cc
parent32906b8c26608185eb7062b9bb32108b1b416d8a (diff)
Automated rollback of change 137731142
Change: 137740850
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc59
1 files changed, 5 insertions, 54 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4aa32f6a84..da88b6a7ca 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -31,9 +31,7 @@ InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<TensorShapeProto>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
@@ -45,30 +43,19 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
- std::vector<ShapeHandle> handle_shapes;
- for (const auto& p : input_handle_shapes) {
- ShapeHandle shape;
- construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
- if (!construction_status_.ok()) {
- return;
- }
- handle_shapes.push_back(shape);
- }
- PostInputInit(handle_shapes, input_handle_dtypes);
+ PostInputInit();
}
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
- const std::vector<ShapeHandle>& input_tensors_as_shapes,
- const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes)
+ const std::vector<ShapeHandle>& input_tensors_as_shapes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
- PostInputInit(input_handle_shapes, input_handle_dtypes);
+ PostInputInit();
}
InferenceContext::~InferenceContext() {
@@ -137,44 +124,15 @@ void InferenceContext::PreInputInit(
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(nullptr);
}
- output_handle_shape_.reserve(num_outputs);
- for (int i = 0; i < num_outputs; ++i) {
- output_handle_shape_.push_back(UnknownShape());
- }
- output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
}
-void InferenceContext::PostInputInit(
- const std::vector<ShapeHandle>& input_handle_shapes,
- const std::vector<DataType>& input_handle_dtypes) {
+void InferenceContext::PostInputInit() {
int num_inputs_from_node_def = 0;
for (const auto& e : input_name_map_) {
num_inputs_from_node_def =
std::max(num_inputs_from_node_def, e.second.second);
}
- // Allow passing empty shapes/dtypes to avoid changing every single test.
- if (input_handle_shapes.empty()) {
- input_handle_shape_.resize(inputs_.size());
- } else {
- input_handle_shape_ = input_handle_shapes;
- if (input_handle_shape_.size() != inputs_.size()) {
- construction_status_ = errors::InvalidArgument(
- "Wrong number of handle shapes passed; expected ", inputs_.size(),
- " got ", input_handle_shape_.size());
- }
- }
- if (input_handle_dtypes.empty()) {
- input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
- } else {
- input_handle_dtype_ = input_handle_dtypes;
- if (input_handle_dtype_.size() != inputs_.size()) {
- construction_status_ = errors::InvalidArgument(
- "Wrong number of handle dtypes passed; expected ", inputs_.size(),
- " got ", input_handle_dtype_.size());
- }
- }
-
if (inputs_.size() != num_inputs_from_node_def) {
construction_status_ = errors::InvalidArgument(
"Wrong number of inputs passed: ", inputs_.size(), " while ",
@@ -779,13 +737,6 @@ Status InferenceContext::AttachContext(const Status& status) {
strings::StrCat(status.error_message(), error_context));
}
-ShapeHandle InferenceContext::input_handle_shape(int idx) {
- if (!input_handle_shape_[idx].IsSet()) {
- input_handle_shape_[idx] = UnknownShape();
- }
- return input_handle_shape_[idx];
-}
-
// -----------------------------------------------------------------------------
// ShapeManager
// -----------------------------------------------------------------------------