diff options
author | Alexandre Passos <apassos@google.com> | 2018-04-04 15:42:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-04 15:44:55 -0700 |
commit | 4cfb393b087dc50c150054531186ccb71882e2d0 (patch) | |
tree | ac5320aef3125e3f78741151a82186b9ef4ccdab /tensorflow/python/client | |
parent | e7ad6ec4267f1f79ee7d9f558c8a008746682959 (diff) |
Adding Operation._control_outputs
PiperOrigin-RevId: 191659944
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/tf_session.i | 19 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 9 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 4 |
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 0c18d973a7..b82182d5d3 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -157,6 +157,25 @@ tensorflow::ImportNumpy(); } } +// We use TF_OperationGetControlOutputs_wrapper instead of +// TF_OperationGetControlOutputs +%ignore TF_OperationGetControlOutputs; +%unignore TF_OperationGetControlOutputs_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_OperationGetControlOutputs_wrapper; + +// Build a Python list of TF_Operation* and return it. +%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlOutputs_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i])); + } +} + %ignore TF_OperationOutputConsumers; %unignore TF_OperationOutputConsumers_wrapper; // See comment for "%noexception TF_SessionRun_wrapper;" diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index ca57abd712..b48d758e4a 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -550,6 +550,15 @@ std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( return control_inputs; } +std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper) { + std::vector<TF_Operation*> control_outputs( + TF_OperationNumControlOutputs(oper)); + TF_OperationGetControlOutputs(oper, control_outputs.data(), + control_outputs.size()); + return control_outputs; +} + std::vector<const char*> TF_OperationOutputConsumers_wrapper( TF_Output oper_out) { int num_consumers = TF_OperationOutputNumConsumers(oper_out); diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 5416d41376..d2b4abc476 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -190,6 +190,10 @@ std::vector<TF_Output> GetOperationInputs(TF_Operation* oper); std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( TF_Operation* oper); +// Retrieves the control outputs of this operation. +std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( + TF_Operation* oper); + // Retrieves the op names of the consumers of `oper_out`. The returned strings // have the lifetime of the underlying TF_Graph. std::vector<const char*> TF_OperationOutputConsumers_wrapper( |