aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD19
-rw-r--r--tensorflow/python/debug/cli/readline_ui.py8
-rw-r--r--tensorflow/python/debug/wrappers/grpc_wrapper.py11
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py17
-rw-r--r--tensorflow/python/estimator/canned/head.py9
-rw-r--r--tensorflow/python/estimator/estimator.py5
-rw-r--r--tensorflow/python/estimator/run_config.py33
-rw-r--r--tensorflow/python/estimator/run_config_test.py24
-rw-r--r--tensorflow/python/feature_column/feature_column.py1
-rw-r--r--tensorflow/python/framework/dtypes.py14
-rw-r--r--tensorflow/python/framework/graph_util_impl.py2
-rw-r--r--tensorflow/python/framework/graph_util_test.py2
-rw-r--r--tensorflow/python/framework/load_library.py2
-rw-r--r--tensorflow/python/framework/python_op_gen.i8
-rw-r--r--tensorflow/python/framework/test_util.py2
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/normalization.py4
-rw-r--r--tensorflow/python/kernel_tests/BUILD26
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py85
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py7
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py12
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py55
-rw-r--r--tensorflow/python/kernel_tests/norm_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py32
-rw-r--r--tensorflow/python/kernel_tests/random/multinomial_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/random/random_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/string_strip_op_test.py56
-rw-r--r--tensorflow/python/lib/core/py_func.cc3
-rw-r--r--tensorflow/python/ops/array_ops.py15
-rw-r--r--tensorflow/python/ops/distributions/categorical.py2
-rw-r--r--tensorflow/python/ops/embedding_ops.py26
-rw-r--r--tensorflow/python/ops/histogram_ops.py1
-rw-r--r--tensorflow/python/ops/image_ops_impl.py74
-rw-r--r--tensorflow/python/ops/init_ops.py18
-rw-r--r--tensorflow/python/ops/linalg_ops.py77
-rw-r--r--tensorflow/python/ops/linalg_ops_impl.py73
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py23
-rw-r--r--tensorflow/python/ops/math_ops.py38
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_impl.py11
-rw-r--r--tensorflow/python/ops/nn_ops.py8
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py4
-rw-r--r--tensorflow/python/profiler/tfprof_logger_test.py2
-rw-r--r--tensorflow/python/tools/saved_model_cli.py3
-rw-r--r--tensorflow/python/training/saver_test.py2
-rw-r--r--tensorflow/python/util/compat.py7
48 files changed, 652 insertions, 218 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9dc03d7cdb..8e7f0cadad 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1946,7 +1946,8 @@ py_library(
":array_ops",
":constant_op",
":dtypes",
- ":linalg_ops",
+ ":linalg_ops_gen",
+ ":linalg_ops_impl",
":math_ops",
":nn_ops",
":random_ops",
@@ -1997,7 +1998,22 @@ py_library(
":array_ops",
":dtypes",
":framework_ops",
+ ":functional_ops",
":linalg_ops_gen",
+ ":linalg_ops_impl",
+ ":math_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "linalg_ops_impl",
+ srcs = ["ops/linalg_ops_impl.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":dtypes",
+ ":framework_ops",
":math_ops",
"//third_party/py/numpy",
],
@@ -3493,6 +3509,7 @@ tf_py_wrap_cc(
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
+ "//tensorflow/python/eager:python_eager_op_gen",
"//util/python:python_headers",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps() +
diff --git a/tensorflow/python/debug/cli/readline_ui.py b/tensorflow/python/debug/cli/readline_ui.py
index 151638789f..3296e45d07 100644
--- a/tensorflow/python/debug/cli/readline_ui.py
+++ b/tensorflow/python/debug/cli/readline_ui.py
@@ -19,6 +19,8 @@ from __future__ import print_function
import readline
+import six
+
from tensorflow.python.debug.cli import base_ui
from tensorflow.python.debug.cli import debugger_cli_common
@@ -39,11 +41,7 @@ class ReadlineUI(base_ui.BaseUI):
readline.set_completer(self._readline_complete)
readline.parse_and_bind("tab: complete")
- # For Python 2-3 compatibility.
- try:
- self._input = raw_input
- except NameError:
- self._input = input
+ self._input = six.moves.input
def _readline_complete(self, text, state):
context, prefix, except_last_word = self._analyze_tab_complete_input(text)
diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py
index fb9494f576..1f9c8fa5a9 100644
--- a/tensorflow/python/debug/wrappers/grpc_wrapper.py
+++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py
@@ -21,6 +21,8 @@ import signal
import sys
import traceback
+import six
+
# Google-internal import(s).
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.wrappers import framework
@@ -140,14 +142,9 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
def _signal_handler(unused_signal, unused_frame):
- try:
- input_func = raw_input
- except NameError:
- # Python 3 does not have raw_input.
- input_func = input
-
while True:
- response = input_func("\nSIGINT received. Quit program? (Y/n): ").strip()
+ response = six.moves.input(
+ "\nSIGINT received. Quit program? (Y/n): ").strip()
if response in ("", "Y", "y"):
sys.exit(0)
elif response in ("N", "n"):
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 6705cd31e2..5e4604fda4 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -31,15 +31,18 @@ from tensorflow.python.training import session_run_hook
class LocalCLIDebugHook(session_run_hook.SessionRunHook):
"""Command-line-interface debugger hook.
- Can be used as a monitor/hook for `tf.train.MonitoredSession`s and
- `tf.contrib.learn`'s `Estimator`s and `Experiment`s.
+ Can be used as a hook for `tf.train.MonitoredSession`s and
+ `tf.estimator.Estimator`s. Provides a substitute for
+ `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly
+ available.
"""
def __init__(self, ui_type="curses", dump_root=None, thread_name_filter=None):
"""Create a local debugger command-line interface (CLI) hook.
Args:
- ui_type: (str) user-interface type.
+ ui_type: (`str`) requested user-interface type. Currently supported:
+ (curses | readline).
dump_root: (`str`) optional path to the dump root directory. Must be a
directory that does not exist or an empty directory. If the directory
does not exist, it will be created by the debugger core during debug
@@ -153,8 +156,8 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook):
class DumpingDebugHook(session_run_hook.SessionRunHook):
"""A debugger hook that dumps debug data to filesystem.
- Can be used as a monitor/hook for `tf.train.MonitoredSession`s and
- `tf.contrib.learn`'s `Estimator`s and `Experiment`s.
+ Can be used as a hook for `tf.train.MonitoredSession`s and
+ `tf.estimator.Estimator`s.
"""
def __init__(self,
@@ -229,8 +232,8 @@ class GrpcDebugHook(session_run_hook.SessionRunHook):
When the arguments of debug_utils.watch_graph changes, strongly consider
changing arguments here too so that features are available to tflearn users.
- Can be used as a monitor/hook for `tf.train.MonitoredSession`s and
- `tf.contrib.learn`'s `Estimator`s and `Experiment`s.
+ Can be used as a hook for `tf.train.MonitoredSession`s and
+ `tf.estimator.Estimator`s.
"""
def __init__(self,
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index c365ea8b4a..efa4bdf598 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -263,9 +263,12 @@ def _check_dense_labels_match_logits_and_reshape(
if (dim1 is not None) and (dim1 != expected_labels_dimension):
raise ValueError(
'Mismatched label shape. '
- 'Classifier configured with n_classes=%s. Received %s. '
- 'Suggested Fix: check your n_classes argument to the estimator '
- 'and/or the shape of your label.' %
+ 'Expected labels dimension=%s. Received %s. '
+ 'Suggested Fix:'
+ 'If your classifier expects one-hot encoding label,'
+ 'check your n_classes argument to the estimator'
+ 'and/or the shape of your label.'
+ 'Otherwise, check the shape of your label.' %
(expected_labels_dimension, dim1))
expected_labels_shape = array_ops.concat(
[logits_shape[:-1], [expected_labels_dimension]], axis=0)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 351fcb6423..2f1212d5a2 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -207,7 +207,8 @@ class Estimator(object):
else:
self._session_config = self._config.session_config
- self._device_fn = _get_replica_device_setter(self._config)
+ self._device_fn = self._config.device_fn or \
+ _get_replica_device_setter(self._config)
if model_fn is None:
raise ValueError('model_fn must be provided to Estimator.')
@@ -716,7 +717,7 @@ class Estimator(object):
batch_length = batch_length or value.shape[0]
if value.shape[0] != batch_length:
raise ValueError('Batch length of predictions should be same. %s has '
- 'different batch length then others.' % key)
+ 'different batch length than others.' % key)
return batch_length
def _extract_keys(self, predictions, predict_keys):
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index dab442aeda..8162b249f1 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -27,11 +27,13 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+from tensorflow.python.estimator import util
from tensorflow.python.util import compat_internal
from tensorflow.python.util.tf_export import tf_export
_USE_DEFAULT = object()
+_VALID_DEVICE_FN_ARGS = set(['op'])
# A list of the property names in RunConfig that the user is allowed to change.
_DEFAULT_REPLACEABLE_LIST = [
@@ -44,7 +46,8 @@ _DEFAULT_REPLACEABLE_LIST = [
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
'log_step_count_steps',
- 'train_distribute'
+ 'train_distribute',
+ 'device_fn'
]
_SAVE_CKPT_ERR = (
@@ -279,6 +282,11 @@ def _validate_properties(run_config):
_validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
message='tf_random_seed must be integer.')
+ _validate('device_fn', lambda device_fn: six.callable(device_fn) and
+ set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+ message='device_fn must be callable with exactly'
+ ' one argument "op".')
+
class TaskType(object):
MASTER = 'master'
@@ -302,7 +310,8 @@ class RunConfig(object):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
- train_distribute=None):
+ train_distribute=None,
+ device_fn=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -430,6 +439,10 @@ class RunConfig(object):
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
according to the policy specified by that strategy.
+ device_fn: A callable invoked for every `Operation` that takes the
+ `Operation` and returns the device string. If `None`, defaults to
+ the device function returned by `tf.train.replica_device_setter`
+ with round-robin strategy.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -466,7 +479,8 @@ class RunConfig(object):
keep_checkpoint_max=keep_checkpoint_max,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
log_step_count_steps=log_step_count_steps,
- train_distribute=train_distribute)
+ train_distribute=train_distribute,
+ device_fn=device_fn)
self._init_distributed_setting_from_environment_var(tf_config)
@@ -569,6 +583,16 @@ class RunConfig(object):
return self._cluster_spec
@property
+ def device_fn(self):
+ """Returns the device_fn.
+
+ If device_fn is not `None`, it overrides the default
+ device function used in `Estimator`.
+ Otherwise the default one is used.
+ """
+ return self._device_fn
+
+ @property
def evaluation_master(self):
return self._evaluation_master
@@ -697,7 +721,8 @@ class RunConfig(object):
- `keep_checkpoint_max`,
- `keep_checkpoint_every_n_hours`,
- `log_step_count_steps`,
- - `train_distribute`.
+ - `train_distribute`,
+ - `device_fn`.
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py
index a3eef4c53f..c8b12605e1 100644
--- a/tensorflow/python/estimator/run_config_test.py
+++ b/tensorflow/python/estimator/run_config_test.py
@@ -42,6 +42,7 @@ _SESSION_CONFIG_ERR = 'session_config must be instance of ConfigProto'
_KEEP_CKPT_MAX_ERR = 'keep_checkpoint_max should be >= 0'
_KEEP_CKPT_HOURS_ERR = 'keep_checkpoint_every_n_hours should be > 0'
_TF_RANDOM_SEED_ERR = 'tf_random_seed must be integer'
+_DEVICE_FN_ERR = 'device_fn must be callable with exactly one argument "op".'
_ONE_CHIEF_ERR = 'The "cluster" in TF_CONFIG must have only one "chief" node.'
_ONE_MASTER_ERR = 'The "cluster" in TF_CONFIG must have only one "master" node.'
_INVALID_TASK_TYPE_FOR_EVAL_MASTER = (
@@ -83,6 +84,7 @@ class RunConfigTest(test.TestCase):
self.assertEqual(5, config.keep_checkpoint_max)
self.assertEqual(10000, config.keep_checkpoint_every_n_hours)
self.assertIsNone(config.service)
+ self.assertIsNone(config.device_fn)
def test_model_dir(self):
empty_config = run_config_lib.RunConfig()
@@ -93,6 +95,7 @@ class RunConfigTest(test.TestCase):
def test_replace_with_allowed_properties(self):
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ device_fn = lambda op: "/cpu:0"
config = run_config_lib.RunConfig().replace(
tf_random_seed=11,
@@ -100,13 +103,15 @@ class RunConfigTest(test.TestCase):
save_checkpoints_secs=14,
session_config=session_config,
keep_checkpoint_max=16,
- keep_checkpoint_every_n_hours=17)
+ keep_checkpoint_every_n_hours=17,
+ device_fn=device_fn)
self.assertEqual(11, config.tf_random_seed)
self.assertEqual(12, config.save_summary_steps)
self.assertEqual(14, config.save_checkpoints_secs)
self.assertEqual(session_config, config.session_config)
self.assertEqual(16, config.keep_checkpoint_max)
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
+ self.assertEqual(device_fn, config.device_fn)
def test_replace_none_value(self):
config = run_config_lib.RunConfig().replace(
@@ -117,7 +122,8 @@ class RunConfigTest(test.TestCase):
save_checkpoints_steps=None,
session_config=None,
keep_checkpoint_max=None,
- keep_checkpoint_every_n_hours=None)
+ keep_checkpoint_every_n_hours=None,
+ device_fn=None)
self.assertIsNone(config.tf_random_seed)
self.assertIsNone(config.model_dir)
self.assertIsNone(config.save_summary_steps)
@@ -126,6 +132,7 @@ class RunConfigTest(test.TestCase):
self.assertIsNone(config.session_config)
self.assertIsNone(config.keep_checkpoint_max)
self.assertIsNone(config.keep_checkpoint_every_n_hours)
+ self.assertIsNone(config.device_fn)
def test_replace_with_disallowallowed_properties(self):
config = run_config_lib.RunConfig()
@@ -166,9 +173,12 @@ class RunConfigTest(test.TestCase):
config.replace(keep_checkpoint_every_n_hours=0)
with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
config.replace(tf_random_seed=1.0)
+ with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
+ config.replace(device_fn=lambda x, y: 0)
def test_init_with_allowed_properties(self):
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ device_fn = lambda op: "/cpu:0"
config = run_config_lib.RunConfig(
tf_random_seed=11,
@@ -176,13 +186,15 @@ class RunConfigTest(test.TestCase):
save_checkpoints_secs=14,
session_config=session_config,
keep_checkpoint_max=16,
- keep_checkpoint_every_n_hours=17)
+ keep_checkpoint_every_n_hours=17,
+ device_fn=device_fn)
self.assertEqual(11, config.tf_random_seed)
self.assertEqual(12, config.save_summary_steps)
self.assertEqual(14, config.save_checkpoints_secs)
self.assertEqual(session_config, config.session_config)
self.assertEqual(16, config.keep_checkpoint_max)
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
+ self.assertEqual(device_fn, config.device_fn)
def test_init_none_value(self):
config = run_config_lib.RunConfig(
@@ -193,7 +205,8 @@ class RunConfigTest(test.TestCase):
save_checkpoints_steps=None,
session_config=None,
keep_checkpoint_max=None,
- keep_checkpoint_every_n_hours=None)
+ keep_checkpoint_every_n_hours=None,
+ device_fn=None)
self.assertIsNone(config.tf_random_seed)
self.assertIsNone(config.model_dir)
self.assertIsNone(config.save_summary_steps)
@@ -202,6 +215,7 @@ class RunConfigTest(test.TestCase):
self.assertIsNone(config.session_config)
self.assertIsNone(config.keep_checkpoint_max)
self.assertIsNone(config.keep_checkpoint_every_n_hours)
+ self.assertIsNone(config.device_fn)
def test_init_invalid_values(self):
with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
@@ -220,6 +234,8 @@ class RunConfigTest(test.TestCase):
run_config_lib.RunConfig(keep_checkpoint_every_n_hours=0)
with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
run_config_lib.RunConfig(tf_random_seed=1.0)
+ with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
+ run_config_lib.RunConfig(device_fn=lambda x: "/cpu:0")
class RunConfigDistributedSettingTest(test.TestCase):
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index a7c4eabcb2..c16c3cda48 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -162,7 +162,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.util.tf_export import tf_export
def _internal_input_layer(features,
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 807582bd7e..7f9ef53457 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -700,11 +700,13 @@ def as_dtype(type_value):
if type_value.type == np.string_ or type_value.type == np.unicode_:
return string
- for key, val in _NP_TO_TF:
- try:
- if key == type_value:
- return val
- except TypeError as e:
- raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e))
+ if isinstance(type_value, (type, np.dtype)):
+ for key, val in _NP_TO_TF:
+ try:
+ if key == type_value:
+ return val
+ except TypeError as e:
+ raise TypeError("Cannot convert {} to a dtype. {}".format(
+ type_value, e))
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 910364364c..394fac6c85 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -285,7 +285,7 @@ def convert_variables_to_constants(sess,
output_graph_def.node.extend([output_node])
output_graph_def.library.CopyFrom(inference_graph.library)
- print("Converted %d variables to const ops." % how_many_converted)
+ logging.info("Converted %d variables to const ops.", how_many_converted)
return output_graph_def
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index b618152b02..2dafb94ba7 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -209,7 +209,7 @@ class DeviceFunctionsTest(test.TestCase):
defun_node, 2.0, name="output_node")
with session.Session() as sess:
- init = variables.initialize_variables([variable_node])
+ init = variables.variables_initializer([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(4.0, output, 0.00001)
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c6017f5..9a8477debb 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -58,7 +58,7 @@ def load_op_library(library_filename):
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
- wrappers = py_tf.GetPythonWrappers(op_list_str)
+ wrappers = py_tf.GetEagerPythonWrappers(op_list_str)
# Delete the library handle to release any memory held in C
# that are no longer needed.
diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i
index 26ec4e8e66..efcce2f209 100644
--- a/tensorflow/python/framework/python_op_gen.i
+++ b/tensorflow/python/framework/python_op_gen.i
@@ -16,10 +16,10 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%{
-#include "tensorflow/python/framework/python_op_gen.h"
+#include "tensorflow/python/eager/python_eager_op_gen.h"
%}
-// Input typemap for GetPythonWrappers.
+// Input typemap for GetEagerPythonWrappers.
// Accepts a python object of 'bytes' type, and converts it to
// a const char* pointer and size_t length. The default typemap
// going from python bytes to const char* tries to decode the
@@ -37,5 +37,5 @@ limitations under the License.
%ignoreall;
-%unignore tensorflow::GetPythonWrappers;
-%include "tensorflow/python/framework/python_op_gen.h"
+%unignore tensorflow::GetEagerPythonWrappers;
+%include "tensorflow/python/eager/python_eager_op_gen.h"
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index f954b9d6c7..5a8bc43727 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1014,6 +1014,8 @@ class TensorFlowTestCase(googletest.TestCase):
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
return config
if graph is None:
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 5a84b16a23..e3dd4b0bdf 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -476,7 +476,7 @@ class LayoutOptimizerTest(test.TestCase):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2], keep_dims=True)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2], keepdims=True)
squeeze = array_ops.squeeze(reduce_sum, axis=[1, 2])
output = array_ops.identity(squeeze)
@@ -506,7 +506,7 @@ class LayoutOptimizerTest(test.TestCase):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2], keep_dims=True)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2], keepdims=True)
squeeze = array_ops.squeeze(reduce_sum, axis=[0, 1, 2])
output = array_ops.identity(squeeze)
@@ -623,7 +623,7 @@ class LayoutOptimizerTest(test.TestCase):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- reduce_sum = math_ops.reduce_sum(conv, axis=[3], keep_dims=True)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[3], keepdims=True)
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
@@ -653,7 +653,7 @@ class LayoutOptimizerTest(test.TestCase):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- reduce_sum = math_ops.reduce_sum(conv, axis=[2], keep_dims=True)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[2], keepdims=True)
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
@@ -682,7 +682,7 @@ class LayoutOptimizerTest(test.TestCase):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- reduce_sum = math_ops.reduce_sum(conv, axis=[2, 3], keep_dims=True)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[2, 3], keepdims=True)
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 81a4d2f820..449410fe08 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -3448,7 +3448,7 @@ def categorical_crossentropy(target, output, from_logits=False):
Returns:
Output tensor.
"""
- # Note: nn.softmax_cross_entropy_with_logits
+ # Note: nn.softmax_cross_entropy_with_logits_v2
# expects logits, Keras expects probabilities.
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
@@ -3512,7 +3512,7 @@ def binary_crossentropy(target, output, from_logits=False):
Returns:
A tensor.
"""
- # Note: nn.softmax_cross_entropy_with_logits
+ # Note: nn.sigmoid_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
# transform back to logits
diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py
index 5462a95d7d..c16fc07fb4 100644
--- a/tensorflow/python/keras/_impl/keras/layers/normalization.py
+++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py
@@ -593,9 +593,9 @@ class BatchNormalization(Layer):
# used during evaluation, it is more efficient to just update in one
# step and should not make a significant difference in the result.
new_mean = math_ops.reduce_mean(new_mean,
- axis=1, keep_dims=True)
+ axis=1, keepdims=True)
new_variance = math_ops.reduce_mean(new_variance,
- axis=1, keep_dims=True)
+ axis=1, keepdims=True)
def _do_update(var, value):
if in_eager_mode and not self.trainable:
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index ebbec39cf3..c03c514699 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -918,6 +918,20 @@ tf_py_test(
)
tf_py_test(
+ name = "string_strip_op_test",
+ size = "small",
+ srcs = ["string_strip_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "substr_op_test",
size = "small",
srcs = ["substr_op_test.py"],
@@ -1196,6 +1210,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "broadcast_to_ops_test",
+ size = "small",
+ srcs = ["broadcast_to_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+cuda_py_test(
name = "inplace_ops_test",
size = "small",
srcs = ["inplace_ops_test.py"],
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
new file mode 100644
index 0000000000..6a1bd958ba
--- /dev/null
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -0,0 +1,85 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for broadcast_to ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test as test_lib
+
+
+class BroadcastToTest(test_util.TensorFlowTestCase):
+
+ def testBroadcastToBasic(self):
+ for dtype in [np.uint8, np.uint16, np.int8, np.int16, np.int32, np.int64]:
+ with self.test_session(use_gpu=True):
+ x = np.array([1, 2, 3], dtype=dtype)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToString(self):
+ with self.test_session(use_gpu=True):
+ x = np.array([b"1", b"2", b"3"])
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToBool(self):
+ with self.test_session(use_gpu=True):
+ x = np.array([True, False, True], dtype=np.bool)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToShape(self):
+ for input_dim in range(1, 6):
+ for output_dim in range(input_dim, 6):
+ with self.test_session(use_gpu=True):
+ input_shape = [2] * input_dim
+ output_shape = [2] * output_dim
+ x = np.array(np.random.randint(5, size=input_shape), dtype=np.int32)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), output_shape)
+ v_np = np.broadcast_to(x, output_shape)
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToScalar(self):
+ with self.test_session(use_gpu=True):
+ x = np.array(1, dtype=np.int32)
+ v_tf = array_ops.broadcast_to(constant_op.constant(x), [3, 3])
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+
+ def testBroadcastToShapeTypeAndInference(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ x = np.array([1, 2, 3])
+ v_tf = array_ops.broadcast_to(
+ constant_op.constant(x),
+ constant_op.constant([3, 3], dtype=dtype))
+ shape = v_tf.get_shape().as_list()
+ v_np = np.broadcast_to(x, [3, 3])
+ self.assertAllEqual(v_tf.eval(), v_np)
+ # check shape inference when shape input is constant
+ self.assertAllEqual(shape, v_np.shape)
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 670a625f0f..79e419867d 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -104,11 +105,7 @@ class ConfusionMatrixTest(test.TestCase):
d, l, cm_out = sess.run([data, lab, cm], {m_neg: 0.0, m_pos: 1.0, s: 1.0})
truth = np.zeros([2, 2], dtype=np_dtype)
- try:
- range_builder = xrange
- except NameError: # In Python 3.
- range_builder = range
- for i in range_builder(len(d)):
+ for i in xrange(len(d)):
truth[l[i], d[i]] += 1
self.assertEqual(cm_out.dtype, np_dtype)
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 749313b00d..107ee37fab 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -65,6 +65,11 @@ class ConstantTest(test.TestCase):
self._testCpu(x)
self._testGpu(x)
+ def testInvalidDType(self):
+ # Test case for GitHub issue 18474
+ with self.assertRaises(TypeError):
+ constant_op.constant(dtypes_lib.string, "[,]")
+
def testBFloat16(self):
bfloat16 = dtypes_lib.bfloat16.as_numpy_dtype
self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(bfloat16))
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index a8b3af5096..8973a450fa 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -119,6 +119,18 @@ class Conv3DTransposeTest(test.TestCase):
target = 3.0
self.assertAllClose(target, value[n, d, h, w, k])
+ def testConv3DTransposeShapeMismatch(self):
+ # Test case for GitHub issue 18460
+ x_shape = [2, 2, 3, 4, 3]
+ f_shape = [3, 3, 3, 2, 2]
+ y_shape = [2, 2, 6, 8, 6]
+ strides = [1, 1, 2, 2, 2]
+ np.random.seed(1)
+ x_value = np.random.random_sample(x_shape).astype(np.float64)
+ f_value = np.random.random_sample(f_shape).astype(np.float64)
+ nn_ops.conv3d_transpose(
+ x_value, f_value, y_shape, strides, data_format='NCDHW')
+
def testConv3DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 2, 1]
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index b8200ac0cb..f31426713c 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import manip_ops
from tensorflow.python.platform import test as test_lib
@@ -88,41 +90,78 @@ class RollTest(test_util.TensorFlowTestCase):
x = np.random.rand(3, 2, 1, 1).astype(t)
self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+ def testNegativeAxis(self):
+ self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
+ self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
+ # Make sure negative axis shoudl be 0 <= axis + dims < dims
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
+ 3, -10).eval()
+
+ def testInvalidInputShape(self):
+ # The input should be 1-D or higher, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at least rank 1 but is rank 0"):
+ manip_ops.roll(7, 1, 0)
+
def testRollInputMustVectorHigherRaises(self):
- tensor = 7
+ # The input should be 1-D or higher, checked in kernel.
+ tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
+
+ def testInvalidAxisShape(self):
+ # The axis should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], 1, [[0, 1]])
def testRollAxisMustBeScalarOrVectorRaises(self):
+ # The axis should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
shift = 1
- axis = [[0, 1]]
+ axis = array_ops.placeholder(dtype=dtypes.int32)
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
+
+ def testInvalidShiftShape(self):
+ # The shift should be a scalar or 1-D, checked in shape function.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be at most rank 1 but is rank 2"):
+ manip_ops.roll([[1, 2], [3, 4]], [[0, 1]], 1)
def testRollShiftMustBeScalarOrVectorRaises(self):
+ # The shift should be a scalar or 1-D, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [[0, 1]]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
+
+ def testInvalidShiftAndAxisNotEqualShape(self):
+ # The shift and axis must be same size, checked in shape function.
+ with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
+ manip_ops.roll([[1, 2], [3, 4]], [1], [0, 1])
def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ # The shift and axis must be same size, checked in kernel.
tensor = [[1, 2], [3, 4]]
- shift = [1]
+ shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
with self.test_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
- manip_ops.roll(tensor, shift, axis).eval()
+ manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
def testRollAxisOutOfRangeRaises(self):
tensor = [1, 2]
diff --git a/tensorflow/python/kernel_tests/norm_op_test.py b/tensorflow/python/kernel_tests/norm_op_test.py
index d85512fae6..3f71b326a2 100644
--- a/tensorflow/python/kernel_tests/norm_op_test.py
+++ b/tensorflow/python/kernel_tests/norm_op_test.py
@@ -37,17 +37,17 @@ class NormOpTest(test_lib.TestCase):
def testBadOrder(self):
matrix = [[0., 1.], [2., 3.]]
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
- linalg_ops.norm(matrix, ord="fro")
+ linalg_ops.norm(matrix, ord=ord_)
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
linalg_ops.norm(matrix, ord=ord_, axis=-1)
- for ord_ in 1.1, 2:
+ for ord_ in "foo", -7, -1.1, 1.1:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported matrix norm"):
linalg_ops.norm(matrix, ord=ord_, axis=[-2, -1])
@@ -69,14 +69,14 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if use_static_shape_:
tf_matrix = constant_op.constant(matrix)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm)
else:
tf_matrix = array_ops.placeholder(dtype_)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm, feed_dict={tf_matrix: matrix})
- self.assertAllClose(np_norm, tf_norm_val)
+ self.assertAllClose(np_norm, tf_norm_val, rtol=1e-5, atol=1e-5)
def Test(self):
is_matrix_norm = (isinstance(axis_, tuple) or
@@ -85,8 +85,6 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if ((not is_matrix_norm and ord_ == "fro") or
(is_matrix_norm and is_fancy_p_norm)):
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
- if is_matrix_norm and ord_ == 2:
- self.skipTest("Not supported by tf.norm")
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
self.skipTest("Not supported by numpy.linalg.norm")
matrix = np.random.randn(*shape_).astype(dtype_)
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 5b508b7c0e..b9f44d728a 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -52,6 +52,38 @@ class PyFuncTest(test.TestCase):
"""Encapsulates tests for py_func and eager_py_func."""
# ----- Tests for py_func -----
+ def testRealDataTypes(self):
+ def sum_func(x, y):
+ return x + y
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
+ dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
+ dtypes.int32, dtypes.int64]:
+ with self.test_session():
+ x = constant_op.constant(1, dtype=dtype)
+ y = constant_op.constant(2, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
+ self.assertEqual(z, 3)
+
+ def testComplexDataTypes(self):
+ def sub_func(x, y):
+ return x - y
+ for dtype in [dtypes.complex64, dtypes.complex128]:
+ with self.test_session():
+ x = constant_op.constant(1 + 1j, dtype=dtype)
+ y = constant_op.constant(2 - 2j, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
+ self.assertEqual(z, -1 + 3j)
+
+ def testBoolDataTypes(self):
+ def and_func(x, y):
+ return x and y
+ dtype = dtypes.bool
+ with self.test_session():
+ x = constant_op.constant(True, dtype=dtype)
+ y = constant_op.constant(False, dtype=dtype)
+ z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
+ self.assertEqual(z, False)
+
def testSingleType(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
index a9dc7b7de0..051c7d86bf 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
@@ -46,7 +46,7 @@ def composed_sampler(logits, num_samples):
logits = array_ops.expand_dims(logits, -1)
# [batch size, num samples]
- return math_ops.argmax(logits + noise, dimension=1)
+ return math_ops.argmax(logits + noise, axis=1)
native_sampler = random_ops.multinomial
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index df37dd98ec..e4b5c3832a 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -228,6 +228,17 @@ class RandomUniformTest(test.TestCase):
print("count = ", count)
self.assertTrue(count < count_limit)
+ def testUniformIntsWithInvalidShape(self):
+ for dtype in dtypes.int32, dtypes.int64:
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ random_ops.random_uniform(
+ [1000], minval=[1, 2], maxval=3, dtype=dtype)
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ random_ops.random_uniform(
+ [1000], minval=1, maxval=[2, 3], dtype=dtype)
+
# Check that uniform ints actually follow a uniform distribution.
def testUniformInts(self):
minv = -2
diff --git a/tensorflow/python/kernel_tests/string_strip_op_test.py b/tensorflow/python/kernel_tests/string_strip_op_test.py
new file mode 100644
index 0000000000..30fd477ff4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_strip_op_test.py
@@ -0,0 +1,56 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for string_strip_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class StringStripOpTest(test.TestCase):
+ """ Test cases for tf.string_strip."""
+
+ def test_string_strip(self):
+ strings = ["pigs on the wing", "animals"]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
+
+ def test_string_strip_2d(self):
+ strings = [["pigs on the wing", "animals"],
+ [" hello ", "\n\tworld \r \n"]]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
+ [b"hello", b"world"]])
+
+ def test_string_strip_with_empty_strings(self):
+ strings = [" hello ", "", "world ", " \t \r \n "]
+
+ with self.test_session() as sess:
+ output = string_ops.string_strip(strings)
+ output = sess.run(output)
+ self.assertAllEqual(output, [b"hello", b"", b"world", b""])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 22317a348c..8c6bb7955a 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -126,6 +126,9 @@ Status NumericNpDTypeToTfDType(const int np, DataType* tf) {
case NPY_INT8:
*tf = DT_INT8;
break;
+ case NPY_UINT16:
+ *tf = DT_UINT16;
+ break;
case NPY_INT16:
*tf = DT_INT16;
break;
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index fa26e07c85..ceeabe090d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -144,6 +144,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,protected-access
@tf_export("expand_dims")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
def expand_dims(input, axis=None, name=None, dim=None):
"""Inserts a dimension of 1 into a tensor's shape.
@@ -193,11 +194,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
Raises:
ValueError: if both `dim` and `axis` are specified.
"""
- # TODO(aselle): Remove argument dim
- if dim is not None:
- if axis is not None:
- raise ValueError("can't specify both 'dim' and 'axis'")
- axis = dim
+ axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
return gen_array_ops.expand_dims(input, axis, name)
@@ -2581,6 +2578,8 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
@tf_export("squeeze")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead",
+ "squeeze_dims")
def squeeze(input, axis=None, name=None, squeeze_dims=None):
# pylint: disable=redefined-builtin
"""Removes dimensions of size 1 from the shape of a tensor.
@@ -2621,10 +2620,8 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
Raises:
ValueError: When both `squeeze_dims` and `axis` are specified.
"""
- if squeeze_dims is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'squeeze_dims' and 'axis'")
- axis = squeeze_dims
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "squeeze_dims", squeeze_dims)
if np.isscalar(axis):
axis = [axis]
return gen_array_ops.squeeze(input, axis, name)
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 66fa9e110c..8f25b1149c 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -311,7 +311,7 @@ class Categorical(distribution.Distribution):
nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
def _mode(self):
- ret = math_ops.argmax(self.logits, dimension=self._batch_rank)
+ ret = math_ops.argmax(self.logits, axis=self._batch_rank)
ret = math_ops.cast(ret, self.dtype)
ret.set_shape(self.batch_shape)
return ret
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index f0120f2957..9e46739bc1 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -331,11 +331,11 @@ def embedding_lookup_sparse(params,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
- sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ sp_ids: N x M `SparseTensor` of int64 ids (typically from FeatureValueToId),
where N is typically batch size and M is arbitrary.
- sp_weights: either a SparseTensor of float / double weights, or None to
- indicate all weights should be taken to be 1. If specified, sp_weights
- must have exactly the same shape and indices as sp_ids.
+ sp_weights: either a `SparseTensor` of float / double weights, or `None` to
+ indicate all weights should be taken to be 1. If specified, `sp_weights`
+ must have exactly the same shape and indices as `sp_ids`.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
@@ -351,39 +351,43 @@ def embedding_lookup_sparse(params,
Returns:
A dense tensor representing the combined embeddings for the
- sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
looks up the embeddings for all ids in that row, multiplies them by the
corresponding weight, and combines these embeddings as specified.
In other words, if
- shape(combined params) = [p0, p1, ..., pm]
+ `shape(combined params) = [p0, p1, ..., pm]`
and
- shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
+ `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]`
then
- shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
+ `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`.
For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
+ ```python
[0, 0]: id 1, weight 2.0
[0, 1]: id 3, weight 0.5
[1, 0]: id 0, weight 1.0
[2, 3]: id 1, weight 3.0
+ ```
with `combiner`="mean", then the output will be a 3x20 matrix where
+ ```python
output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
output[1, :] = (params[0, :] * 1.0) / 1.0
output[2, :] = (params[1, :] * 3.0) / 3.0
+ ```
Raises:
- TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
- None nor SparseTensor.
- ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
+ TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
+ neither `None` nor `SparseTensor`.
+ ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
"""
if combiner is None:
logging.warn("The default value of combiner will change from \"mean\" "
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index 4a1ef54fb5..ec38d89a0e 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -32,7 +32,6 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.util.tf_export import tf_export
@tf_export('histogram_fixed_width_bins')
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 3369fe3c9b..601010bce9 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -269,17 +269,7 @@ def random_flip_up_down(image, seed=None):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'random_flip_up_down', [image]) as scope:
- image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
- mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(
- mirror_cond,
- lambda: array_ops.reverse(image, [0]),
- lambda: image,
- name=scope)
- return fix_image_flip_shape(image, result)
+ return _random_flip(image, 0, seed, 'random_flip_up_down')
@tf_export('image.random_flip_left_right')
@@ -301,14 +291,34 @@ def random_flip_left_right(image, seed=None):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'random_flip_left_right', [image]) as scope:
+ return _random_flip(image, 1, seed, 'random_flip_left_right')
+
+
+def _random_flip(image, flip_index, seed, scope_name):
+ """Randomly (50% chance) flip an image along axis `flip_index`.
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ flip_index: The dimension along which to flip the image.
+ Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ @{tf.set_random_seed}
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(
mirror_cond,
- lambda: array_ops.reverse(image, [1]),
+ lambda: array_ops.reverse(image, [flip_index]),
lambda: image,
name=scope)
return fix_image_flip_shape(image, result)
@@ -332,16 +342,7 @@ def flip_left_right(image):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_left_right', [image]):
- image = ops.convert_to_tensor(image, name='image')
- image = _AssertAtLeast3DImage(image)
- shape = image.get_shape()
- if shape.ndims == 3 or shape.ndims is None:
- return fix_image_flip_shape(image, array_ops.reverse(image, [1]))
- elif shape.ndims == 4:
- return array_ops.reverse(image, [2])
- else:
- raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+ return _flip(image, 1, 'flip_left_right')
@tf_export('image.flip_up_down')
@@ -362,14 +363,35 @@ def flip_up_down(image):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_up_down', [image]):
+ return _flip(image, 0, 'flip_up_down')
+
+
+def _flip(image, flip_index, scope_name):
+ """Flip an image either horizontally or vertically.
+
+ Outputs the contents of `image` flipped along the dimension `flip_index`.
+
+ See also `reverse()`.
+
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: 0 For vertical, 1 for horizontal.
+
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ with ops.name_scope(None, scope_name, [image]):
image = ops.convert_to_tensor(image, name='image')
image = _AssertAtLeast3DImage(image)
shape = image.get_shape()
if shape.ndims == 3 or shape.ndims is None:
- return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
+ return fix_image_flip_shape(image, array_ops.reverse(image, [flip_index]))
elif shape.ndims == 4:
- return array_ops.reverse(image, [1])
+ return array_ops.reverse(image, [flip_index+1])
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 39b7295124..f93bf0a17f 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -39,10 +39,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
+from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -529,7 +529,7 @@ class Orthogonal(Initializer):
# Generate a random matrix
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, r = linalg_ops.qr(a, full_matrices=False)
+ q, r = gen_linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
q *= math_ops.sign(d)
@@ -577,7 +577,7 @@ class ConvolutionDeltaOrthogonal(Initializer):
a = random_ops.random_normal([shape[-1], shape[-1]],
dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, r = linalg_ops.qr(a, full_matrices=False)
+ q, r = gen_linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
q *= math_ops.sign(d)
@@ -636,7 +636,7 @@ class ConvolutionOrthogonal(Initializer):
a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
if self.seed:
self.seed += 1
- q, r = linalg_ops.qr(a)
+ q, r = gen_linalg_ops.qr(a)
d = array_ops.diag_part(r)
# make q uniform
q *= math_ops.sign(d)
@@ -723,7 +723,7 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
raise ValueError("The dimension of the matrices must be the same.")
n = p1.shape.as_list()[0]
kernel2x2 = {}
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel2x2[0, 0] = math_ops.matmul(p1, p2)
kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
@@ -848,7 +848,7 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
"""
n = projection_matrix.shape.as_list()[0]
kernel = {}
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel[0] = projection_matrix
kernel[1] = eye - projection_matrix
return kernel
@@ -976,7 +976,7 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
raise ValueError("The dimension of the matrices must be the same.")
n = p1_shape[0]
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel2x2x2 = {}
def matmul(p1, p2, p3):
return math_ops.matmul(math_ops.matmul(p1, p2), p3)
@@ -1084,7 +1084,7 @@ class Identity(Initializer):
"Identity matrix initializer can only be used for 2D matrices.")
if dtype is None:
dtype = self.dtype
- initializer = linalg_ops.eye(*full_shape, dtype=dtype)
+ initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
if partition_info is not None:
initializer = array_ops.slice(initializer, partition_info.var_offset,
shape)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 170861b43f..a0dfa543f9 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -24,12 +24,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -159,36 +160,11 @@ def eye(num_rows,
Returns:
A `Tensor` of shape `batch_shape + [num_rows, num_columns]`
"""
- with ops.name_scope(
- name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
- is_square = num_columns is None
- batch_shape = [] if batch_shape is None else batch_shape
- num_columns = num_rows if num_columns is None else num_columns
- if isinstance(num_rows, ops.Tensor) or isinstance(
- num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
- batch_shape = ops.convert_to_tensor(
- batch_shape, name='shape', dtype=dtypes.int32)
- diag_size = math_ops.minimum(num_rows, num_columns)
- diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
- if not is_square:
- shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
- else:
- if not isinstance(num_rows, compat.integral_types) or not isinstance(
- num_columns, compat.integral_types):
- raise TypeError(
- 'num_rows and num_columns must be positive integer values.')
- batch_shape = [dim for dim in batch_shape]
- is_square = num_rows == num_columns
- diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
- if not is_square:
- shape = batch_shape + [num_rows, num_columns]
-
- diag_ones = array_ops.ones(diag_shape, dtype=dtype)
- if is_square:
- return array_ops.matrix_diag(diag_ones)
- else:
- zero_matrix = array_ops.zeros(shape, dtype=dtype)
- return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+ return linalg_ops_impl.eye(num_rows,
+ num_columns=num_columns,
+ batch_shape=batch_shape,
+ dtype=dtype,
+ name=name)
@tf_export('matrix_solve_ls', 'linalg.lstsq')
@@ -454,7 +430,7 @@ def norm(tensor,
This function can compute several different vector norms (the 1-norm, the
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
- matrix norms (Frobenius, 1-norm, and inf-norm).
+ matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
Args:
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
@@ -465,7 +441,7 @@ def norm(tensor,
Some restrictions apply:
a) The Frobenius norm `fro` is not defined for vectors,
b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
- `np.inf` are supported.
+ `2`, `np.inf` are supported.
See the description of `axis` on how to compute norms for a batch of
vectors or matrices stored in a tensor.
axis: If `axis` is `None` (the default), the input is considered a vector
@@ -521,8 +497,7 @@ def norm(tensor,
axis[0] == axis[1]):
raise ValueError(
"'axis' must be None, an integer, or a tuple of 2 unique integers")
- # TODO(rmlarsen): Implement matrix 2-norm using tf.svd().
- supported_matrix_norms = ['euclidean', 'fro', 1, np.inf]
+ supported_matrix_norms = ['euclidean', 'fro', 1, 2, np.inf]
if ord not in supported_matrix_norms:
raise ValueError("'ord' must be a supported matrix norm in %s, got %s" %
(supported_matrix_norms, ord))
@@ -539,12 +514,34 @@ def norm(tensor,
with ops.name_scope(name, 'norm', [tensor]):
tensor = ops.convert_to_tensor(tensor)
+
if ord in ['fro', 'euclidean', 2, 2.0]:
- # TODO(rmlarsen): Move 2-norm to a separate clause once we support it for
- # matrices.
- result = math_ops.sqrt(
- math_ops.reduce_sum(
- tensor * math_ops.conj(tensor), axis, keepdims=True))
+ if is_matrix_norm and ord in [2, 2.0]:
+ rank = array_ops.rank(tensor)
+ positive_axis = functional_ops.map_fn(
+ lambda i: control_flow_ops.cond(i >= 0, lambda: i, lambda: i + rank),
+ ops.convert_to_tensor(axis))
+ axes = math_ops.range(rank)
+ perm_before = array_ops.concat(
+ [array_ops.setdiff1d(axes, positive_axis)[0], positive_axis],
+ axis=0)
+ perm_after = functional_ops.map_fn(
+ lambda i: math_ops.cast(
+ array_ops.squeeze(
+ array_ops.where(math_ops.equal(perm_before, i))),
+ dtype=dtypes.int32), axes)
+ permed = array_ops.transpose(tensor, perm=perm_before)
+ matrix_2_norm = array_ops.expand_dims(
+ math_ops.reduce_max(
+ math_ops.abs(gen_linalg_ops.svd(permed, compute_uv=False)[0]),
+ axis=-1,
+ keepdims=True),
+ axis=-1)
+ result = array_ops.transpose(matrix_2_norm, perm=perm_after)
+ else:
+ result = math_ops.sqrt(
+ math_ops.reduce_sum(
+ tensor * math_ops.conj(tensor), axis, keepdims=True))
else:
result = math_ops.abs(tensor)
if ord == 1:
diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py
new file mode 100644
index 0000000000..e7c89f6ae3
--- /dev/null
+++ b/tensorflow/python/ops/linalg_ops_impl.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations for linear algebra."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat
+
+# Names below are lower_case.
+# pylint: disable=invalid-name
+
+
+def eye(num_rows,
+ num_columns=None,
+ batch_shape=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Construct an identity matrix, or a batch of matrices.
+
+ See `linalg_ops.eye`.
+ """
+ with ops.name_scope(
+ name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
+ is_square = num_columns is None
+ batch_shape = [] if batch_shape is None else batch_shape
+ num_columns = num_rows if num_columns is None else num_columns
+ if isinstance(num_rows, ops.Tensor) or isinstance(
+ num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
+ batch_shape = ops.convert_to_tensor(
+ batch_shape, name='shape', dtype=dtypes.int32)
+ diag_size = math_ops.minimum(num_rows, num_columns)
+ diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
+ if not is_square:
+ shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
+ else:
+ if not isinstance(num_rows, compat.integral_types) or not isinstance(
+ num_columns, compat.integral_types):
+ raise TypeError(
+ 'num_rows and num_columns must be positive integer values.')
+ batch_shape = [dim for dim in batch_shape]
+ is_square = num_rows == num_columns
+ diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
+ if not is_square:
+ shape = batch_shape + [num_rows, num_columns]
+
+ diag_ones = array_ops.ones(diag_shape, dtype=dtype)
+ if is_square:
+ return array_ops.matrix_diag(diag_ones)
+ else:
+ zero_matrix = array_ops.zeros(shape, dtype=dtype)
+ return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+
+# pylint: enable=invalid-name,redefined-builtin
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 34ca1adc3e..9fc545c967 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@@ -306,11 +307,8 @@ def cosine_distance(
ValueError: If `predictions` shape doesn't match `labels` shape, or
`axis`, `labels`, `predictions` or `weights` is `None`.
"""
- if dim is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dim'")
- axis = dim
- if axis is None and dim is None:
+ axis = deprecated_argument_lookup("axis", axis, "dim", dim)
+ if axis is None:
raise ValueError("You must specify 'axis'.")
if labels is None:
raise ValueError("labels must not be None.")
@@ -696,7 +694,7 @@ def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
- """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
+ """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits_v2.
`weights` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weights` is a
@@ -707,11 +705,16 @@ def softmax_cross_entropy(
new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
+ Note that `onehot_labels` and `logits` must have the same shape,
+ e.g. `[batch_size, num_classes]`. The shape of `weights` must be
+ broadcastable to loss, whose shape is decided by the shape of `logits`.
+ In case the shape of `logits` is `[batch_size, num_classes]`, loss is
+ a `Tensor` of shape `[batch_size]`.
+
Args:
- onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels.
- logits: `[batch_size, num_classes]` logits outputs of the network .
- weights: Optional `Tensor` whose rank is either 0, or rank 1 and is
- broadcastable to the loss which is a `Tensor` of shape `[batch_size]`.
+ onehot_labels: One-hot-encoded labels.
+ logits: Logits outputs of the network.
+ weights: Optional `Tensor` that is broadcastable to loss.
label_smoothing: If greater than 0 then smooth the labels.
scope: the scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 2b04866fef..2feb88cb7b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -211,11 +211,9 @@ def argmax(input,
name=None,
dimension=None,
output_type=dtypes.int64):
- if dimension is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dimension'")
- axis = dimension
- elif axis is None:
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "dimension", dimension)
+ if axis is None:
axis = 0
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
@@ -231,11 +229,9 @@ def argmin(input,
name=None,
dimension=None,
output_type=dtypes.int64):
- if dimension is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dimension'")
- axis = dimension
- elif axis is None:
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "dimension", dimension)
+ if axis is None:
axis = 0
return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
@@ -761,13 +757,25 @@ def cast(x, dtype, name=None):
tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32
```
+ The operation supports data types (for `x` and `dtype`) of
+ `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`,
+ `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from
+ complex types (`complex64`, `complex128`) to real types, only the real part
+ of `x` is returned. In case of casting from real types to complex types
+ (`complex64`, `complex128`), the imaginary part of the returned value is set
+ to `0`. The handling of complex types here matches the behavior of numpy.
+
Args:
- x: A `Tensor` or `SparseTensor`.
- dtype: The destination type.
+ x: A `Tensor` or `SparseTensor` of numeric type. It could be
+ `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`,
+ `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
+ dtype: The destination type. The list of supported dtypes is the same
+ as `x`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x`.
+ A `Tensor` or `SparseTensor` with same shape as `x` and
+ same type as `dtype`.
Raises:
TypeError: If `x` cannot be cast to the `dtype`.
@@ -1634,7 +1642,7 @@ def reduce_min(input_tensor,
tensor with a single element is returned.
Args:
- input_tensor: The tensor to reduce. Should have numeric type.
+ input_tensor: The tensor to reduce. Should have real numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
@@ -1683,7 +1691,7 @@ def reduce_max(input_tensor,
tensor with a single element is returned.
Args:
- input_tensor: The tensor to reduce. Should have numeric type.
+ input_tensor: The tensor to reduce. Should have real numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 244702d13b..1d0d9a52a1 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -98,6 +98,7 @@ See the @{$python/nn} guide.
@@fixed_unigram_candidate_sampler
@@compute_accidental_hits
@@quantized_conv2d
+@@quantized_relu
@@quantized_relu_x
@@quantized_max_pool
@@quantized_avg_pool
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 47cc4da7f2..d0d5ed07ce 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -987,7 +987,7 @@ def _compute_sampled_logits(weights,
class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
- the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
@@ -1012,7 +1012,7 @@ def _compute_sampled_logits(weights,
out_logits: `Tensor` object with shape
`[batch_size, num_true + num_sampled]`, for passing to either
`nn.sigmoid_cross_entropy_with_logits` (NCE) or
- `nn.softmax_cross_entropy_with_logits` (sampled softmax).
+ `nn.softmax_cross_entropy_with_logits_v2` (sampled softmax).
out_labels: A Tensor object with the same shape as `out_logits`.
"""
@@ -1285,7 +1285,7 @@ def sampled_softmax_loss(weights,
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
- loss = tf.nn.softmax_cross_entropy_with_logits(
+ loss = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels_one_hot,
logits=logits)
```
@@ -1303,7 +1303,7 @@ def sampled_softmax_loss(weights,
biases: A `Tensor` of shape `[num_classes]`. The class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
- the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
@@ -1340,7 +1340,8 @@ def sampled_softmax_loss(weights,
partition_strategy=partition_strategy,
name=name,
seed=seed)
- sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
+ labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
+ sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
labels=labels, logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bb454b3c3a..cd07550d2e 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1155,7 +1155,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
Returns:
A `Tensor` with the same type as `value`.
- Output shape with `'VALID`` padding is:
+ Output shape with `'VALID'` padding is:
[batch, height - 2 * (filter_width - 1),
width - 2 * (filter_height - 1), out_channels].
@@ -1458,10 +1458,10 @@ def conv3d_transpose(
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [5] if reached this point.
- if not filter.get_shape()[3].is_compatible_with(output_shape[4]):
+ if not filter.get_shape()[3].is_compatible_with(output_shape[axis]):
raise ValueError(
"output_shape does not match filter's output channels, "
- "{} != {}".format(output_shape[4],
+ "{} != {}".format(output_shape[axis],
filter.get_shape()[3]))
if padding != "VALID" and padding != "SAME":
@@ -1986,7 +1986,7 @@ def sparse_softmax_cross_entropy_with_logits(
must provide a single specific index for the true class for each row of
`logits` (each minibatch entry). For soft softmax classification with
a probability distribution for each entry, see
- `softmax_cross_entropy_with_logits`.
+ `softmax_cross_entropy_with_logits_v2`.
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 9251e9802c..86dc053c0f 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -617,9 +617,9 @@ class BasicLSTMCell(LayerRNNCell):
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
- `[batch_size, self.state_size]`, if `state_is_tuple` has been set to
+ `[batch_size, num_units]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
- `[batch_size, 2 * self.state_size]`.
+ `[batch_size, 2 * num_units]`.
Returns:
A pair containing the new hidden state, and the new state (either a
diff --git a/tensorflow/python/profiler/tfprof_logger_test.py b/tensorflow/python/profiler/tfprof_logger_test.py
index 141144f987..caf3869f56 100644
--- a/tensorflow/python/profiler/tfprof_logger_test.py
+++ b/tensorflow/python/profiler/tfprof_logger_test.py
@@ -38,7 +38,7 @@ class TFProfLoggerTest(test.TestCase):
return math_ops.matmul(a, b)
# pylint: disable=pointless-string-statement
- """# TODO(xpan): This this out of core so it doesn't depend on contrib.
+ """# TODO(xpan): This out of core so it doesn't depend on contrib.
def testFillMissingShape(self):
a, b, y = self._BuildSmallPlaceholderlModel()
run_options = config_pb2.RunOptions(
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index b88be4ae04..73ea85ab0c 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -41,6 +41,7 @@ from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.platform import app # pylint: disable=unused-import
+from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
@@ -543,7 +544,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
input_examples = preprocess_input_examples_arg_string(input_examples_str)
for input_tensor_key, (filename, variable_name) in inputs.items():
- data = np.load(filename)
+ data = np.load(file_io.FileIO(filename, mode='r'))
# When a variable_name key is specified for the input file
if variable_name:
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 3867c0d8da..70495291bc 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2731,7 +2731,7 @@ class ScopedGraphTest(test.TestCase):
# The rest of the variables.
rest_variables = list(
set(variables.global_variables()) - set(var_list.keys()))
- init_rest_op = variables.initialize_variables(rest_variables)
+ init_rest_op = variables.variables_initializer(rest_variables)
with self.test_session(graph=graph) as sess:
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 4163fcac79..3358ffe526 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -42,10 +42,8 @@ import six as _six
from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('compat.as_bytes', 'compat.as_str')
def as_bytes(bytes_or_text, encoding='utf-8'):
"""Converts either bytes or unicode to `bytes`, using utf-8 encoding for text.
@@ -68,7 +66,6 @@ def as_bytes(bytes_or_text, encoding='utf-8'):
(bytes_or_text,))
-@tf_export('compat.as_text')
def as_text(bytes_or_text, encoding='utf-8'):
"""Returns the given argument as a unicode string.
@@ -93,8 +90,12 @@ def as_text(bytes_or_text, encoding='utf-8'):
# Convert an object to a `str` in both Python 2 and 3.
if _six.PY2:
as_str = as_bytes
+ tf_export('compat.as_bytes', 'compat.as_str')(as_bytes)
+ tf_export('compat.as_text')(as_text)
else:
as_str = as_text
+ tf_export('compat.as_bytes')(as_bytes)
+ tf_export('compat.as_text', 'compat.as_str')(as_text)
@tf_export('compat.as_str_any')