diff options
author | Olivia Nordquist <nolivia@google.com> | 2017-11-13 13:07:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-13 13:11:13 -0800 |
commit | bac56b37be7736c9da9a3257696a9c1241327d60 (patch) | |
tree | 767a1449473fa471aa11d750ac666085494bf00b /tensorflow/c/python_api.cc | |
parent | 90222dd7b29ff2597bc7f8d0f92db17324f591b0 (diff) |
Validate shapes when updating edges from Python.
Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape. Also this changes the MergeInput method. Previously, MergeInput would only return true if the shapes differed *and* the merge was successful. Now, MergeInput returns true only if the merge is successful.
PiperOrigin-RevId: 175576173
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r-- | tensorflow/c/python_api.cc | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index c67007dca0..ba5a9268b4 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); } |