aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure7
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake18
-rw-r--r--tensorflow/contrib/cmake/tf_python.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake5
-rw-r--r--tensorflow/contrib/distributions/python/ops/exponential.py2
-rw-r--r--tensorflow/contrib/makefile/Makefile10
-rwxr-xr-xtensorflow/contrib/makefile/build_all_ios.sh7
-rwxr-xr-xtensorflow/contrib/makefile/compile_ios_protobuf.sh15
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py2
-rw-r--r--tensorflow/contrib/seq2seq/BUILD12
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py50
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/loss.py78
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader.py3
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc18
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc71
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/debug_ops.h2
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc4
-rw-r--r--tensorflow/examples/how_tos/reading_data/convert_to_records.py8
-rw-r--r--tensorflow/examples/udacity/README.md9
-rw-r--r--tensorflow/g3doc/how_tos/quantization/index.md10
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py3
-rw-r--r--tensorflow/python/debug/debug_data.py2
-rw-r--r--tensorflow/python/debug/debug_data_test.py2
-rw-r--r--tensorflow/python/framework/graph_io.py4
-rw-r--r--tensorflow/python/ops/special_math_ops.py16
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py37
-rw-r--r--tensorflow/python/ops/variable_scope.py2
-rw-r--r--tensorflow/python/training/saver_test.py16
-rwxr-xr-xtensorflow/tools/ci_build/builds/android_full.sh3
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh2
31 files changed, 371 insertions, 50 deletions
diff --git a/configure b/configure
index 753841d166..87ef6e99be 100755
--- a/configure
+++ b/configure
@@ -23,13 +23,8 @@ function bazel_clean_and_fetch() {
# TODO(pcloudy): Re-enable it after bazel clean --expunge is fixed.
if ! is_windows; then
bazel clean --expunge
- # TODO(https://github.com/bazelbuild/bazel/issues/2220) Remove the nested `bazel query`.
- bazel fetch $(bazel query "//tensorflow/... -//tensorflow/examples/android/...")
- else
- # TODO(pcloudy): Also filter out //tensorflow/examples/android/... on Windows after
- # https://github.com/bazelbuild/bazel/issues/2248 is fixed.
- bazel fetch //tensorflow/...
fi
+ bazel fetch "//tensorflow/... -//tensorflow/examples/android/..."
}
## Set up python-related environment settings
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 50eba857ad..911e52604e 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -102,6 +102,24 @@ file(GLOB_RECURSE tf_core_gpu_kernels_srcs
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc"
)
+if(WIN32 AND tensorflow_ENABLE_GPU)
+ file(GLOB_RECURSE tf_core_kernels_cpu_only_srcs
+ # GPU implementation not working on Windows yet.
+ "${tensorflow_source_dir}/tensorflow/core/kernels/matrix_diag_op.cc"
+ "${tensorflow_source_dir}/tensorflow/core/kernels/one_hot_op.cc")
+ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cpu_only_srcs})
+ add_library(tf_core_kernels_cpu_only OBJECT ${tf_core_kernels_cpu_only_srcs})
+ add_dependencies(tf_core_kernels_cpu_only tf_core_cpu)
+ # Undefine GOOGLE_CUDA to avoid registering unsupported GPU kernel symbols.
+ get_target_property(target_compile_flags tf_core_kernels_cpu_only COMPILE_FLAGS)
+ if(target_compile_flags STREQUAL "target_compile_flags-NOTFOUND")
+ set(target_compile_flags "/UGOOGLE_CUDA")
+ else()
+ set(target_compile_flags "${target_compile_flags} /UGOOGLE_CUDA")
+ endif()
+ set_target_properties(tf_core_kernels_cpu_only PROPERTIES COMPILE_FLAGS ${target_compile_flags})
+endif(WIN32 AND tensorflow_ENABLE_GPU)
+
add_library(tf_core_kernels OBJECT ${tf_core_kernels_srcs})
add_dependencies(tf_core_kernels tf_core_cpu)
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index d7d3d54003..6fe83bf83e 100644
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -584,6 +584,7 @@ add_library(pywrap_tensorflow SHARED
$<TARGET_OBJECTS:tf_core_direct_session>
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
$<TARGET_OBJECTS:tf_core_kernels>
+ $<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_include_directories(pywrap_tensorflow PUBLIC
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 081a7a7ce3..0909b61ba9 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -148,13 +148,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/server_test.py"
+ "${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py" # Silently failing with GPU kernel disabled.
# int32/int64 mixup
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
- # cuda launch failed
- "${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py"
- "${tensorflow_source_dir}/tensorflow/python/kernel_tests/trace_op_test.py"
- "${tensorflow_source_dir}/tensorflow/python/kernel_tests/one_hot_op_test.py" # gpu, T=uint8
# training tests
"${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix.
"${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker.
diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py
index cd6e5c2d1a..d0245bf445 100644
--- a/tensorflow/contrib/distributions/python/ops/exponential.py
+++ b/tensorflow/contrib/distributions/python/ops/exponential.py
@@ -74,7 +74,7 @@ class Exponential(gamma.Gamma):
allow_nan_stats=allow_nan_stats,
validate_args=validate_args,
name=ns)
- # While the Gamma distribution is not reparameterizeable, the
+ # While the Gamma distribution is not re-parameterizable, the
# exponential distribution is.
self._is_reparameterized = True
self._parameters = parameters
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index bb22586264..284a5894cf 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -294,6 +294,7 @@ ifeq ($(TARGET),IOS)
ifeq ($(IOS_ARCH),ARMV7)
CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-arch armv7 \
+ -fembed-bitcode \
-D__thread= \
-DUSE_GEMM_FOR_CONV \
-Wno-c++11-narrowing \
@@ -304,6 +305,7 @@ ifeq ($(TARGET),IOS)
-isysroot \
${IPHONEOS_SYSROOT}
LDFLAGS := -arch armv7 \
+ -fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
-Xlinker -S \
@@ -316,6 +318,7 @@ ifeq ($(TARGET),IOS)
ifeq ($(IOS_ARCH),ARMV7S)
CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-arch armv7s \
+ -fembed-bitcode \
-D__thread= \
-DUSE_GEMM_FOR_CONV \
-Wno-c++11-narrowing \
@@ -326,6 +329,7 @@ ifeq ($(TARGET),IOS)
-isysroot \
${IPHONEOS_SYSROOT}
LDFLAGS := -arch armv7s \
+ -fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
-Xlinker -S \
@@ -338,6 +342,7 @@ ifeq ($(TARGET),IOS)
ifeq ($(IOS_ARCH),ARM64)
CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-arch arm64 \
+ -fembed-bitcode \
-D__thread= \
-DUSE_GEMM_FOR_CONV \
-Wno-c++11-narrowing \
@@ -347,6 +352,7 @@ ifeq ($(TARGET),IOS)
-isysroot \
${IPHONEOS_SYSROOT}
LDFLAGS := -arch arm64 \
+ -fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
-Xlinker -S \
@@ -359,6 +365,7 @@ ifeq ($(TARGET),IOS)
ifeq ($(IOS_ARCH),I386)
CXXFLAGS += -mios-simulator-version-min=$(MIN_SDK_VERSION) \
-arch i386 \
+ -fembed-bitcode \
-D__thread= \
-DUSE_GEMM_FOR_CONV \
-Wno-c++11-narrowing \
@@ -368,6 +375,7 @@ ifeq ($(TARGET),IOS)
-isysroot \
${IPHONESIMULATOR_SYSROOT}
LDFLAGS := -arch i386 \
+ -fembed-bitcode \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
-Xlinker -S \
@@ -380,6 +388,7 @@ ifeq ($(TARGET),IOS)
ifeq ($(IOS_ARCH),X86_64)
CXXFLAGS += -mios-simulator-version-min=$(MIN_SDK_VERSION) \
-arch x86_64 \
+ -fembed-bitcode \
-D__thread= \
-DUSE_GEMM_FOR_CONV \
-Wno-c++11-narrowing \
@@ -389,6 +398,7 @@ ifeq ($(TARGET),IOS)
-isysroot \
${IPHONESIMULATOR_SYSROOT}
LDFLAGS := -arch x86_64 \
+ -fembed-bitcode \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-framework Accelerate \
-Xlinker -S \
diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh
index 4d9ce077ba..344bf49dcf 100755
--- a/tensorflow/contrib/makefile/build_all_ios.sh
+++ b/tensorflow/contrib/makefile/build_all_ios.sh
@@ -32,6 +32,13 @@ cd ${SCRIPT_DIR}/../../../
make -f tensorflow/contrib/makefile/Makefile clean
rm -rf tensorflow/contrib/makefile/downloads
+# Setting a deployment target is required for building with bitcode,
+# otherwise linking will fail with:
+#
+# ld: -bind_at_load and -bitcode_bundle (Xcode setting ENABLE_BITCODE=YES) cannot be used together
+#
+export MACOSX_DEPLOYMENT_TARGET="10.10"
+
# Pull down the required versions of the frameworks we need.
tensorflow/contrib/makefile/download_dependencies.sh
diff --git a/tensorflow/contrib/makefile/compile_ios_protobuf.sh b/tensorflow/contrib/makefile/compile_ios_protobuf.sh
index c413924538..12f34b38d0 100755
--- a/tensorflow/contrib/makefile/compile_ios_protobuf.sh
+++ b/tensorflow/contrib/makefile/compile_ios_protobuf.sh
@@ -76,14 +76,17 @@ make distclean
"CFLAGS=${CFLAGS} \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-arch i386 \
+-fembed-bitcode \
-isysroot ${IPHONESIMULATOR_SYSROOT}" \
"CXX=${CXX}" \
"CXXFLAGS=${CXXFLAGS} \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-arch i386 \
+-fembed-bitcode \
-isysroot \
${IPHONESIMULATOR_SYSROOT}" \
LDFLAGS="-arch i386 \
+-fembed-bitcode \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
${LDFLAGS} \
-L${IPHONESIMULATOR_SYSROOT}/usr/lib/ \
@@ -103,14 +106,17 @@ make distclean
"CFLAGS=${CFLAGS} \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-arch x86_64 \
+-fembed-bitcode \
-isysroot ${IPHONESIMULATOR_SYSROOT}" \
"CXX=${CXX}" \
"CXXFLAGS=${CXXFLAGS} \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
-arch x86_64 \
+-fembed-bitcode \
-isysroot \
${IPHONESIMULATOR_SYSROOT}" \
LDFLAGS="-arch x86_64 \
+-fembed-bitcode \
-mios-simulator-version-min=${MIN_SDK_VERSION} \
${LDFLAGS} \
-L${IPHONESIMULATOR_SYSROOT}/usr/lib/ \
@@ -129,13 +135,16 @@ make distclean
"CFLAGS=${CFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch armv7 \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
"CXX=${CXX}" \
"CXXFLAGS=${CXXFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch armv7 \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
LDFLAGS="-arch armv7 \
+-fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
${LDFLAGS}" \
"LIBS=${LIBS}"
@@ -152,13 +161,16 @@ make distclean
"CFLAGS=${CFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch armv7s \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
"CXX=${CXX}" \
"CXXFLAGS=${CXXFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch armv7s \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
LDFLAGS="-arch armv7s \
+-fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
${LDFLAGS}" \
"LIBS=${LIBS}"
@@ -175,12 +187,15 @@ make distclean
"CFLAGS=${CFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch arm64 \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
"CXXFLAGS=${CXXFLAGS} \
-miphoneos-version-min=${MIN_SDK_VERSION} \
-arch arm64 \
+-fembed-bitcode \
-isysroot ${IPHONEOS_SYSROOT}" \
LDFLAGS="-arch arm64 \
+-fembed-bitcode \
-miphoneos-version-min=${MIN_SDK_VERSION} \
${LDFLAGS}" \
"LIBS=${LIBS}"
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index c1578f7da9..d1d547b952 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -635,7 +635,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)
if sequence_length is None:
- max_seq_len = time_len
+ max_seq_len = math_ops.to_int64(time_len)
else:
max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length))
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index 9566d03211..3c314e2f28 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -41,6 +41,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "loss_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/loss_test.py"],
+ additional_deps = [
+ ":seq2seq_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "seq2seq_test",
size = "medium",
srcs = ["python/kernel_tests/seq2seq_test.py"],
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
index f99de76f17..95560fb254 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py
@@ -20,14 +20,58 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
+import numpy as np
import tensorflow as tf
-
class LossTest(tf.test.TestCase):
- def testLoss(self):
- pass
+ def testSequenceLoss(self):
+ with self.test_session() as sess:
+ with tf.variable_scope("root",
+ initializer=tf.constant_initializer(0.5)) as varscope:
+ batch_size = 2
+ sequence_length = 3
+ number_of_classes = 5
+ logits = [tf.constant(i + 0.5, shape=[batch_size, number_of_classes])
+ for i in range(sequence_length)]
+ logits = tf.stack(logits, axis=1)
+ targets = [tf.constant(i, tf.int32, shape=[batch_size]) for i in
+ range(sequence_length)]
+ targets = tf.stack(targets, axis=1)
+ weights = [tf.constant(1.0, shape=[batch_size]) for i in
+ range(sequence_length)]
+ weights = tf.stack(weights, axis=1)
+
+ average_loss_per_example = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_example)
+ self.assertAllClose(1.60944, res)
+
+ average_loss_per_sequence = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=True)
+ res = sess.run(average_loss_per_sequence)
+ compare_per_sequence = np.ones((sequence_length)) * 1.60944
+ self.assertAllClose(compare_per_sequence, res)
+
+ average_loss_per_batch = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=True,
+ average_across_batch=False)
+ res = sess.run(average_loss_per_batch)
+ compare_per_batch = np.ones((batch_size)) * 1.60944
+ self.assertAllClose(compare_per_batch, res)
+ total_loss = tf.contrib.seq2seq.sequence_loss(
+ logits, targets, weights,
+ average_across_timesteps=False,
+ average_across_batch=False)
+ res = sess.run(total_loss)
+ compare_total = np.ones((batch_size, sequence_length)) * 1.60944
+ self.assertAllClose(compare_total, res)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py
index b8a33b3f6f..bb87111266 100644
--- a/tensorflow/contrib/seq2seq/python/ops/loss.py
+++ b/tensorflow/contrib/seq2seq/python/ops/loss.py
@@ -13,18 +13,88 @@
# limitations under the License.
# ==============================================================================
-"""Seq2seq loss operations for use in neural networks.
+"""Seq2seq loss operations for use in sequence models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import math_ops
+__all__ = ["sequence_loss"]
-__all__ = ["seq2seq_loss"]
+def sequence_loss(logits, targets, weights,
+ average_across_timesteps=True, average_across_batch=True,
+ softmax_loss_function=None, name=None):
+ """Weighted cross-entropy loss for a sequence of logits (per example).
+ Args:
+ logits: A 3D Tensor of shape
+ [batch_size x sequence_length x num_decoder_symbols] and dtype float.
+ The logits correspond to the prediction across all classes at each
+ timestep.
+ targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ int. The target represents the true class at each timestep.
+ weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
+ float. Weights constitutes the weighting of each prediction in the
+ sequence. When using weights as masking set all valid timesteps to 1 and
+ all padded timesteps to 0.
+ average_across_timesteps: If set, sum the cost across the sequence
+ dimension and divide by the cost by the total label weight across
+ timesteps.
+ average_across_batch: If set, sum the cost across the batch dimension and
+ divide the returned cost by the batch size.
+ softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
+ to be used instead of the standard softmax (the default if this is None).
+ name: Optional name for this operation, defaults to "sequence_loss".
-def seq2seq_loss(*args, **kwargs):
- pass
+ Returns:
+ A scalar float Tensor: The average log-perplexity per symbol (weighted).
+
+ Raises:
+ ValueError: logits does not have 3 dimensions or targets does not have 2
+ dimensions or weights does not have 2 dimensions.
+ """
+ if len(logits.get_shape()) != 3:
+ raise ValueError("Logits must be a "
+ "[batch_size x sequence_length x logits] tensor")
+ if len(targets.get_shape()) != 2:
+ raise ValueError("Targets must be a [batch_size x sequence_length] "
+ "tensor")
+ if len(weights.get_shape()) != 2:
+ raise ValueError("Weights must be a [batch_size x sequence_length] "
+ "tensor")
+ with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
+ num_classes = array_ops.shape(logits)[2]
+ probs_flat = array_ops.reshape(logits, [-1, num_classes])
+ targets = array_ops.reshape(targets, [-1])
+ if softmax_loss_function is None:
+ crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=targets, logits=probs_flat)
+ else:
+ crossent = softmax_loss_function(probs_flat, targets)
+ crossent = crossent * array_ops.reshape(weights, [-1])
+ if average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent)
+ total_size = math_ops.reduce_sum(weights)
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ else:
+ batch_size = array_ops.shape(logits)[0]
+ sequence_length = array_ops.shape(logits)[1]
+ crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
+ if average_across_timesteps and not average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[1])
+ total_size = math_ops.reduce_sum(weights, axis=[1])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ if not average_across_timesteps and average_across_batch:
+ crossent = math_ops.reduce_sum(crossent, axis=[0])
+ total_size = math_ops.reduce_sum(weights, axis=[0])
+ total_size += 1e-12 # to avoid division by 0 for all-0 weights
+ crossent /= total_size
+ return crossent
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
index 170c5899b9..6082af008a 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py
@@ -210,7 +210,8 @@ def parallel_read(data_sources,
data_files = get_data_files(data_sources)
with ops.name_scope(scope, 'parallel_read'):
filename_queue = tf_input.string_input_producer(
- data_files, num_epochs=num_epochs, shuffle=shuffle, name='filenames')
+ data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed,
+ name='filenames')
dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
if shuffle:
common_queue = data_flow_ops.RandomShuffleQueue(
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index d3110cba04..f6e6bf0692 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -605,7 +605,7 @@ bool IsMetadataNode(const Node* node) {
// outputs that are connected to nodes in the same colocation group.
bool IsGeneratorNode(const Node* node) {
return node->num_inputs() == 0 && node->num_outputs() == 1 &&
- node->out_edges().size() == 1 && !IsRefType(node->output_type(0));
+ !IsRefType(node->output_type(0));
}
} // namespace
@@ -730,9 +730,9 @@ Status SimplePlacer::Run() {
// Heuristic A: prefer to place "generators" with their only
// consumers.
//
- // If this is a node with no inputs and a single (non-ref)
- // consumer, we save this for a second pass, so that the
- // consumer's placement is chosen.
+ // If this is a node with no inputs and one output, we save
+ // this for a second pass, so that the consumer's placement
+ // is chosen.
if (IsGeneratorNode(node)) {
second_pass.push_back(node);
continue;
@@ -794,7 +794,15 @@ Status SimplePlacer::Run() {
if (IsGeneratorNode(node)) {
const Node* output = (*node->out_edges().begin())->dst();
const string& output_device_name = output->assigned_device_name();
- if (CanAssignToDevice(output_device_name, devices)) {
+
+ const bool consumers_on_same_device = std::all_of(
+ node->out_edges().begin(), node->out_edges().end(),
+ [output_device_name](const Edge* e) {
+ return e->dst()->assigned_device_name() == output_device_name;
+ });
+
+ if (consumers_on_same_device &&
+ CanAssignToDevice(output_device_name, devices)) {
assigned_device = output_device_name;
}
}
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index 06267d71ae..c73ed041ed 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -1226,5 +1226,76 @@ TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
.contains("Cannot colocate nodes 'var' and 'assign'"));
}
+// Test that a generator node follows its consumers (where there are several
+// consumer nodes on the same devices).
+TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+
+ // A variable is only on CPU
+ Node* var1_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
+ Node* var2_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
+
+ // The constant to be assigned can be on both GPU or CPU.
+ //
+ // Because of the heuristic, it gets placed on CPU to avoid a
+ // copy.
+ Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
+
+ // The assigns are bound to CPU by the reference edge.
+ ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
+ ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
+
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "var1_cpu", "in");
+ EXPECT_COLOCATED(g, "assign1", "in");
+ EXPECT_COLOCATED(g, "var2_cpu", "in");
+ EXPECT_COLOCATED(g, "assign2", "in");
+}
+
+// Test that a generator node does not follow its consumers (where there are
+// several consumers on different devices).
+TEST_F(SimplePlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+
+ // A variable is only on CPU
+ Node* var1_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
+ Node* var2_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
+
+ // The constant to be assigned can be on both GPU or CPU.
+ //
+ // Because of the heuristic, it ought to be on the GPU (cannot be
+ // co-located with both consumers, so goes to the 'standard' place)
+ Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
+
+ // The assigns are bound to CPU by the reference edge.
+ ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
+ ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
+
+ TF_EXPECT_OK(BuildGraph(b, &g));
+
+ GetNodeByName(g, "var1_cpu")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:1");
+
+ GetNodeByName(g, "var2_cpu")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:2");
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "assign1", "var1_cpu");
+ EXPECT_COLOCATED(g, "assign2", "var2_cpu");
+ EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e7854ac0db..5137156107 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -439,7 +439,6 @@ ARRAY_DEPS = [
cc_library(
name = "array_not_windows",
deps = [
- ":debug_ops",
":immutable_constant_op",
],
)
@@ -478,6 +477,7 @@ cc_library(
":bitcast_op",
":concat_op",
":constant_op",
+ ":debug_ops",
":depth_space_ops",
":diag_op",
":edit_distance_op",
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index aa47315f55..0a4dc45205 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -127,7 +127,7 @@ class DebugNanCountOp : public OpKernel {
const T* input_flat = input.template flat<T>().data();
for (int64 i = 0; i < input_shape.num_elements(); ++i) {
- if (Eigen::numext::isnan(input_flat[i])) {
+ if (Eigen::numext::isnan(static_cast<double>(input_flat[i]))) {
nan_count++;
}
}
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 670abf3fdf..facadc7f57 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -72,7 +72,9 @@ SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) {
BOOL read_result = ::ReadFile(hfile, src, static_cast<DWORD>(num_bytes),
&bytes_read, &overlapped);
- if ((FALSE == read_result) &&
+ if (TRUE == read_result) {
+ result = bytes_read;
+ } else if ((FALSE == read_result) &&
((last_error = GetLastError()) != ERROR_IO_PENDING)) {
result = (last_error == ERROR_HANDLE_EOF) ? 0 : -1;
} else {
diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
index 5457b27eca..d14c1f7c86 100644
--- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
@@ -26,14 +26,6 @@ import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
-SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
-
-TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames
-TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
-TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
-TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
-
-
FLAGS = None
diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md
index 6e40c3bae6..143a75a3e9 100644
--- a/tensorflow/examples/udacity/README.md
+++ b/tensorflow/examples/udacity/README.md
@@ -54,6 +54,15 @@ to get the ip of the new virtual machine. To switch from default virtual machine
Note that `docker-machine env tensorflow` outputs some environment variables such like `DOCKER_HOST`. Then your docker client is now connected to the docker host in virtual machine `tensorflow`
+* **I'm getting a TLS connection error.**
+
+If you get an error about the TLS connection of your docker, run the command below to confirm the problem.
+
+ docker-machine ip tensorflow
+
+Then if it is the case use the instructions on [this page](https://docs.docker.com/toolbox/faqs/troubleshoot/) to solve the issue.
+
+
* **I'm getting the error - docker: Cannot connect to the Docker daemon. Is the docker daemon running on this host? - when I run 'docker run'.**
This is a permissions issue, and a popular answer is provided for Linux and Max OSX [here](http://stackoverflow.com/questions/21871479/docker-cant-connect-to-docker-daemon) on StackOverflow.
diff --git a/tensorflow/g3doc/how_tos/quantization/index.md b/tensorflow/g3doc/how_tos/quantization/index.md
index 340d70f93e..fa61cadcea 100644
--- a/tensorflow/g3doc/how_tos/quantization/index.md
+++ b/tensorflow/g3doc/how_tos/quantization/index.md
@@ -91,11 +91,11 @@ eight-bit computations:
```sh
curl http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -o /tmp/inceptionv3.tgz
tar xzf /tmp/inceptionv3.tgz -C /tmp/
-bazel build tensorflow/contrib/quantization/tools:quantize_graph
-bazel-bin/tensorflow/contrib/quantization/tools/quantize_graph \
---input=/tmp/classify_image_graph_def.pb \
---output_node_names="softmax" --output=/tmp/quantized_graph.pb \
---mode=eightbit
+bazel build tensorflow/tools/quantization/tools:quantize_graph
+bazel-bin/tensorflow/tools/quantization/tools/quantize_graph \
+ --input=/tmp/classify_image_graph_def.pb \
+ --output_node_names="softmax" --output=/tmp/quantized_graph.pb \
+ --mode=eightbit
```
This will produce a new model that runs the same operations as the original, but
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index d6b5c3a097..5d3f1aed2f 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -1608,7 +1608,8 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
self.assertEqual(" dtype: int32", output.lines[1])
self.assertEqual(" shape: ()", output.lines[2])
self.assertEqual("", output.lines[3])
- self.assertEqual("array(%d, dtype=int32)" % i, output.lines[4])
+ self.assertTrue(output.lines[4].startswith("array(%d" % i))
+ self.assertTrue(output.lines[4].endswith(")"))
def testMultipleDumpsPrintTensorInvalidNumber(self):
output = self._registry.dispatch_command("pt",
diff --git a/tensorflow/python/debug/debug_data.py b/tensorflow/python/debug/debug_data.py
index 3e638c6f04..e010e8f3fc 100644
--- a/tensorflow/python/debug/debug_data.py
+++ b/tensorflow/python/debug/debug_data.py
@@ -312,7 +312,7 @@ class DebugTensorDatum(object):
self._debug_op = base.split("_")[-2]
self._output_slot = int(base.split("_")[-3])
- namespace = os.path.dirname(debug_dump_rel_path)
+ namespace = os.path.dirname(debug_dump_rel_path).replace("\\", "/")
node_base_name = "_".join(base.split("_")[:-3])
if not namespace or namespace == ".":
self._node_name = node_base_name
diff --git a/tensorflow/python/debug/debug_data_test.py b/tensorflow/python/debug/debug_data_test.py
index 9910244ad3..753b76358b 100644
--- a/tensorflow/python/debug/debug_data_test.py
+++ b/tensorflow/python/debug/debug_data_test.py
@@ -133,7 +133,7 @@ class HasNanOrInfTest(test_util.TensorFlowTestCase):
a = np.array([1j, 3j, 3j, 7j], dtype=np.complex128)
self.assertFalse(debug_data.has_inf_or_nan(self._dummy_datum, a))
- b = np.array([1j, 3j, 3j, 7j, np.nan], dtype=np.complex256)
+ b = np.array([1j, 3j, 3j, 7j, np.nan], dtype=np.complex128)
self.assertTrue(debug_data.has_inf_or_nan(self._dummy_datum, b))
def testDTypeIntegerWorks(self):
diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py
index 85a00efd74..0033a37088 100644
--- a/tensorflow/python/framework/graph_io.py
+++ b/tensorflow/python/framework/graph_io.py
@@ -50,6 +50,9 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True):
filesystems, such as Google Cloud Storage (GCS).
name: Filename for the graph.
as_text: If `True`, writes the graph as an ASCII proto.
+
+ Returns:
+ The path of the output proto file.
"""
if isinstance(graph_or_graph_def, ops.Graph):
graph_def = graph_or_graph_def.as_graph_def()
@@ -64,3 +67,4 @@ def write_graph(graph_or_graph_def, logdir, name, as_text=True):
file_io.atomic_write_string_to_file(path, str(graph_def))
else:
file_io.atomic_write_string_to_file(path, graph_def.SerializeToString())
+ return path
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index a2d995030f..bf4d198209 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -340,6 +340,22 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] +
t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
)
+
+ # Check the number of None values and replace them with Tensors containing
+ # corresponding dimensions if there exist two or more None values
+ num_none_dims = sum(1 for d in uncompacted_shape if d is None)
+ if num_none_dims > 1:
+ uncompacted_shape = list(uncompacted_shape)
+ for i in xrange(len(uncompacted_shape)):
+ if uncompacted_shape[i] is None:
+ if i < len(preserved_axes) + len(broadcast_axes[0]):
+ uncompacted_shape[i] = array_ops.shape(inputs[0])[i]
+ else:
+ idx = (i - len(preserved_axes) - len(broadcast_axes[0])
+ + len(t1_shape) - len(broadcast_axes[1]))
+ uncompacted_shape[i] = array_ops.shape(inputs[1])[idx]
+ uncompacted_shape = tuple(uncompacted_shape)
+
product = _reshape_if_necessary(product, uncompacted_shape)
product_axes = (
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index d17bb80d4b..3d289bcc9a 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -283,6 +283,43 @@ class EinsumTest(test.TestCase):
}
np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict))
+ # Tests for placeholders which have two or more None values
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
+ out = special_math_ops.einsum('ijk,kl->ijl', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[[1,2]]],
+ m1: [[3], [2]],
+ }
+ np.testing.assert_almost_equal(
+ [[[7]]], sess.run(out, feed_dict=feed_dict))
+
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ out = special_math_ops.einsum('kl,ijk->ijl', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[3], [2]],
+ m1: [[[1,2]]],
+ }
+ np.testing.assert_almost_equal(
+ [[[7]]], sess.run(out, feed_dict=feed_dict))
+
+ with ops.Graph().as_default():
+ m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
+ m1 = array_ops.placeholder(dtypes.int32, shape=(2,))
+ out = special_math_ops.einsum('ijk,k->ij', m0, m1)
+ with session.Session() as sess:
+ feed_dict = {
+ m0: [[[1, 2]]],
+ m1: [3, 2],
+ }
+ np.testing.assert_almost_equal(
+ [[7]], sess.run(out, feed_dict=feed_dict))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 5897f94f9d..ddba73f7e9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -998,7 +998,7 @@ for an extensive description of how reusing works. Here is a basic example:
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
-with tf.variable_scope("foo", reuse=True)
+with tf.variable_scope("foo", reuse=True):
v1 = tf.get_variable("v") # The same as v above.
```
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 03d1c06476..af9f13f438 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1837,14 +1837,22 @@ class WriteGraphTest(test.TestCase):
def testWriteGraph(self):
test_dir = _TestDir("write_graph_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- graph_io.write_graph(ops_lib.get_default_graph(),
- "/".join([test_dir, "l1"]), "graph.pbtxt")
+ path = graph_io.write_graph(ops_lib.get_default_graph(),
+ os.path.join(test_dir, "l1"), "graph.pbtxt")
+ truth = os.path.join(test_dir, "l1", "graph.pbtxt")
+ self.assertEqual(path, truth)
+ self.assertTrue(os.path.exists(path))
+
def testRecursiveCreate(self):
test_dir = _TestDir("deep_dir")
variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
- "/".join([test_dir, "l1/l2/l3"]), "graph.pbtxt")
+ path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
+ os.path.join(test_dir, "l1", "l2", "l3"),
+ "graph.pbtxt")
+ truth = os.path.join(test_dir, 'l1', 'l2', 'l3', "graph.pbtxt")
+ self.assertEqual(path, truth)
+ self.assertTrue(os.path.exists(path))
class SaverUtilsTest(test.TestCase):
diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh
index fce2ed7504..241f28ef9c 100755
--- a/tensorflow/tools/ci_build/builds/android_full.sh
+++ b/tensorflow/tools/ci_build/builds/android_full.sh
@@ -67,4 +67,7 @@ cp bazel-bin/tensorflow/examples/android/tensorflow_demo.apk \
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar ${OUT_DIR}
# Test Makefile build just to make sure it still works.
+if [ -z "$NDK_ROOT" ]; then
+ export NDK_ROOT=${ANDROID_NDK_HOME}
+fi
tensorflow/contrib/makefile/build_all_android.sh
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index e809e89a41..a3b7ee786b 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -121,7 +121,7 @@ function get_failing_cpu_py_tests() {
//$1/tensorflow/python:session_test + \
//$1/tensorflow/python:supervisor_test + \
//$1/tensorflow/python:sync_replicas_optimizer_test + \
- //$1/tensorflow/python/debug/... + \
+ //$1/tensorflow/python/debug:curses_ui_test + \
//$1/tensorflow/python/kernel_tests:as_string_op_test + \
//$1/tensorflow/python/kernel_tests:benchmark_test + \
//$1/tensorflow/python/kernel_tests:cast_op_test + \