aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/client/tf_session.i21
-rw-r--r--tensorflow/python/client/tf_session_helper.cc8
-rw-r--r--tensorflow/python/client/tf_session_helper.h5
-rw-r--r--tensorflow/python/framework/ops.py32
4 files changed, 65 insertions, 1 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 284c98639d..243c870e0d 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -68,6 +68,27 @@ tensorflow::ImportNumpy();
$result = PyUnicode_FromString($1);
}
+// We use TF_OperationGetControlInputs_wrapper instead of
+// TF_OperationGetControlInputs
+%ignore TF_OperationGetControlInputs;
+%unignore TF_OperationGetControlInputs_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_OperationGetControlInputs_wrapper;
+
+// Build a Python list of TF_Operation* and return it.
+%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, SWIG_NewPointerObj(
+ $1[i], SWIGTYPE_p_TF_Operation, 0));
+ }
+}
+
+
////////////////////////////////////////////////////////////////////////////////
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 7ebb1a7fe4..0e89ae2426 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -766,4 +766,12 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
ClearDecrefCache();
}
+std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
+ TF_Operation* oper) {
+ std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
+ TF_OperationGetControlInputs(oper, control_inputs.data(),
+ control_inputs.size());
+ return control_inputs;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 9937b6aeeb..f1f70a9a1d 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -158,6 +158,11 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
TF_Status* out_status,
std::vector<PyObject*>* py_outputs);
+// Retrieves control inputs of this operation.
+// control_inputs should be empty.
+std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
+ TF_Operation* oper);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f52127f119..a8c2930cbe 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1583,7 +1583,14 @@ class Operation(object):
A list of `Operation` objects.
"""
- return self._control_inputs
+ if _USE_C_API:
+ control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops]
+ # pylint: enable=protected-access
+ else:
+ return self._control_inputs
@property
def type(self):
@@ -2781,6 +2788,29 @@ class Graph(object):
% type(name).__name__)
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
+ def _get_operation_by_name_unsafe(self, name):
+ """Returns the `Operation` with the given `name`.
+
+ This is a internal unsafe version of get_operation_by_name. It skips many
+ checks and does not have user friedly error messages but runs considerably
+ faster. This method may be called concurrently from multiple threads.
+
+ Args:
+ name: The name of the `Operation` to return.
+
+ Returns:
+ The `Operation` with the given `name`.
+
+ Raises:
+ KeyError: If `name` does not correspond to an operation in this graph.
+ """
+
+ if self._finalized:
+ return self._nodes_by_name[name]
+
+ with self._lock:
+ return self._nodes_by_name[name]
+
def get_tensor_by_name(self, name):
"""Returns the `Tensor` with the given `name`.