aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/python_api.cc
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-11-13 13:07:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 13:11:13 -0800
commitbac56b37be7736c9da9a3257696a9c1241327d60 (patch)
tree767a1449473fa471aa11d750ac666085494bf00b /tensorflow/c/python_api.cc
parent90222dd7b29ff2597bc7f8d0f92db17324f591b0 (diff)
Validate shapes when updating edges from Python.
Uses MergeInput from shape_inference to check if the new input is compatible with the preexisting shape. Also this changes the MergeInput method. Previously, MergeInput would only return true if the shapes differed *and* the merge was successful. Now, MergeInput returns true only if the merge is successful. PiperOrigin-RevId: 175576173
Diffstat (limited to 'tensorflow/c/python_api.cc')
-rw-r--r--tensorflow/c/python_api.cc27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index c67007dca0..ba5a9268b4 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&new_src.oper->node);
+
+ if (ic->num_outputs() <= new_src.index) {
+ status->status = tensorflow::errors::OutOfRange(
+ "Cannot update edge. Output index [", new_src.index,
+ "] is greater than the number of total outputs [", ic->num_outputs(),
+ "].");
+ return;
+ }
+ tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
+
+ tensorflow::shape_inference::InferenceContext* ic_dst =
+ graph->refiner.GetContext(&dst.oper->node);
+ if (ic_dst->num_inputs() <= dst.index) {
+ status->status = tensorflow::errors::OutOfRange(
+ "Cannot update edge. Input index [", dst.index,
+ "] is greater than the number of total inputs [", ic_dst->num_inputs(),
+ "].");
+ return;
+ }
+ if (!ic_dst->MergeInput(dst.index, shape)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
+ " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
+ return;
+ }
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
}