diff options
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 113 |
1 files changed, 83 insertions, 30 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 11f99c030f..38794f1612 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -963,13 +963,14 @@ def zeros(shape, dtype=None, name=None): [ 0., 0., 0., 0.]], dtype=float32) ``` """ - if dtype is None: - dtype = floatx() - tf_dtype = dtypes_module.as_dtype(dtype) - v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name) - if py_all(v.get_shape().as_list()): - return variable(v, dtype=dtype, name=name) - return v + with ops.init_scope(): + if dtype is None: + dtype = floatx() + tf_dtype = dtypes_module.as_dtype(dtype) + v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.get_shape().as_list()): + return variable(v, dtype=dtype, name=name) + return v @tf_export('keras.backend.ones') @@ -996,13 +997,14 @@ def ones(shape, dtype=None, name=None): [ 1., 1., 1., 1.]], dtype=float32) ``` """ - if dtype is None: - dtype = floatx() - tf_dtype = dtypes_module.as_dtype(dtype) - v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name) - if py_all(v.get_shape().as_list()): - return variable(v, dtype=dtype, name=name) - return v + with ops.init_scope(): + if dtype is None: + dtype = floatx() + tf_dtype = dtypes_module.as_dtype(dtype) + v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.get_shape().as_list()): + return variable(v, dtype=dtype, name=name) + return v @tf_export('keras.backend.eye') @@ -2795,10 +2797,15 @@ class Function(object): if not isinstance(self.fetches, list): self.fetches = [self.fetches] # The main use case of `fetches` being passed to a model is the ability - # to run custom updates (since the outputs of fetches are never returned). + # to run custom updates # This requires us to wrap fetches in `identity` ops. self.fetches = [array_ops.identity(x) for x in self.fetches] self.session_kwargs = session_kwargs + # This mapping keeps track of the function that should receive the + # output from a fetch in `fetches`: { fetch: function(fetch_output) } + # A Callback can use this to register a function with access to the + # output values for a fetch it added. + self.fetch_callbacks = dict() if session_kwargs: raise ValueError('Some keys in session_kwargs are not supported at this ' @@ -2808,6 +2815,7 @@ class Function(object): self._feed_arrays = None self._feed_symbols = None self._symbol_vals = None + self._fetches = None self._session = None def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session): @@ -2853,8 +2861,14 @@ class Function(object): self._feed_arrays = feed_arrays self._feed_symbols = feed_symbols self._symbol_vals = symbol_vals + self._fetches = list(self.fetches) self._session = session + def _call_fetch_callbacks(self, fetches_output): + for fetch, output in zip(self._fetches, fetches_output): + if fetch in self.fetch_callbacks: + self.fetch_callbacks[fetch](output) + def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') @@ -2891,14 +2905,14 @@ class Function(object): np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name)) # Refresh callable if anything has changed. - if (self._callable_fn is None or - feed_arrays != self._feed_arrays or + if (self._callable_fn is None or feed_arrays != self._feed_arrays or symbol_vals != self._symbol_vals or - feed_symbols != self._feed_symbols or + feed_symbols != self._feed_symbols or self.fetches != self._fetches or session != self._session): self._make_callable(feed_arrays, feed_symbols, symbol_vals, session) fetched = self._callable_fn(*array_vals) + self._call_fetch_callbacks(fetched[-len(self._fetches):]) return fetched[:len(self.outputs)] @@ -3358,26 +3372,48 @@ def in_test_phase(x, alt, training=None): @tf_export('keras.backend.relu') -def relu(x, alpha=0., max_value=None): +def relu(x, alpha=0., max_value=None, threshold=0): """Rectified linear unit. With default values, it returns element-wise `max(x, 0)`. + Otherwise, it follows: + `f(x) = max_value` for `x >= max_value`, + `f(x) = x` for `threshold <= x < max_value`, + `f(x) = alpha * (x - threshold)` otherwise. + Arguments: x: A tensor or variable. alpha: A scalar, slope of negative section (default=`0.`). - max_value: Saturation threshold. + max_value: float. Saturation threshold. + threshold: float. Threshold value for thresholded activation. Returns: A tensor. """ + clip_max = max_value is not None + if alpha != 0.: - negative_part = nn.relu(-x) - x = nn.relu(x) - if max_value is not None: + if threshold != 0: + negative_part = nn.relu(-x + threshold) + else: + negative_part = nn.relu(-x) + + if threshold != 0: + # computes x for x > threshold else 0 + x = x * math_ops.cast(math_ops.greater(x, threshold), floatx()) + elif max_value == 6: + # if no threshold, then can use nn.relu6 native TF op for performance + x = nn.relu6(x) + clip_max = False + else: + x = nn.relu(x) + + if clip_max: max_value = _to_tensor(max_value, x.dtype.base_dtype) zero = _to_tensor(0., x.dtype.base_dtype) x = clip_ops.clip_by_value(x, zero, max_value) + if alpha != 0.: alpha = _to_tensor(alpha, x.dtype.base_dtype) x -= alpha * negative_part @@ -3444,7 +3480,7 @@ def softsign(x): @tf_export('keras.backend.categorical_crossentropy') -def categorical_crossentropy(target, output, from_logits=False): +def categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy between an output tensor and a target tensor. Arguments: @@ -3454,28 +3490,33 @@ def categorical_crossentropy(target, output, from_logits=False): case `output` is expected to be the logits). from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last', and `axis=1` corresponds to data format + `channels_first`. Returns: Output tensor. + + Raises: + ValueError: if `axis` is neither -1 nor one of the axes of `output`. """ + rank = len(output.get_shape()) + axis = axis % rank # Note: nn.softmax_cross_entropy_with_logits_v2 # expects logits, Keras expects probabilities. if not from_logits: # scale preds so that the class probas of each sample sum to 1 - output = output / math_ops.reduce_sum( # pylint: disable=g-no-augmented-assignment - output, len(output.get_shape()) - 1, True) + output = output / math_ops.reduce_sum(output, axis, True) # manual computation of crossentropy epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype) output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_) - return -math_ops.reduce_sum( - target * math_ops.log(output), - axis=len(output.get_shape()) - 1) + return -math_ops.reduce_sum(target * math_ops.log(output), axis) else: return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output) @tf_export('keras.backend.sparse_categorical_crossentropy') -def sparse_categorical_crossentropy(target, output, from_logits=False): +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy with integer targets. Arguments: @@ -3485,10 +3526,22 @@ def sparse_categorical_crossentropy(target, output, from_logits=False): case `output` is expected to be the logits). from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last', and `axis=1` corresponds to data format + `channels_first`. Returns: Output tensor. + + Raises: + ValueError: if `axis` is neither -1 nor one of the axes of `output`. """ + rank = len(output.get_shape()) + axis = axis % rank + if axis != rank - 1: + permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis] + output = array_ops.transpose(output, perm=permutation) + # Note: nn.sparse_softmax_cross_entropy_with_logits # expects logits, Keras expects probabilities. if not from_logits: |