diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-06-07 08:47:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-07 08:50:48 -0700 |
commit | 537e8c7a28b6b793eb570c957c4e90bf81ce9c3b (patch) | |
tree | 815217e5667b97b7afc7e70a28c967f2b787da2a /tensorflow/python/client | |
parent | 3f31670ddc140a62ffac9d8b9310f71bdfbae629 (diff) |
Remove _USE_C_API staging from session.py.
PiperOrigin-RevId: 199641205
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session.py | 159 |
1 files changed, 39 insertions, 120 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5507d011bb..648e35cdf2 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -619,21 +619,12 @@ class BaseSession(SessionInterface): self._config = None self._add_shapes = False - # pylint: disable=protected-access - # We cache _USE_C_API's value because some test cases will create a session - # with _USE_C_API = False but set it back to True before calling close(). - self._created_with_new_api = ops._USE_C_API - # pylint: enable=protected-access - self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: - if self._created_with_new_api: - # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) - # pylint: enable=protected-access - else: - self._session = tf_session.TF_NewDeprecatedSession(opts) + # pylint: disable=protected-access + self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts) @@ -660,11 +651,7 @@ class BaseSession(SessionInterface): Returns: A list of devices in the session. """ - if self._created_with_new_api: - raw_device_list = tf_session.TF_SessionListDevices(self._session) - else: - raw_device_list = tf_session.TF_DeprecatedSessionListDevices( - self._session) + raw_device_list = tf_session.TF_SessionListDevices(self._session) device_list = [] size = tf_session.TF_DeviceListCount(raw_device_list) for i in range(size): @@ -684,16 +671,9 @@ class BaseSession(SessionInterface): tf.errors.OpError: Or one of its subclasses if an error occurs while closing the TensorFlow session. """ - if self._created_with_new_api: - if self._session and not self._closed: - self._closed = True - tf_session.TF_CloseSession(self._session) - - else: - with self._extend_lock: - if self._opened and not self._closed: - self._closed = True - tf_session.TF_CloseDeprecatedSession(self._session) + if self._session and not self._closed: + self._closed = True + tf_session.TF_CloseSession(self._session) def __del__(self): # cleanly ignore all exceptions @@ -703,10 +683,7 @@ class BaseSession(SessionInterface): pass if self._session is not None: try: - if self._created_with_new_api: - tf_session.TF_DeleteSession(self._session) - else: - tf_session.TF_DeleteDeprecatedSession(self._session) + tf_session.TF_DeleteSession(self._session) except AttributeError: # At shutdown, `c_api_util` or `tf_session` may have been garbage # collected, causing the above method calls to fail. In this case, @@ -1005,12 +982,9 @@ class BaseSession(SessionInterface): try: subfeed_t = self.graph.as_graph_element( subfeed, allow_tensor=True, allow_operation=False) - if self._created_with_new_api: - # pylint: disable=protected-access - feed_list.append(subfeed_t._as_tf_output()) - # pylint: enable=protected-access - else: - feed_list.append(compat.as_bytes(subfeed_t.name)) + # pylint: disable=protected-access + feed_list.append(subfeed_t._as_tf_output()) + # pylint: enable=protected-access except Exception as e: e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) e.args = (e.message,) @@ -1023,22 +997,13 @@ class BaseSession(SessionInterface): # Set up a graph with feeds and fetches for partial run. def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() - if self._created_with_new_api: - return tf_session.TF_SessionPRunSetup_wrapper( - session, feed_list, fetch_list, target_list) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRunSetup(session, feed_list, fetch_list, - target_list, status) + return tf_session.TF_SessionPRunSetup_wrapper( + session, feed_list, fetch_list, target_list) - if self._created_with_new_api: - # pylint: disable=protected-access - final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] - final_targets = [op._c_op for op in fetch_handler.targets()] - # pylint: enable=protected-access - else: - final_fetches = _name_list(fetch_handler.fetches()) - final_targets = _name_list(fetch_handler.targets()) + # pylint: disable=protected-access + final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] + final_targets = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access return self._do_call(_setup_fn, self._session, feed_list, final_fetches, final_targets) @@ -1196,14 +1161,10 @@ class BaseSession(SessionInterface): # Create a fetch handler to take care of the structure of fetches. fetch_handler = _FetchHandler(self._graph, fetches, {}) - if self._created_with_new_api: - # pylint: disable=protected-access - fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] - target_list = [op._c_op for op in fetch_handler.targets()] - # pylint: enable=protected-access - else: - fetch_list = _name_list(fetch_handler.fetches()) - target_list = _name_list(fetch_handler.targets()) + # pylint: disable=protected-access + fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] + target_list = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access def _callable_template_with_options_and_metadata(fetch_list, target_list, @@ -1289,16 +1250,11 @@ class BaseSession(SessionInterface): Raises: tf.errors.OpError: Or one of its subclasses on error. """ - if self._created_with_new_api: - # pylint: disable=protected-access - feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) - fetches = [t._as_tf_output() for t in fetch_list] - targets = [op._c_op for op in target_list] - # pylint: enable=protected-access - else: - feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items()) - fetches = _name_list(fetch_list) - targets = _name_list(target_list) + # pylint: disable=protected-access + feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) + fetches = [t._as_tf_output() for t in fetch_list] + targets = [op._c_op for op in target_list] + # pylint: enable=protected-access def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): # Ensure any changes to the graph are reflected in the runtime. @@ -1335,22 +1291,8 @@ class BaseSession(SessionInterface): raise type(e)(node_def, op, message) def _extend_graph(self): - if self._created_with_new_api: - with self._graph._lock: # pylint: disable=protected-access - tf_session.ExtendSession(self._session) - else: - # Ensure any changes to the graph are reflected in the runtime. - with self._extend_lock: - if self._graph.version > self._current_version: - # pylint: disable=protected-access - graph_def, self._current_version = self._graph._as_graph_def( - from_version=self._current_version, add_shapes=self._add_shapes) - # pylint: enable=protected-access - - with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_ExtendGraph(self._session, - graph_def.SerializeToString(), status) - self._opened = True + with self._graph._lock: # pylint: disable=protected-access + tf_session.ExtendSession(self._session) # The threshold to run garbage collection to delete dead tensors. _DEAD_HANDLES_THRESHOLD = 10 @@ -1403,24 +1345,13 @@ class BaseSession(SessionInterface): def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata): - if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper( - self._session, options, feed_dict, fetch_list, target_list, - run_metadata) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_Run( - self._session, options, feed_dict, fetch_list, target_list, - status, run_metadata) + return tf_session.TF_SessionRun_wrapper( + self._session, options, feed_dict, fetch_list, target_list, + run_metadata) def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): - if self._created_with_new_api: - return tf_session.TF_SessionPRun_wrapper( - self._session, handle, feed_dict, fetch_list) - else: - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRun( - self._session, handle, feed_dict, fetch_list, status) + return tf_session.TF_SessionPRun_wrapper( + self._session, handle, feed_dict, fetch_list) # pylint: disable=protected-access class _Callable(object): @@ -1433,12 +1364,8 @@ class BaseSession(SessionInterface): compat.as_bytes(callable_options.SerializeToString())) try: with errors.raise_exception_on_not_ok_status() as status: - if session._created_with_new_api: - self._handle = tf_session.TF_SessionMakeCallable( - session._session, options_ptr, status) - else: - self._handle = tf_session.TF_DeprecatedSessionMakeCallable( - session._session, options_ptr, status) + self._handle = tf_session.TF_SessionMakeCallable( + session._session, options_ptr, status) finally: tf_session.TF_DeleteBuffer(options_ptr) @@ -1446,12 +1373,8 @@ class BaseSession(SessionInterface): # TODO(b/74355905): Support argument and return value nested structures, # and tensor-like objects such as SparseTensors. with errors.raise_exception_on_not_ok_status() as status: - if self._session._created_with_new_api: - return tf_session.TF_SessionRunCallable( - self._session._session, self._handle, args, status, None) - else: - return tf_session.TF_DeprecatedSessionRunCallable( - self._session._session, self._handle, args, status, None) + return tf_session.TF_SessionRunCallable( + self._session._session, self._handle, args, status, None) def __del__(self): # NOTE(mrry): It is possible that `self._session.__del__()` could be @@ -1459,12 +1382,8 @@ class BaseSession(SessionInterface): # will be `None`. if self._handle is not None and self._session._session is not None: with errors.raise_exception_on_not_ok_status() as status: - if self._session._created_with_new_api: - tf_session.TF_SessionReleaseCallable( - self._session._session, self._handle, status) - else: - tf_session.TF_DeprecatedSessionReleaseCallable( - self._session._session, self._handle, status) + tf_session.TF_SessionReleaseCallable( + self._session._session, self._handle, status) # pylint: enable=protected-access # TODO(b/74355905): Reimplement `Session.make_callable()` using this method |