aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py113
1 files changed, 83 insertions, 30 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 11f99c030f..38794f1612 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -963,13 +963,14 @@ def zeros(shape, dtype=None, name=None):
[ 0., 0., 0., 0.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.ones')
@@ -996,13 +997,14 @@ def ones(shape, dtype=None, name=None):
[ 1., 1., 1., 1.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.eye')
@@ -2795,10 +2797,15 @@ class Function(object):
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
# The main use case of `fetches` being passed to a model is the ability
- # to run custom updates (since the outputs of fetches are never returned).
+ # to run custom updates
# This requires us to wrap fetches in `identity` ops.
self.fetches = [array_ops.identity(x) for x in self.fetches]
self.session_kwargs = session_kwargs
+ # This mapping keeps track of the function that should receive the
+ # output from a fetch in `fetches`: { fetch: function(fetch_output) }
+ # A Callback can use this to register a function with access to the
+ # output values for a fetch it added.
+ self.fetch_callbacks = dict()
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
@@ -2808,6 +2815,7 @@ class Function(object):
self._feed_arrays = None
self._feed_symbols = None
self._symbol_vals = None
+ self._fetches = None
self._session = None
def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
@@ -2853,8 +2861,14 @@ class Function(object):
self._feed_arrays = feed_arrays
self._feed_symbols = feed_symbols
self._symbol_vals = symbol_vals
+ self._fetches = list(self.fetches)
self._session = session
+ def _call_fetch_callbacks(self, fetches_output):
+ for fetch, output in zip(self._fetches, fetches_output):
+ if fetch in self.fetch_callbacks:
+ self.fetch_callbacks[fetch](output)
+
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
@@ -2891,14 +2905,14 @@ class Function(object):
np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
# Refresh callable if anything has changed.
- if (self._callable_fn is None or
- feed_arrays != self._feed_arrays or
+ if (self._callable_fn is None or feed_arrays != self._feed_arrays or
symbol_vals != self._symbol_vals or
- feed_symbols != self._feed_symbols or
+ feed_symbols != self._feed_symbols or self.fetches != self._fetches or
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
fetched = self._callable_fn(*array_vals)
+ self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
@@ -3358,26 +3372,48 @@ def in_test_phase(x, alt, training=None):
@tf_export('keras.backend.relu')
-def relu(x, alpha=0., max_value=None):
+def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified linear unit.
With default values, it returns element-wise `max(x, 0)`.
+ Otherwise, it follows:
+ `f(x) = max_value` for `x >= max_value`,
+ `f(x) = x` for `threshold <= x < max_value`,
+ `f(x) = alpha * (x - threshold)` otherwise.
+
Arguments:
x: A tensor or variable.
alpha: A scalar, slope of negative section (default=`0.`).
- max_value: Saturation threshold.
+ max_value: float. Saturation threshold.
+ threshold: float. Threshold value for thresholded activation.
Returns:
A tensor.
"""
+ clip_max = max_value is not None
+
if alpha != 0.:
- negative_part = nn.relu(-x)
- x = nn.relu(x)
- if max_value is not None:
+ if threshold != 0:
+ negative_part = nn.relu(-x + threshold)
+ else:
+ negative_part = nn.relu(-x)
+
+ if threshold != 0:
+ # computes x for x > threshold else 0
+ x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
+ elif max_value == 6:
+ # if no threshold, then can use nn.relu6 native TF op for performance
+ x = nn.relu6(x)
+ clip_max = False
+ else:
+ x = nn.relu(x)
+
+ if clip_max:
max_value = _to_tensor(max_value, x.dtype.base_dtype)
zero = _to_tensor(0., x.dtype.base_dtype)
x = clip_ops.clip_by_value(x, zero, max_value)
+
if alpha != 0.:
alpha = _to_tensor(alpha, x.dtype.base_dtype)
x -= alpha * negative_part
@@ -3444,7 +3480,7 @@ def softsign(x):
@tf_export('keras.backend.categorical_crossentropy')
-def categorical_crossentropy(target, output, from_logits=False):
+def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
Arguments:
@@ -3454,28 +3490,33 @@ def categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
+ axis: Int specifying the channels axis. `axis=-1` corresponds to data
+ format `channels_last', and `axis=1` corresponds to data format
+ `channels_first`.
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
+ rank = len(output.get_shape())
+ axis = axis % rank
# Note: nn.softmax_cross_entropy_with_logits_v2
# expects logits, Keras expects probabilities.
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
- output = output / math_ops.reduce_sum( # pylint: disable=g-no-augmented-assignment
- output, len(output.get_shape()) - 1, True)
+ output = output / math_ops.reduce_sum(output, axis, True)
# manual computation of crossentropy
epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
- return -math_ops.reduce_sum(
- target * math_ops.log(output),
- axis=len(output.get_shape()) - 1)
+ return -math_ops.reduce_sum(target * math_ops.log(output), axis)
else:
return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output)
@tf_export('keras.backend.sparse_categorical_crossentropy')
-def sparse_categorical_crossentropy(target, output, from_logits=False):
+def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.
Arguments:
@@ -3485,10 +3526,22 @@ def sparse_categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
+ axis: Int specifying the channels axis. `axis=-1` corresponds to data
+ format `channels_last', and `axis=1` corresponds to data format
+ `channels_first`.
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
+ rank = len(output.get_shape())
+ axis = axis % rank
+ if axis != rank - 1:
+ permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis]
+ output = array_ops.transpose(output, perm=permutation)
+
# Note: nn.sparse_softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits: