aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-06-07 08:47:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 08:50:48 -0700
commit537e8c7a28b6b793eb570c957c4e90bf81ce9c3b (patch)
tree815217e5667b97b7afc7e70a28c967f2b787da2a /tensorflow/python/client
parent3f31670ddc140a62ffac9d8b9310f71bdfbae629 (diff)
Remove _USE_C_API staging from session.py.
PiperOrigin-RevId: 199641205
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py159
1 files changed, 39 insertions, 120 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5507d011bb..648e35cdf2 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -619,21 +619,12 @@ class BaseSession(SessionInterface):
self._config = None
self._add_shapes = False
- # pylint: disable=protected-access
- # We cache _USE_C_API's value because some test cases will create a session
- # with _USE_C_API = False but set it back to True before calling close().
- self._created_with_new_api = ops._USE_C_API
- # pylint: enable=protected-access
-
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
- if self._created_with_new_api:
- # pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
- # pylint: enable=protected-access
- else:
- self._session = tf_session.TF_NewDeprecatedSession(opts)
+ # pylint: disable=protected-access
+ self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ # pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -660,11 +651,7 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
- if self._created_with_new_api:
- raw_device_list = tf_session.TF_SessionListDevices(self._session)
- else:
- raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
- self._session)
+ raw_device_list = tf_session.TF_SessionListDevices(self._session)
device_list = []
size = tf_session.TF_DeviceListCount(raw_device_list)
for i in range(size):
@@ -684,16 +671,9 @@ class BaseSession(SessionInterface):
tf.errors.OpError: Or one of its subclasses if an error occurs while
closing the TensorFlow session.
"""
- if self._created_with_new_api:
- if self._session and not self._closed:
- self._closed = True
- tf_session.TF_CloseSession(self._session)
-
- else:
- with self._extend_lock:
- if self._opened and not self._closed:
- self._closed = True
- tf_session.TF_CloseDeprecatedSession(self._session)
+ if self._session and not self._closed:
+ self._closed = True
+ tf_session.TF_CloseSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@@ -703,10 +683,7 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
- if self._created_with_new_api:
- tf_session.TF_DeleteSession(self._session)
- else:
- tf_session.TF_DeleteDeprecatedSession(self._session)
+ tf_session.TF_DeleteSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@@ -1005,12 +982,9 @@ class BaseSession(SessionInterface):
try:
subfeed_t = self.graph.as_graph_element(
subfeed, allow_tensor=True, allow_operation=False)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feed_list.append(subfeed_t._as_tf_output())
- # pylint: enable=protected-access
- else:
- feed_list.append(compat.as_bytes(subfeed_t.name))
+ # pylint: disable=protected-access
+ feed_list.append(subfeed_t._as_tf_output())
+ # pylint: enable=protected-access
except Exception as e:
e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
e.args = (e.message,)
@@ -1023,22 +997,13 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
- if self._created_with_new_api:
- return tf_session.TF_SessionPRunSetup_wrapper(
- session, feed_list, fetch_list, target_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
- target_list, status)
+ return tf_session.TF_SessionPRunSetup_wrapper(
+ session, feed_list, fetch_list, target_list)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
- final_targets = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- final_fetches = _name_list(fetch_handler.fetches())
- final_targets = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
+ final_targets = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
final_targets)
@@ -1196,14 +1161,10 @@ class BaseSession(SessionInterface):
# Create a fetch handler to take care of the structure of fetches.
fetch_handler = _FetchHandler(self._graph, fetches, {})
- if self._created_with_new_api:
- # pylint: disable=protected-access
- fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
- target_list = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- fetch_list = _name_list(fetch_handler.fetches())
- target_list = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
+ target_list = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
def _callable_template_with_options_and_metadata(fetch_list,
target_list,
@@ -1289,16 +1250,11 @@ class BaseSession(SessionInterface):
Raises:
tf.errors.OpError: Or one of its subclasses on error.
"""
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
- fetches = [t._as_tf_output() for t in fetch_list]
- targets = [op._c_op for op in target_list]
- # pylint: enable=protected-access
- else:
- feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
- fetches = _name_list(fetch_list)
- targets = _name_list(target_list)
+ # pylint: disable=protected-access
+ feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
+ fetches = [t._as_tf_output() for t in fetch_list]
+ targets = [op._c_op for op in target_list]
+ # pylint: enable=protected-access
def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
# Ensure any changes to the graph are reflected in the runtime.
@@ -1335,22 +1291,8 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
- if self._created_with_new_api:
- with self._graph._lock: # pylint: disable=protected-access
- tf_session.ExtendSession(self._session)
- else:
- # Ensure any changes to the graph are reflected in the runtime.
- with self._extend_lock:
- if self._graph.version > self._current_version:
- # pylint: disable=protected-access
- graph_def, self._current_version = self._graph._as_graph_def(
- from_version=self._current_version, add_shapes=self._add_shapes)
- # pylint: enable=protected-access
-
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_ExtendGraph(self._session,
- graph_def.SerializeToString(), status)
- self._opened = True
+ with self._graph._lock: # pylint: disable=protected-access
+ tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.
_DEAD_HANDLES_THRESHOLD = 10
@@ -1403,24 +1345,13 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
- if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- self._session, options, feed_dict, fetch_list, target_list,
- run_metadata)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_Run(
- self._session, options, feed_dict, fetch_list, target_list,
- status, run_metadata)
+ return tf_session.TF_SessionRun_wrapper(
+ self._session, options, feed_dict, fetch_list, target_list,
+ run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
- if self._created_with_new_api:
- return tf_session.TF_SessionPRun_wrapper(
- self._session, handle, feed_dict, fetch_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRun(
- self._session, handle, feed_dict, fetch_list, status)
+ return tf_session.TF_SessionPRun_wrapper(
+ self._session, handle, feed_dict, fetch_list)
# pylint: disable=protected-access
class _Callable(object):
@@ -1433,12 +1364,8 @@ class BaseSession(SessionInterface):
compat.as_bytes(callable_options.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
- if session._created_with_new_api:
- self._handle = tf_session.TF_SessionMakeCallable(
- session._session, options_ptr, status)
- else:
- self._handle = tf_session.TF_DeprecatedSessionMakeCallable(
- session._session, options_ptr, status)
+ self._handle = tf_session.TF_SessionMakeCallable(
+ session._session, options_ptr, status)
finally:
tf_session.TF_DeleteBuffer(options_ptr)
@@ -1446,12 +1373,8 @@ class BaseSession(SessionInterface):
# TODO(b/74355905): Support argument and return value nested structures,
# and tensor-like objects such as SparseTensors.
with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- return tf_session.TF_SessionRunCallable(
- self._session._session, self._handle, args, status, None)
- else:
- return tf_session.TF_DeprecatedSessionRunCallable(
- self._session._session, self._handle, args, status, None)
+ return tf_session.TF_SessionRunCallable(
+ self._session._session, self._handle, args, status, None)
def __del__(self):
# NOTE(mrry): It is possible that `self._session.__del__()` could be
@@ -1459,12 +1382,8 @@ class BaseSession(SessionInterface):
# will be `None`.
if self._handle is not None and self._session._session is not None:
with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- tf_session.TF_SessionReleaseCallable(
- self._session._session, self._handle, status)
- else:
- tf_session.TF_DeprecatedSessionReleaseCallable(
- self._session._session, self._handle, status)
+ tf_session.TF_SessionReleaseCallable(
+ self._session._session, self._handle, status)
# pylint: enable=protected-access
# TODO(b/74355905): Reimplement `Session.make_callable()` using this method