diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-05-02 08:06:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-02 09:34:53 -0700 |
commit | e8eafd94de1fc90a5f4724570f5882b01e1626dc (patch) | |
tree | 25084a1ed40a7acbc592e1a48be4cbd81cebd473 | |
parent | 883e32600ef242cb44d0702bb96f71f3140b5403 (diff) |
Don't try to refine the shapes for a node if its inference context wasn't
successfully built by the AddNode() method.
Change: 154838211
-rw-r--r-- | tensorflow/core/common_runtime/shape_refiner.cc | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index daa9e5091a..828297a1ab 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -88,7 +88,7 @@ Status ShapeRefiner::AddNode(const Node* node) { } // This needs to be filled in with real data in a second pass. - std::vector<const Tensor*> input_tensors(node->num_inputs()); + std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr); std::vector<ShapeHandle> input_tensors_as_shapes; // Create the inference context for this node with the existing input shapes. @@ -145,6 +145,9 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) { } InferenceContext* node_context = it->second.get(); + // Give up if the context wasn't successfully built by the AddNode() method. + TF_RETURN_IF_ERROR(node_context->construction_status()); + // Check if the shapes of the nodes in the fan-in of this node have changed, // and if they have update the node input shapes. for (const Edge* e : node->in_edges()) { @@ -458,7 +461,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, shape_inference::InferenceContext* c) { // This will be filled in with real data in a second pass. - std::vector<const Tensor*> input_tensors(node->num_inputs()); + std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr); std::vector<Tensor> real_tensors(node->num_inputs()); std::vector<bool> attempted_materialization(node->num_inputs()); std::vector<bool> attempted_tensor_as_shape_conversion(node->num_inputs()); |