diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-31 12:24:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-31 13:38:39 -0700 |
commit | 2940b6c9ac3518b27633d90880b9092157496ee8 (patch) | |
tree | cba06a8dabb4c1cf79faba8267b8eea4e16b2a6d /tensorflow/core/framework/shape_inference.cc | |
parent | 32906b8c26608185eb7062b9bb32108b1b416d8a (diff) |
Automated rollback of change 137731142
Change: 137740850
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 59 |
1 files changed, 5 insertions, 54 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 4aa32f6a84..da88b6a7ca 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -31,9 +31,7 @@ InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes, - const std::vector<TensorShapeProto>& input_handle_shapes, - const std::vector<DataType>& input_handle_dtypes) + const std::vector<ShapeHandle>& input_tensors_as_shapes) : node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; @@ -45,30 +43,19 @@ InferenceContext::InferenceContext( } inputs_.push_back(shape); } - std::vector<ShapeHandle> handle_shapes; - for (const auto& p : input_handle_shapes) { - ShapeHandle shape; - construction_status_.Update(MakeShapeFromShapeProto(p, &shape)); - if (!construction_status_.ok()) { - return; - } - handle_shapes.push_back(shape); - } - PostInputInit(handle_shapes, input_handle_dtypes); + PostInputInit(); } InferenceContext::InferenceContext( const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, - const std::vector<ShapeHandle>& input_tensors_as_shapes, - const std::vector<ShapeHandle>& input_handle_shapes, - const std::vector<DataType>& input_handle_dtypes) + const std::vector<ShapeHandle>& input_tensors_as_shapes) : node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; inputs_ = input_shapes; - PostInputInit(input_handle_shapes, input_handle_dtypes); + PostInputInit(); } InferenceContext::~InferenceContext() { @@ -137,44 +124,15 @@ void InferenceContext::PreInputInit( for (int i = 0; i < num_outputs; ++i) { outputs_.push_back(nullptr); } - output_handle_shape_.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - output_handle_shape_.push_back(UnknownShape()); - } - output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID); } -void InferenceContext::PostInputInit( - const std::vector<ShapeHandle>& input_handle_shapes, - const std::vector<DataType>& input_handle_dtypes) { +void InferenceContext::PostInputInit() { int num_inputs_from_node_def = 0; for (const auto& e : input_name_map_) { num_inputs_from_node_def = std::max(num_inputs_from_node_def, e.second.second); } - // Allow passing empty shapes/dtypes to avoid changing every single test. - if (input_handle_shapes.empty()) { - input_handle_shape_.resize(inputs_.size()); - } else { - input_handle_shape_ = input_handle_shapes; - if (input_handle_shape_.size() != inputs_.size()) { - construction_status_ = errors::InvalidArgument( - "Wrong number of handle shapes passed; expected ", inputs_.size(), - " got ", input_handle_shape_.size()); - } - } - if (input_handle_dtypes.empty()) { - input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID); - } else { - input_handle_dtype_ = input_handle_dtypes; - if (input_handle_dtype_.size() != inputs_.size()) { - construction_status_ = errors::InvalidArgument( - "Wrong number of handle dtypes passed; expected ", inputs_.size(), - " got ", input_handle_dtype_.size()); - } - } - if (inputs_.size() != num_inputs_from_node_def) { construction_status_ = errors::InvalidArgument( "Wrong number of inputs passed: ", inputs_.size(), " while ", @@ -779,13 +737,6 @@ Status InferenceContext::AttachContext(const Status& status) { strings::StrCat(status.error_message(), error_context)); } -ShapeHandle InferenceContext::input_handle_shape(int idx) { - if (!input_handle_shape_[idx].IsSet()) { - input_handle_shape_[idx] = UnknownShape(); - } - return input_handle_shape_[idx]; -} - // ----------------------------------------------------------------------------- // ShapeManager // ----------------------------------------------------------------------------- |