aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-04-04 15:42:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 15:44:55 -0700
commit4cfb393b087dc50c150054531186ccb71882e2d0 (patch)
treeac5320aef3125e3f78741151a82186b9ef4ccdab /tensorflow/python/client
parente7ad6ec4267f1f79ee7d9f558c8a008746682959 (diff)
Adding Operation._control_outputs
PiperOrigin-RevId: 191659944
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/tf_session.i19
-rw-r--r--tensorflow/python/client/tf_session_helper.cc9
-rw-r--r--tensorflow/python/client/tf_session_helper.h4
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 0c18d973a7..b82182d5d3 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -157,6 +157,25 @@ tensorflow::ImportNumpy();
}
}
+// We use TF_OperationGetControlOutputs_wrapper instead of
+// TF_OperationGetControlOutputs
+%ignore TF_OperationGetControlOutputs;
+%unignore TF_OperationGetControlOutputs_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_OperationGetControlOutputs_wrapper;
+
+// Build a Python list of TF_Operation* and return it.
+%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i]));
+ }
+}
+
%ignore TF_OperationOutputConsumers;
%unignore TF_OperationOutputConsumers_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index ca57abd712..b48d758e4a 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -550,6 +550,15 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
return control_inputs;
}
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper) {
+ std::vector<TF_Operation*> control_outputs(
+ TF_OperationNumControlOutputs(oper));
+ TF_OperationGetControlOutputs(oper, control_outputs.data(),
+ control_outputs.size());
+ return control_outputs;
+}
+
std::vector<const char*> TF_OperationOutputConsumers_wrapper(
TF_Output oper_out) {
int num_consumers = TF_OperationOutputNumConsumers(oper_out);
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 5416d41376..d2b4abc476 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -190,6 +190,10 @@ std::vector<TF_Output> GetOperationInputs(TF_Operation* oper);
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper);
+// Retrieves the control outputs of this operation.
+std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper(
+ TF_Operation* oper);
+
// Retrieves the op names of the consumers of `oper_out`. The returned strings
// have the lifetime of the underlying TF_Graph.
std::vector<const char*> TF_OperationOutputConsumers_wrapper(