aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-24 16:50:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-24 18:05:02 -0700
commit67324b1e3af826c4c491802f4022a5f5be9f6670 (patch)
tree018e21feade905a14d8beb2c8e7ebdfd905bbf51
parente30936c026655f1b2f4f45997da32c257d18b076 (diff)
Merge changes from github.
Change: 125835079
-rw-r--r--README.md5
-rw-r--r--tensorflow/contrib/cmake/tf_core_cpu.cmake9
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake19
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake8
-rw-r--r--tensorflow/contrib/cmake/tf_stream_executor.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_tutorials.cmake5
-rw-r--r--tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm2
-rw-r--r--tensorflow/contrib/ios_examples/simple/RunModelViewController.mm2
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/dnn_ops.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/dropout_ops.py12
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/tests/dropout_ops_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py4
-rw-r--r--tensorflow/contrib/makefile/Makefile6
-rwxr-xr-xtensorflow/contrib/makefile/gen_file_lists.sh13
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt2
-rw-r--r--tensorflow/contrib/makefile/tf_cc_files.txt8
-rw-r--r--tensorflow/contrib/pi_examples/README.md73
-rw-r--r--tensorflow/contrib/pi_examples/camera/Makefile84
-rw-r--r--tensorflow/contrib/pi_examples/camera/camera.cc533
-rw-r--r--tensorflow/contrib/pi_examples/label_image/Makefile83
-rw-r--r--tensorflow/contrib/pi_examples/label_image/data/grace_hopper.jpgbin0 -> 61306 bytes
-rw-r--r--tensorflow/contrib/pi_examples/label_image/label_image.cc397
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph.py2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc2
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h1
-rw-r--r--tensorflow/core/kernels/reader_base.h2
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h58
-rw-r--r--tensorflow/core/lib/core/arena.cc2
-rw-r--r--tensorflow/core/lib/io/record_reader.cc9
-rw-r--r--tensorflow/core/lib/io/record_reader.h6
-rw-r--r--tensorflow/core/lib/io/record_writer.cc11
-rw-r--r--tensorflow/core/lib/io/record_writer.h10
-rw-r--r--tensorflow/examples/image_retraining/retrain.py118
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.metrics.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md4
-rw-r--r--tensorflow/g3doc/get_started/basic_usage.md2
-rw-r--r--tensorflow/g3doc/how_tos/image_retraining/index.md16
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md4
-rw-r--r--tensorflow/python/framework/dtypes.py4
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py11
-rw-r--r--tensorflow/python/ops/array_ops.py30
-rw-r--r--tensorflow/python/ops/math_grad.py2
-rw-r--r--tensorflow/python/ops/math_ops.py42
-rw-r--r--tensorflow/python/platform/flags.py5
-rw-r--r--tensorflow/python/platform/flags_test.py11
-rw-r--r--tensorflow/stream_executor/dso_loader.cc2
-rw-r--r--tensorflow/tools/pip_package/setup.py2
-rw-r--r--util/python/BUILD8
52 files changed, 1597 insertions, 96 deletions
diff --git a/README.md b/README.md
index adf3602d89..2327f8e92c 100644
--- a/README.md
+++ b/README.md
@@ -39,9 +39,10 @@ People who are a little bit adventurous can also try our nightly binaries:
* [Android](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/))
#### *Try your first TensorFlow program*
-```python
+```shell
$ python
-
+```
+```python
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake
index 374096e942..135c001536 100644
--- a/tensorflow/contrib/cmake/tf_core_cpu.cmake
+++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake
@@ -4,7 +4,6 @@
file(GLOB_RECURSE tf_core_cpu_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/*.h"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/*.cc"
- "${tensorflow_source_dir}/tensorflow/core/client/*.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/*.h"
"${tensorflow_source_dir}/tensorflow/core/graph/*.cc"
"${tensorflow_source_dir}/tensorflow/core/public/*.h"
@@ -18,9 +17,17 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.h"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc"
)
list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs})
+# We need to include stubs for the GPU tracer, which are in the exclude glob.
+list(APPEND tf_core_cpu_srcs
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_tracer.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_tracer.h"
+)
add_library(tf_core_cpu OBJECT ${tf_core_cpu_srcs})
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 78aa9169dd..3e6ec3c389 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -121,7 +121,10 @@ target_include_directories(tf_protos_cc PUBLIC
target_link_libraries(tf_protos_cc PUBLIC
${PROTOBUF_LIBRARIES}
)
-
+# C++11
+target_compile_features(tf_protos_cc PRIVATE
+ cxx_rvalue_references
+)
########################################################
# tf_core_lib library
@@ -154,11 +157,6 @@ target_include_directories(tf_core_lib PUBLIC
${jsoncpp_INCLUDE_DIR}
${boringssl_INCLUDE_DIR}
)
-#target_link_libraries(tf_core_lib
-# ${CMAKE_THREAD_LIBS_INIT}
-# ${PROTOBUF_LIBRARIES}
-# tf_protos_cc
-#)
target_compile_options(tf_core_lib PRIVATE
-fno-exceptions
-DEIGEN_AVOID_STL_ARRAY
@@ -188,6 +186,10 @@ file(GLOB_RECURSE tf_core_framework_srcs
"${tensorflow_source_dir}/tensorflow/core/framework/*.cc"
"${tensorflow_source_dir}/tensorflow/core/util/*.h"
"${tensorflow_source_dir}/tensorflow/core/util/*.cc"
+ "${tensorflow_source_dir}/tensorflow/core/client/tensor_c_api.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc"
+ "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc"
"${tensorflow_source_dir}/public/*.h"
)
@@ -204,7 +206,10 @@ file(GLOB_RECURSE tf_core_framework_test_srcs
list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_test_srcs})
-add_library(tf_core_framework OBJECT ${tf_core_framework_srcs} ${PROTO_TEXT_HDRS})
+add_library(tf_core_framework OBJECT
+ ${tf_core_framework_srcs}
+ ${PROTO_TEXT_HDRS}
+ ${PROTO_TEXT_SRCS})
target_include_directories(tf_core_framework PUBLIC
${tensorflow_source_dir}
${eigen_INCLUDE_DIRS}
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 5a22d88103..2fff5c2dd3 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -19,7 +19,13 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs})
add_library(tf_core_kernels OBJECT ${tf_core_kernels_srcs})
-add_dependencies(tf_core_kernels tf_core_cpu farmhash highwayhash)
+add_dependencies(tf_core_kernels
+ tf_core_cpu
+ farmhash
+ highwayhash
+ farmhash_copy_headers_to_destination
+ highwayhash_copy_headers_to_destination
+)
target_include_directories(tf_core_kernels PRIVATE
${tensorflow_source_dir}
diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake
index 0bc8dad0ab..e1aa0cd7b5 100644
--- a/tensorflow/contrib/cmake/tf_stream_executor.cmake
+++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake
@@ -58,6 +58,7 @@ add_library(tf_stream_executor OBJECT ${tf_stream_executor_srcs})
target_include_directories(tf_stream_executor PRIVATE
${tensorflow_source_dir}
+ ${eigen_INCLUDE_DIRS}
)
add_dependencies(tf_stream_executor
tf_core_lib
diff --git a/tensorflow/contrib/cmake/tf_tutorials.cmake b/tensorflow/contrib/cmake/tf_tutorials.cmake
index 89511b096d..11dfd4739b 100644
--- a/tensorflow/contrib/cmake/tf_tutorials.cmake
+++ b/tensorflow/contrib/cmake/tf_tutorials.cmake
@@ -35,10 +35,13 @@ target_include_directories(tf_tutorials_example_trainer PUBLIC
target_link_libraries(tf_tutorials_example_trainer PUBLIC
${CMAKE_THREAD_LIBS_INIT}
- ${PROTOBUF_LIBRARIES}
+ ${PROTOBUF_STATIC_LIBRARIES}
tf_protos_cc
re2_lib
+ ${boringssl_STATIC_LIBRARIES}
+ ${farmhash_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
+ ${jsoncpp_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}
${ZLIB_LIBRARIES}
${CMAKE_DL_LIBS}
diff --git a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm b/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm
index c529a2e171..dc79e7a12a 100644
--- a/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm
+++ b/tensorflow/contrib/ios_examples/camera/CameraExampleViewController.mm
@@ -291,7 +291,7 @@ didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer
in + (in_y * image_width * image_channels) + (in_x * image_channels);
float *out_pixel = out_row + (x * wanted_channels);
for (int c = 0; c < wanted_channels; ++c) {
- out_pixel[c] = (in_pixel[c] / input_std) - input_mean;
+ out_pixel[c] = (in_pixel[c] - input_mean) / input_std;
}
}
}
diff --git a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm b/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm
index 19f00ad479..2e389b39d4 100644
--- a/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/ios_examples/simple/RunModelViewController.mm
@@ -202,7 +202,7 @@ NSString* RunInferenceOnImage() {
tensorflow::uint8* in_pixel = in_row + (in_x * image_channels);
float* out_pixel = out_row + (x * wanted_channels);
for (int c = 0; c < wanted_channels; ++c) {
- out_pixel[c] = (in_pixel[c] / input_std) - input_mean;
+ out_pixel[c] = (in_pixel[c] - input_mean) / input_std;
}
}
}
diff --git a/tensorflow/contrib/learn/python/learn/ops/dnn_ops.py b/tensorflow/contrib/learn/python/learn/ops/dnn_ops.py
index c1fba21619..e2d55ef809 100644
--- a/tensorflow/contrib/learn/python/learn/ops/dnn_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/dnn_ops.py
@@ -22,7 +22,9 @@ from __future__ import print_function
from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
from tensorflow.contrib.learn.python.learn.ops import dropout_ops
+
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.ops import array_ops as array_ops_
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
@@ -32,6 +34,7 @@ from tensorflow.python.ops import variable_scope as vs
@deprecated('2016-08-01', 'Please use tf.contrib.layers.stack instead.')
def dnn(tensor_in, hidden_units, activation=nn.relu, dropout=None):
"""Creates fully connected deep neural network subgraph.
+ This is deprecated. Please use contrib.layers.dnn instead.
Args:
tensor_in: tensor or placeholder for input features.
@@ -42,6 +45,8 @@ def dnn(tensor_in, hidden_units, activation=nn.relu, dropout=None):
Returns:
A tensor which would be a deep neural network.
"""
+ logging.warning("learn.ops.dnn is deprecated, \
+ please use contrib.layers.dnn.")
with vs.variable_scope('dnn'):
for i, n_units in enumerate(hidden_units):
with vs.variable_scope('layer%d' % i):
diff --git a/tensorflow/contrib/learn/python/learn/ops/dropout_ops.py b/tensorflow/contrib/learn/python/learn/ops/dropout_ops.py
index d49a7d0d58..a0153f1ac6 100644
--- a/tensorflow/contrib/learn/python/learn/ops/dropout_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/dropout_ops.py
@@ -1,3 +1,4 @@
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,8 +22,10 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import tf_logging as logging
+
+from tensorflow.contrib.layers import dropout as contrib_dropout
# Key to collect dropout probabilities.
DROPOUTS = "dropouts"
@@ -30,7 +33,8 @@ DROPOUTS = "dropouts"
def dropout(tensor_in, prob, name=None):
"""Adds dropout node and stores probability tensor into graph collection.
-
+ This is deprecated. Please use contrib.layers.dropout instead.
+
Args:
tensor_in: Input tensor.
prob: Float or Tensor.
@@ -42,10 +46,12 @@ def dropout(tensor_in, prob, name=None):
Raises:
ValueError: If `keep_prob` is not in `(0, 1]`.
"""
+ logging.warning("learn.ops.dropout is deprecated, \
+ please use contrib.layers.dropout.")
with ops.op_scope([tensor_in], name, "dropout") as name:
if isinstance(prob, float):
prob = vs.get_variable("prob", [],
initializer=init_ops.constant_initializer(prob),
trainable=False)
ops.add_to_collection(DROPOUTS, prob)
- return nn.dropout(tensor_in, prob)
+ return contrib_dropout(tensor_in, keep_prob=prob)
diff --git a/tensorflow/contrib/learn/python/learn/ops/tests/dropout_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/tests/dropout_ops_test.py
index 4ce38b49eb..e6d3911c01 100644
--- a/tensorflow/contrib/learn/python/learn/ops/tests/dropout_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/tests/dropout_ops_test.py
@@ -29,6 +29,7 @@ class DropoutTest(tf.test.TestCase):
def test_dropout_float(self):
with self.test_session() as session:
+ tf.add_to_collection("IS_TRAINING", True)
x = tf.placeholder(tf.float32, [5, 5])
ops.dropout(x, 0.5)
probs = tf.get_collection(ops.DROPOUTS)
@@ -38,6 +39,7 @@ class DropoutTest(tf.test.TestCase):
def test_dropout_tensor(self):
with self.test_session():
+ tf.add_to_collection("IS_TRAINING", True)
x = tf.placeholder(tf.float32, [5, 5])
y = tf.get_variable("prob", [], initializer=tf.constant_initializer(0.5))
ops.dropout(x, y)
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py
index 47d835e6eb..a79bd29b8b 100644
--- a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py
@@ -20,10 +20,16 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-import pandas as pd
import tensorflow as tf
import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_functions as ff
+# pylint: disable=g-import-not-at-top
+try:
+ import pandas as pd
+ HAS_PANDAS = True
+except ImportError:
+ HAS_PANDAS = False
+
def vals_to_list(a):
return {key: val.tolist() if isinstance(val, np.ndarray) else val
@@ -72,6 +78,8 @@ class _FeedingFunctionsTestCase(tf.test.TestCase):
self.assertEqual(expected, vals_to_list(actual))
def testPandasFeedFnBatchOne(self):
+ if not HAS_PANDAS:
+ return
array1 = np.arange(32, 64)
array2 = np.arange(64, 96)
df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128))
@@ -88,6 +96,8 @@ class _FeedingFunctionsTestCase(tf.test.TestCase):
self.assertEqual(expected, vals_to_list(actual))
def testPandasFeedFnBatchFive(self):
+ if not HAS_PANDAS:
+ return
array1 = np.arange(32, 64)
array2 = np.arange(64, 96)
df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128))
@@ -105,6 +115,8 @@ class _FeedingFunctionsTestCase(tf.test.TestCase):
self.assertEqual(expected, vals_to_list(actual))
def testPandasFeedFnBatchOneHundred(self):
+ if not HAS_PANDAS:
+ return
array1 = np.arange(32, 64)
array2 = np.arange(64, 96)
df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(96, 128))
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py
index 8da96e5444..c200140d80 100644
--- a/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py
@@ -123,7 +123,6 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
"""Test construction from Pandas DataFrame."""
if not HAS_PANDAS:
return
-
pandas_df = pd.DataFrame({"sparrow": range(10), "ostrich": 1})
tensorflow_df = df.TensorFlowDataFrame.from_pandas(pandas_df,
batch_size=10,
@@ -176,7 +175,6 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
def testFromCSV(self):
if not HAS_PANDAS:
return
-
num_batches = 100
batch_size = 8
enqueue_size = 7
@@ -214,6 +212,8 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
self.assertEqual(expected_num_batches, actual_num_batches)
def testFromCSVWithFeatureSpec(self):
+ if not HAS_PANDAS:
+ return
num_batches = 100
batch_size = 8
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 5f77ee40e2..c9b4641afa 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -66,7 +66,8 @@ HOST_LIBS := \
-lstdc++ \
-lprotobuf \
-lpthread \
--lm
+-lm \
+-lz
# If we're on Linux, also link in the dl library.
ifeq ($(HOST_OS),LINUX)
@@ -115,7 +116,7 @@ PROTOGENDIR := $(GENDIR)proto/
# Settings for the target compiler.
CXX := $(CC_PREFIX) gcc
OPTFLAGS := -O0
-CXXFLAGS := --std=c++11 $(OPTFLAGS)
+CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD $(OPTFLAGS)
LDFLAGS := \
-L/usr/local/lib
@@ -367,6 +368,7 @@ TF_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_cc_files.txt)
PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt)
PROTO_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_proto_files.txt)
BENCHMARK_SRCS := \
+tensorflow/core/util/reporter.cc \
tensorflow/tools/benchmark/benchmark_model.cc \
tensorflow/tools/benchmark/benchmark_model_main.cc
diff --git a/tensorflow/contrib/makefile/gen_file_lists.sh b/tensorflow/contrib/makefile/gen_file_lists.sh
index 71a0d8d618..2bbc6bfcae 100755
--- a/tensorflow/contrib/makefile/gen_file_lists.sh
+++ b/tensorflow/contrib/makefile/gen_file_lists.sh
@@ -21,21 +21,22 @@ grep "//tensorflow/.*\.cc$" | \
grep -v "gen_proto_text" | \
grep -E -v "jpeg" | \
grep -E -v "png" | \
+grep -E -v "zlib" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/tf_cc_files.txt
+> tensorflow/contrib/makefile/tf_cc_files.txt
bazel query 'kind("source file", deps(//tensorflow/core:android_tensorflow_lib))' | \
grep "//tensorflow/.*\.proto$" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/tf_proto_files.txt
+> tensorflow/contrib/makefile/tf_proto_files.txt
bazel query 'kind("generated file", deps(//tensorflow/core:proto_text))' | \
grep "pb_text\.cc$" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/tf_pb_text_files.txt
+> tensorflow/contrib/makefile/tf_pb_text_files.txt
bazel query 'kind("source file", deps(//tensorflow/tools/proto_text:gen_proto_text_functions))' | \
grep -E "//tensorflow/.*\.cc$" | \
@@ -43,16 +44,16 @@ grep -E -v "jpeg" | \
grep -E -v "png" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/proto_text_cc_files.txt
+> tensorflow/contrib/makefile/proto_text_cc_files.txt
bazel query 'kind("generated file", deps(//tensorflow/tools/proto_text:gen_proto_text_functions))' | \
grep -E "//tensorflow/.*\.cc$" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/proto_text_pb_cc_files.txt
+> tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
bazel query 'kind("generated file", deps(//tensorflow/tools/proto_text:gen_proto_text_functions))' | \
grep -E "//tensorflow/.*\.h$" | \
sed -E 's#^//##g' | \
sed -E 's#:#/#g' \
-> make/proto_text_pb_h_files.txt
+> tensorflow/contrib/makefile/proto_text_pb_h_files.txt
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 5dc57a0484..1809a7a69b 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -24,6 +24,8 @@ tensorflow/core/lib/random/weighted_picker.cc
tensorflow/core/lib/random/simple_philox.cc
tensorflow/core/lib/random/random.cc
tensorflow/core/lib/random/distribution_sampler.cc
+tensorflow/core/lib/io/zlib_outputbuffer.cc
+tensorflow/core/lib/io/zlib_inputbuffer.cc
tensorflow/core/lib/io/two_level_iterator.cc
tensorflow/core/lib/io/table_builder.cc
tensorflow/core/lib/io/table.cc
diff --git a/tensorflow/contrib/makefile/tf_cc_files.txt b/tensorflow/contrib/makefile/tf_cc_files.txt
index 220d409d30..5402642f5b 100644
--- a/tensorflow/contrib/makefile/tf_cc_files.txt
+++ b/tensorflow/contrib/makefile/tf_cc_files.txt
@@ -7,6 +7,7 @@ tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/core/kernels/training_ops.cc
tensorflow/core/kernels/topk_op.cc
tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/strided_slice_op.cc
tensorflow/core/kernels/stack_ops.cc
tensorflow/core/kernels/split_op.cc
tensorflow/core/kernels/split_lib_cpu.cc
@@ -25,6 +26,7 @@ tensorflow/core/kernels/reverse_sequence_op.cc
tensorflow/core/kernels/reverse_op.cc
tensorflow/core/kernels/restore_op.cc
tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+tensorflow/core/kernels/resize_bilinear_op.cc
tensorflow/core/kernels/reshape_op.cc
tensorflow/core/kernels/relu_op.cc
tensorflow/core/kernels/reduction_ops_sum.cc
@@ -52,6 +54,7 @@ tensorflow/core/kernels/dense_update_ops.cc
tensorflow/core/kernels/cwise_ops_common.cc
tensorflow/core/kernels/cwise_op_tanh.cc
tensorflow/core/kernels/cwise_op_sub.cc
+tensorflow/core/kernels/cwise_op_squared_difference.cc
tensorflow/core/kernels/cwise_op_square.cc
tensorflow/core/kernels/cwise_op_sqrt.cc
tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -64,6 +67,7 @@ tensorflow/core/kernels/cwise_op_maximum.cc
tensorflow/core/kernels/cwise_op_log.cc
tensorflow/core/kernels/cwise_op_less.cc
tensorflow/core/kernels/cwise_op_isfinite.cc
+tensorflow/core/kernels/cwise_op_inverse.cc
tensorflow/core/kernels/cwise_op_greater.cc
tensorflow/core/kernels/cwise_op_exp.cc
tensorflow/core/kernels/cwise_op_equal_to.cc
@@ -71,6 +75,7 @@ tensorflow/core/kernels/cwise_op_div.cc
tensorflow/core/kernels/cwise_op_add.cc
tensorflow/core/kernels/ctc_decoder_ops.cc
tensorflow/core/kernels/conv_ops.cc
+tensorflow/core/kernels/conv_grad_ops.cc
tensorflow/core/kernels/control_flow_ops.cc
tensorflow/core/kernels/constant_op.cc
tensorflow/core/kernels/concat_op.cc
@@ -94,7 +99,6 @@ tensorflow/core/util/tensor_format.cc
tensorflow/core/util/stat_summarizer.cc
tensorflow/core/util/sparse/group_iterator.cc
tensorflow/core/util/saved_tensor_slice_util.cc
-tensorflow/core/util/reporter.cc
tensorflow/core/util/port.cc
tensorflow/core/util/padding.cc
tensorflow/core/util/mirror_pad_mode.cc
@@ -179,6 +183,7 @@ tensorflow/core/lib/core/arena.cc
tensorflow/core/graph/validate.cc
tensorflow/core/graph/tensor_id.cc
tensorflow/core/graph/subgraph.cc
+tensorflow/core/graph/quantize_training.cc
tensorflow/core/graph/optimizer_cse.cc
tensorflow/core/graph/node_builder.cc
tensorflow/core/graph/graph_partition.cc
@@ -200,6 +205,7 @@ tensorflow/core/framework/tensor_slice.cc
tensorflow/core/framework/tensor_shape.cc
tensorflow/core/framework/tensor_reference.cc
tensorflow/core/framework/tensor.cc
+tensorflow/core/framework/shape_inference.cc
tensorflow/core/framework/resource_mgr.cc
tensorflow/core/framework/rendezvous.cc
tensorflow/core/framework/reader_op_kernel.cc
diff --git a/tensorflow/contrib/pi_examples/README.md b/tensorflow/contrib/pi_examples/README.md
new file mode 100644
index 0000000000..8dde63e4c6
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/README.md
@@ -0,0 +1,73 @@
+# TensorFlow Raspberry Pi Examples
+
+This folder contains examples of how to build applications for the Raspberry Pi using TensorFlow.
+
+## Building the Examples
+
+ - Follow the Raspberry Pi section of the instructions at [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile) to compile a static library containing the core TensorFlow code.
+
+ - Install libjpeg, so we can load image files:
+
+```
+sudo apt-get install -y libjpeg-dev
+```
+
+ - To download the example model you'll need, run these commands:
+
+```bash
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015_stripped.zip \
+-o /tmp/inception_dec_2015_stripped.zip
+unzip /tmp/inception_dec_2015_stripped.zip \
+-d tensorflow/contrib/pi_examples/label_image/data/
+```
+
+ - From the root of the TensorFlow source tree, run `make -f tensorflow/contrib/pi_examples/label_image/Makefile` to build a basic example.
+
+## Usage
+
+Run `tensorflow/contrib/pi_examples/label_image/gen/bin/label_image` to try out image labeling with the default Grace Hopper image. You should several lines of output, with "Military Uniform" shown as the top result, something like this:
+
+```bash
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:384] Running model succeeded!
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:284] military uniform (866): 0.624293
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:284] suit (794): 0.0473981
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:284] academic gown (896): 0.0280926
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:284] bolo tie (940): 0.0156956
+I tensorflow/contrib/pi_examples/label_image/label_image.cc:284] bearskin (849): 0.0143348
+```
+
+Once you've verified that is working, you can supply your own images with `--image=your_image.jpg`, or even with graphs you've trained yourself with the TensorFlow for Poets tutorial using `--graph=your_graph.pb --input=Mul:0 --output=final_result:0`.
+
+## Camera Example
+
+Once you have the simple example running, you can try out a more complex version that
+reads frames from a camera attached to the Pi. You'll need to install and set up your
+camera module first. The example uses Video4Linux, so you'll need to install that first.
+Here's some commands I found necessary to get that set up, and I found more information
+at this blog post: http://www.richardmudhar.com/blog/2015/02/raspberry-pi-camera-and-motion-out-of-the-box-sparrowcam/
+
+```
+sudo bash -c "echo 'bcm2835-v4l2' >> /etc/modules"
+sudo apt-get install libv4l-dev
+```
+
+Once that's working, run the following commands to build and run the camera example:
+
+```bash
+make -f tensorflow/contrib/pi_examples/camera/Makefile
+tensorflow/contrib/pi_examples/camera/gen/bin/camera
+```
+
+You should see it looping over camera frames as they come in, and printing the top labels
+to the command line. This is a great starting point for all sorts of fun image recognition
+applications, especially when you combine it with a custom model you've built using
+something like the TensorFlow for Poets tutorial.
+
+The example is designed to work with the Flite speech synthesis tool, so that your Pi
+can speak any labels that have a high enough score. To enable this, just install the
+Flite package and then pipe the output of the binary you've built, like this:
+
+```
+sudo apt-get install flite
+tensorflow/contrib/pi_examples/camera/gen/bin/camera | xargs -n1 flite -t
+```
diff --git a/tensorflow/contrib/pi_examples/camera/Makefile b/tensorflow/contrib/pi_examples/camera/Makefile
new file mode 100644
index 0000000000..2d14606400
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/camera/Makefile
@@ -0,0 +1,84 @@
+# This Makefile compiles the label_image example for the Raspberry Pi.
+# See tensorflow/contrib/pi_examples/README.md for full build instructions.
+
+# Find where we're running from, so we can store generated files here.
+SCRIPT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
+
+# The location of the tensorflow/contrib/makefile directory.
+TFMAKEFILE_DIR := $(SCRIPT_DIR)/../../makefile
+
+# Where compiled objects are stored.
+GENDIR := $(SCRIPT_DIR)/gen/
+OBJDIR := $(GENDIR)obj/
+LIBDIR := $(GENDIR)lib/
+BINDIR := $(GENDIR)bin/
+
+# The expected locations of the TensorFlow library.
+TFLIBDIR := $(TFMAKEFILE_DIR)/gen/lib
+TFLIBS := $(TFLIBDIR)/libtensorflow-core.a
+
+# Where the downloads have been stored.
+DOWNLOADSDIR := $(TFMAKEFILE_DIR)/downloads
+
+# The location of the compiled protobuf headers generated by TensorFlow.
+PBTGENDIR := $(TFMAKEFILE_DIR)/gen/proto_text/
+PROTOGENDIR := $(TFMAKEFILE_DIR)/gen/proto/
+
+# The name of the output program we're compiling.
+EXECUTABLE_NAME := $(BINDIR)/camera
+
+# Settings for the target compiler.
+CXX := gcc
+OPTFLAGS := -O0
+CXXFLAGS := --std=c++11 $(OPTFLAGS)
+LDFLAGS := \
+-L/usr/local/lib \
+-L$(TFLIBDIR) \
+-Wl,--no-whole-archive
+INCLUDES := \
+-I/usr/local/include \
+-I. \
+-I$(DOWNLOADSDIR) \
+-I$(DOWNLOADSDIR)/eigen-latest/ \
+-I$(PROTOGENDIR) \
+-I$(PBTGENDIR)
+LIBS := \
+-lstdc++ \
+-lprotobuf \
+-lv4l2 \
+-Wl,--allow-multiple-definition \
+-Wl,--whole-archive \
+-ltensorflow-core \
+-Wl,--no-whole-archive \
+-ldl \
+-lpthread \
+-lm \
+-ljpeg
+LIBFLAGS :=
+
+EXECUTABLE_SRCS := tensorflow/contrib/pi_examples/camera/camera.cc
+
+# File names of the intermediate files target compilation generates.
+EXECUTABLE_OBJS := $(addprefix $(OBJDIR), $(EXECUTABLE_SRCS:.cc=.o))
+
+.PHONY: clean
+
+# The target that's compiled if there's no command-line arguments.
+all: $(EXECUTABLE_NAME)
+
+# Rules for target compilation.
+
+$(EXECUTABLE_NAME): $(EXECUTABLE_OBJS) $(TFLIBS)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(EXECUTABLE_NAME) $(EXECUTABLE_OBJS) \
+ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
+
+# Matches on C++ source files.
+$(OBJDIR)%.o: %.cc
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
+
+# Gets rid of all generated files.
+clean:
+ rm -rf $(GENDIR)
diff --git a/tensorflow/contrib/pi_examples/camera/camera.cc b/tensorflow/contrib/pi_examples/camera/camera.cc
new file mode 100644
index 0000000000..9bba110a52
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/camera/camera.cc
@@ -0,0 +1,533 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Full build instructions are at tensorflow/contrib/pi_examples/README.md.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <fstream>
+#include <libv4l2.h>
+#include <linux/videodev2.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/mman.h>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+// These are all common classes it's handy to reference with no namespace.
+using tensorflow::Flag;
+using tensorflow::Tensor;
+using tensorflow::Status;
+using tensorflow::string;
+using tensorflow::int32;
+
+// Used to store the memory-mapped buffers we use for capture.
+struct CameraBuffer {
+ void* start;
+ size_t length;
+};
+
+// Wrapper around camera command sending.
+Status SendCameraCommand(int fh, int request, void* arg) {
+ int r;
+ do {
+ r = v4l2_ioctl(fh, request, arg);
+ } while (r == -1 && ((errno == EINTR) || (errno == EAGAIN)));
+ if (r == -1) {
+ LOG(ERROR) << "SendCameraCommand error " << errno << " (" << strerror(errno)
+ << ")";
+ return tensorflow::errors::Unknown("SendCameraCommand error ", errno,
+ strerror(errno));
+ }
+ return Status::OK();
+}
+
+Status OpenCamera(int* camera_handle) {
+ const char* dev_name = "/dev/video0";
+ int fd = v4l2_open(dev_name, O_RDWR | O_NONBLOCK, 0);
+ if (fd < 0) {
+ LOG(ERROR) << "Cannot open camera device";
+ return tensorflow::errors::NotFound("V4L2 camera device not found");
+ }
+ *camera_handle = fd;
+ return Status::OK();
+}
+
+Status CloseCamera(int camera_handle) {
+ v4l2_close(camera_handle);
+ return Status::OK();
+}
+
+Status SetCameraFormat(int camera_handle, int wanted_width, int wanted_height) {
+ struct v4l2_format fmt;
+ memset(&fmt, 0, sizeof(fmt));
+ fmt.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ fmt.fmt.pix.width = wanted_width;
+ fmt.fmt.pix.height = wanted_height;
+ fmt.fmt.pix.pixelformat = V4L2_PIX_FMT_RGB24;
+ fmt.fmt.pix.field = V4L2_FIELD_INTERLACED;
+ Status set_format_status =
+ SendCameraCommand(camera_handle, VIDIOC_S_FMT, &fmt);
+ if (!set_format_status.ok()) {
+ LOG(ERROR) << "Setting format failed with " << set_format_status;
+ return set_format_status;
+ }
+ if (fmt.fmt.pix.pixelformat != V4L2_PIX_FMT_RGB24) {
+ LOG(ERROR) << "Libv4l didn't accept RGB24 format. Can't proceed.";
+ return tensorflow::errors::Unknown("Libv4l didn't accept RGB24 format");
+ }
+ if ((fmt.fmt.pix.width != wanted_width) ||
+ (fmt.fmt.pix.height != wanted_height)) {
+ LOG(WARNING) << "Warning: driver is sending image at " << fmt.fmt.pix.width
+ << "x" << fmt.fmt.pix.height;
+ }
+ return Status::OK();
+}
+
+Status StartCameraCapture(int camera_handle, int buffer_count,
+ CameraBuffer** buffers) {
+ struct v4l2_requestbuffers req;
+ memset(&req, 0, sizeof(req));
+ req.count = buffer_count;
+ req.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ req.memory = V4L2_MEMORY_MMAP;
+ Status request_buffers_status =
+ SendCameraCommand(camera_handle, VIDIOC_REQBUFS, &req);
+ if (!request_buffers_status.ok()) {
+ LOG(ERROR) << "Request buffers failed with " << request_buffers_status;
+ return request_buffers_status;
+ }
+
+ *buffers = (CameraBuffer*)(calloc(buffer_count, sizeof(*buffers)));
+ for (int n_buffers = 0; n_buffers < buffer_count; ++n_buffers) {
+ struct v4l2_buffer buf;
+ memset(&buf, 0, sizeof(buf));
+ buf.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ buf.memory = V4L2_MEMORY_MMAP;
+ buf.index = n_buffers;
+ Status query_buffer_status =
+ SendCameraCommand(camera_handle, VIDIOC_QUERYBUF, &buf);
+ if (!query_buffer_status.ok()) {
+ LOG(ERROR) << "Query buffer failed with " << query_buffer_status;
+ return query_buffer_status;
+ }
+ (*buffers)[n_buffers].length = buf.length;
+ (*buffers)[n_buffers].start =
+ v4l2_mmap(NULL, buf.length, PROT_READ | PROT_WRITE, MAP_SHARED,
+ camera_handle, buf.m.offset);
+
+ if (MAP_FAILED == (*buffers)[n_buffers].start) {
+ LOG(ERROR) << "Memory-mapping buffer failed";
+ return tensorflow::errors::Unknown("Memory-mapping buffer failed");
+ }
+ }
+
+ for (int i = 0; i < buffer_count; ++i) {
+ struct v4l2_buffer buf;
+ memset(&buf, 0, sizeof(buf));
+ buf.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ buf.memory = V4L2_MEMORY_MMAP;
+ buf.index = i;
+ Status set_buffer_status =
+ SendCameraCommand(camera_handle, VIDIOC_QBUF, &buf);
+ if (!set_buffer_status.ok()) {
+ LOG(ERROR) << "Set buffer failed with " << set_buffer_status;
+ return set_buffer_status;
+ }
+ }
+
+ enum v4l2_buf_type type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ Status stream_on_status =
+ SendCameraCommand(camera_handle, VIDIOC_STREAMON, &type);
+ if (!stream_on_status.ok()) {
+ LOG(ERROR) << "Turning stream on failed with " << stream_on_status;
+ return stream_on_status;
+ }
+ return Status::OK();
+}
+
+Status EndCameraCapture(int camera_handle, CameraBuffer* buffers,
+ int buffer_count) {
+ enum v4l2_buf_type type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ Status stream_off_status =
+ SendCameraCommand(camera_handle, VIDIOC_STREAMOFF, &type);
+ if (!stream_off_status.ok()) {
+ LOG(ERROR) << "Turning stream off failed with " << stream_off_status;
+ return stream_off_status;
+ }
+ for (int i = 0; i < buffer_count; ++i)
+ v4l2_munmap(buffers[i].start, buffers[i].length);
+ return Status::OK();
+}
+
+Status CaptureNextFrame(int camera_handle, CameraBuffer* buffers,
+ uint8_t** frame_data, int* frame_data_size,
+ v4l2_buffer* buf) {
+ int r;
+ do {
+ fd_set fds;
+ FD_ZERO(&fds);
+ FD_SET(camera_handle, &fds);
+ struct timeval tv;
+ tv.tv_sec = 2;
+ tv.tv_usec = 0;
+ r = select(camera_handle + 1, &fds, NULL, NULL, &tv);
+ } while ((r == -1 && (errno = EINTR)));
+ if (r == -1) {
+ LOG(ERROR) << "select() failed while waiting for the camera with " << errno;
+ return tensorflow::errors::Unknown(
+ "CaptureCameraFrame: select() failed with", errno);
+ }
+
+ memset(buf, 0, sizeof(*buf));
+ buf->type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
+ buf->memory = V4L2_MEMORY_MMAP;
+ Status get_buffer_status =
+ SendCameraCommand(camera_handle, VIDIOC_DQBUF, buf);
+ if (!get_buffer_status.ok()) {
+ LOG(ERROR) << "Get buffer failed with " << get_buffer_status;
+ return get_buffer_status;
+ }
+
+ *frame_data = static_cast<uint8_t*>(buffers[buf->index].start);
+ *frame_data_size = buf->bytesused;
+
+ return Status::OK();
+}
+
+Status ReleaseFrame(int camera_handle, v4l2_buffer* buf) {
+ Status release_buffer_status =
+ SendCameraCommand(camera_handle, VIDIOC_QBUF, buf);
+ if (!release_buffer_status.ok()) {
+ LOG(ERROR) << "Release buffer failed with " << release_buffer_status;
+ return release_buffer_status;
+ }
+}
+
+// Reads a model graph definition from disk, and creates a session object you
+// can use to run it.
+Status LoadGraph(string graph_file_name,
+ std::unique_ptr<tensorflow::Session>* session) {
+ tensorflow::GraphDef graph_def;
+ Status load_graph_status =
+ ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
+ if (!load_graph_status.ok()) {
+ return tensorflow::errors::NotFound("Failed to load compute graph at '",
+ graph_file_name, "'");
+ }
+ session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
+ Status session_create_status = (*session)->Create(graph_def);
+ if (!session_create_status.ok()) {
+ return session_create_status;
+ }
+ return Status::OK();
+}
+
+// Analyzes the output of the Inception graph to retrieve the highest scores and
+// their positions in the tensor, which correspond to categories.
+Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
+ Tensor* out_indices, Tensor* out_scores) {
+ const Tensor& unsorted_scores_tensor = outputs[0];
+ auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
+ std::vector<std::pair<int, float>> scores;
+ for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
+ scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
+ }
+ std::sort(scores.begin(), scores.end(),
+ [](const std::pair<int, float>& left,
+ const std::pair<int, float>& right) {
+ return left.second > right.second;
+ });
+ scores.resize(how_many_labels);
+ Tensor sorted_indices(tensorflow::DT_INT32, {scores.size()});
+ Tensor sorted_scores(tensorflow::DT_FLOAT, {scores.size()});
+ for (int i = 0; i < scores.size(); ++i) {
+ sorted_indices.flat<int>()(i) = scores[i].first;
+ sorted_scores.flat<float>()(i) = scores[i].second;
+ }
+ *out_indices = sorted_indices;
+ *out_scores = sorted_scores;
+ return Status::OK();
+}
+
+// Takes a file name, and loads a list of labels from it, one per line, and
+// returns a vector of the strings. It pads with empty strings so the length
+// of the result is a multiple of 16, because our model expects that.
+Status ReadLabelsFile(string file_name, std::vector<string>* result,
+ size_t* found_label_count) {
+ std::ifstream file(file_name);
+ if (!file) {
+ return tensorflow::errors::NotFound("Labels file ", file_name,
+ " not found.");
+ }
+ result->clear();
+ string line;
+ while (std::getline(file, line)) {
+ result->push_back(line);
+ }
+ *found_label_count = result->size();
+ const int padding = 16;
+ while (result->size() % padding) {
+ result->emplace_back();
+ }
+ return Status::OK();
+}
+
+// Given the output of a model run, and the name of a file containing the labels
+// this prints out the top five highest-scoring values.
+Status PrintTopLabels(const std::vector<Tensor>& outputs,
+ const std::vector<string>& labels, int label_count,
+ float print_threshold) {
+ const int how_many_labels = std::min(5, static_cast<int>(label_count));
+ Tensor indices;
+ Tensor scores;
+ TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
+ tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
+ tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
+ for (int pos = 0; pos < how_many_labels; ++pos) {
+ const int label_index = indices_flat(pos);
+ const float score = scores_flat(pos);
+ LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
+ // Print the top label to stdout if it's above a threshold.
+ if ((pos == 0) && (score > print_threshold)) {
+ std::cout << labels[label_index] << std::endl;
+ }
+ }
+ return Status::OK();
+}
+
+// Given an image buffer, resize it to the requested size, and then scale the
+// values as desired.
+Status TensorFromFrame(uint8_t* image_data, int image_width, int image_height,
+ int image_channels, const int wanted_height,
+ const int wanted_width, const float input_mean,
+ const float input_std,
+ std::vector<Tensor>* out_tensors) {
+ const int wanted_channels = 3;
+ if (image_channels < wanted_channels) {
+ return tensorflow::errors::FailedPrecondition(
+ "Image needs to have at least ", wanted_channels, " but only has ",
+ image_channels);
+ }
+ // In these loops, we convert the eight-bit data in the image into float,
+ // resize it using bilinear filtering, and scale it numerically to the float
+ // range that the model expects (given by input_mean and input_std).
+ tensorflow::Tensor image_tensor(
+ tensorflow::DT_FLOAT,
+ tensorflow::TensorShape(
+ {1, wanted_height, wanted_width, wanted_channels}));
+ auto image_tensor_mapped = image_tensor.tensor<float, 4>();
+ tensorflow::uint8* in = image_data;
+ float* out = image_tensor_mapped.data();
+ const size_t image_rowlen = image_width * image_channels;
+ const float width_scale = static_cast<float>(image_width) / wanted_width;
+ const float height_scale = static_cast<float>(image_height) / wanted_height;
+ for (int y = 0; y < wanted_height; ++y) {
+ const float in_y = y * height_scale;
+ const int top_y_index = static_cast<int>(floorf(in_y));
+ const int bottom_y_index =
+ std::min(static_cast<int>(ceilf(in_y)), (image_height - 1));
+ const float y_lerp = in_y - top_y_index;
+ tensorflow::uint8* in_top_row = in + (top_y_index * image_rowlen);
+ tensorflow::uint8* in_bottom_row = in + (bottom_y_index * image_rowlen);
+ float* out_row = out + (y * wanted_width * wanted_channels);
+ for (int x = 0; x < wanted_width; ++x) {
+ const float in_x = x * width_scale;
+ const int left_x_index = static_cast<int>(floorf(in_x));
+ const int right_x_index =
+ std::min(static_cast<int>(ceilf(in_x)), (image_width - 1));
+ tensorflow::uint8* in_top_left_pixel =
+ in_top_row + (left_x_index * wanted_channels);
+ tensorflow::uint8* in_top_right_pixel =
+ in_top_row + (right_x_index * wanted_channels);
+ tensorflow::uint8* in_bottom_left_pixel =
+ in_bottom_row + (left_x_index * wanted_channels);
+ tensorflow::uint8* in_bottom_right_pixel =
+ in_bottom_row + (right_x_index * wanted_channels);
+ const float x_lerp = in_x - left_x_index;
+ float* out_pixel = out_row + (x * wanted_channels);
+ for (int c = 0; c < wanted_channels; ++c) {
+ const float top_left((in_top_left_pixel[c] - input_mean) / input_std);
+ const float top_right((in_top_right_pixel[c] - input_mean) / input_std);
+ const float bottom_left((in_bottom_left_pixel[c] - input_mean) /
+ input_std);
+ const float bottom_right((in_bottom_right_pixel[c] - input_mean) /
+ input_std);
+ const float top = top_left + (top_right - top_left) * x_lerp;
+ const float bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ out_pixel[c] = top + (bottom - top) * y_lerp;
+ }
+ }
+ }
+
+ out_tensors->push_back(image_tensor);
+ return Status::OK();
+}
+
+int main(int argc, char** argv) {
+ string graph =
+ "tensorflow/contrib/pi_examples/label_image/data/"
+ "tensorflow_inception_stripped.pb";
+ string labels_file_name =
+ "tensorflow/contrib/pi_examples/label_image/data/"
+ "imagenet_comp_graph_label_strings.txt";
+ int32 input_width = 299;
+ int32 input_height = 299;
+ int32 input_mean = 128;
+ int32 input_std = 128;
+ string input_layer = "Mul";
+ string output_layer = "softmax";
+ int32 video_width = 640;
+ int32 video_height = 480;
+ int print_threshold = 50;
+ string root_dir = "";
+ const bool parse_result = tensorflow::ParseFlags(
+ &argc, argv, {Flag("graph", &graph), //
+ Flag("labels", &labels_file_name), //
+ Flag("input_width", &input_width), //
+ Flag("input_height", &input_height), //
+ Flag("input_mean", &input_mean), //
+ Flag("input_std", &input_std), //
+ Flag("input_layer", &input_layer), //
+ Flag("output_layer", &output_layer), //
+ Flag("video_width", &video_width), //
+ Flag("video_height", &video_height), //
+ Flag("print_threshold", &print_threshold), //
+ Flag("root_dir", &root_dir)});
+ if (!parse_result) {
+ LOG(ERROR) << "Error parsing command-line flags.";
+ return -1;
+ }
+
+ // First we load and initialize the model.
+ std::unique_ptr<tensorflow::Session> session;
+ string graph_path = tensorflow::io::JoinPath(root_dir, graph);
+ Status load_graph_status = LoadGraph(graph_path, &session);
+ if (!load_graph_status.ok()) {
+ LOG(ERROR) << load_graph_status;
+ return -1;
+ }
+
+ std::vector<string> labels;
+ size_t label_count;
+ Status read_labels_status =
+ ReadLabelsFile(labels_file_name, &labels, &label_count);
+ if (!read_labels_status.ok()) {
+ LOG(ERROR) << read_labels_status;
+ return -1;
+ }
+
+ int camera_handle;
+ Status open_status = OpenCamera(&camera_handle);
+ if (!open_status.ok()) {
+ LOG(ERROR) << "OpenCamera failed with " << open_status;
+ return -1;
+ }
+
+ Status format_status =
+ SetCameraFormat(camera_handle, video_width, video_height);
+ if (!format_status.ok()) {
+ LOG(ERROR) << "SetCameraFormat failed with " << format_status;
+ return -1;
+ }
+
+ const int how_many_buffers = 2;
+ CameraBuffer* buffers;
+ Status start_capture_status =
+ StartCameraCapture(camera_handle, how_many_buffers, &buffers);
+ if (!start_capture_status.ok()) {
+ LOG(ERROR) << "StartCameraCapture failed with " << start_capture_status;
+ return -1;
+ }
+
+ for (int i = 0; i < 200; i++) {
+ uint8_t* frame_data;
+ int frame_data_size;
+ v4l2_buffer buf;
+ Status capture_next_status = CaptureNextFrame(
+ camera_handle, buffers, &frame_data, &frame_data_size, &buf);
+ if (!capture_next_status.ok()) {
+ LOG(ERROR) << "CaptureNextFrame failed with " << capture_next_status;
+ return -1;
+ }
+
+ std::vector<Tensor> resized_tensors;
+ Status tensor_from_frame_status =
+ TensorFromFrame(frame_data, video_width, video_height, 3, input_height,
+ input_width, input_mean, input_std, &resized_tensors);
+ if (!tensor_from_frame_status.ok()) {
+ LOG(ERROR) << tensor_from_frame_status;
+ return -1;
+ }
+ const Tensor& resized_tensor = resized_tensors[0];
+
+ Status release_frame_status = ReleaseFrame(camera_handle, &buf);
+ if (!release_frame_status.ok()) {
+ LOG(ERROR) << "ReleaseFrame failed with " << release_frame_status;
+ return -1;
+ }
+
+ // Actually run the image through the model.
+ std::vector<Tensor> outputs;
+ Status run_status = session->Run({{input_layer, resized_tensor}},
+ {output_layer}, {}, &outputs);
+ if (!run_status.ok()) {
+ LOG(ERROR) << "Running model failed: " << run_status;
+ return -1;
+ }
+
+ // Do something interesting with the results we've generated.
+ Status print_status =
+ PrintTopLabels(outputs, labels, label_count, print_threshold * 0.01f);
+ if (!print_status.ok()) {
+ LOG(ERROR) << "Running print failed: " << print_status;
+ return -1;
+ }
+ }
+
+ Status end_capture_status =
+ EndCameraCapture(camera_handle, buffers, how_many_buffers);
+ if (!end_capture_status.ok()) {
+ LOG(ERROR) << "EndCameraCapture failed with " << end_capture_status;
+ return -1;
+ }
+
+ Status close_status = CloseCamera(camera_handle);
+ if (!close_status.ok()) {
+ LOG(ERROR) << "CloseCamera failed with " << open_status;
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/tensorflow/contrib/pi_examples/label_image/Makefile b/tensorflow/contrib/pi_examples/label_image/Makefile
new file mode 100644
index 0000000000..1f310ec93b
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/label_image/Makefile
@@ -0,0 +1,83 @@
+# This Makefile compiles the label_image example for the Raspberry Pi.
+# See tensorflow/contrib/pi_examples/README.md for full build instructions.
+
+# Find where we're running from, so we can store generated files here.
+SCRIPT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
+
+# The location of the tensorflow/contrib/makefile directory.
+TFMAKEFILE_DIR := $(SCRIPT_DIR)/../../makefile
+
+# Where compiled objects are stored.
+GENDIR := $(SCRIPT_DIR)/gen/
+OBJDIR := $(GENDIR)obj/
+LIBDIR := $(GENDIR)lib/
+BINDIR := $(GENDIR)bin/
+
+# The expected locations of the TensorFlow library.
+TFLIBDIR := $(TFMAKEFILE_DIR)/gen/lib
+TFLIBS := $(TFLIBDIR)/libtensorflow-core.a
+
+# Where the downloads have been stored.
+DOWNLOADSDIR := $(TFMAKEFILE_DIR)/downloads
+
+# The location of the compiled protobuf headers generated by TensorFlow.
+PBTGENDIR := $(TFMAKEFILE_DIR)/gen/proto_text/
+PROTOGENDIR := $(TFMAKEFILE_DIR)/gen/proto/
+
+# The name of the output program we're compiling.
+EXECUTABLE_NAME := $(BINDIR)/label_image
+
+# Settings for the target compiler.
+CXX := gcc
+OPTFLAGS := -O0
+CXXFLAGS := --std=c++11 $(OPTFLAGS)
+LDFLAGS := \
+-L/usr/local/lib \
+-L$(TFLIBDIR) \
+-Wl,--no-whole-archive
+INCLUDES := \
+-I/usr/local/include \
+-I. \
+-I$(DOWNLOADSDIR) \
+-I$(DOWNLOADSDIR)/eigen-latest/ \
+-I$(PROTOGENDIR) \
+-I$(PBTGENDIR)
+LIBS := \
+-lstdc++ \
+-lprotobuf \
+-Wl,--allow-multiple-definition \
+-Wl,--whole-archive \
+-ltensorflow-core \
+-Wl,--no-whole-archive \
+-ldl \
+-lpthread \
+-lm \
+-ljpeg
+LIBFLAGS :=
+
+EXECUTABLE_SRCS := tensorflow/contrib/pi_examples/label_image/label_image.cc
+
+# File names of the intermediate files target compilation generates.
+EXECUTABLE_OBJS := $(addprefix $(OBJDIR), $(EXECUTABLE_SRCS:.cc=.o))
+
+.PHONY: clean
+
+# The target that's compiled if there's no command-line arguments.
+all: $(EXECUTABLE_NAME)
+
+# Rules for target compilation.
+
+$(EXECUTABLE_NAME): $(EXECUTABLE_OBJS) $(TFLIBS)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(EXECUTABLE_NAME) $(EXECUTABLE_OBJS) \
+ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
+
+# Matches on C++ source files.
+$(OBJDIR)%.o: %.cc
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
+
+# Gets rid of all generated files.
+clean:
+ rm -rf $(GENDIR)
diff --git a/tensorflow/contrib/pi_examples/label_image/data/grace_hopper.jpg b/tensorflow/contrib/pi_examples/label_image/data/grace_hopper.jpg
new file mode 100644
index 0000000000..478720d669
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/label_image/data/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc
new file mode 100644
index 0000000000..70f32f2199
--- /dev/null
+++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc
@@ -0,0 +1,397 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A minimal but useful C++ example showing how to load an Imagenet-style object
+// recognition TensorFlow model, prepare input images for it, run them through
+// the graph, and interpret the results.
+//
+// It has been stripped down from the tensorflow/examples/label_image sample
+// code to remove features and ops not included in the mobile/embedded core
+// library available on the Raspberry Pi.
+//
+// Full build instructions are at tensorflow/contrib/pi_examples/README.md.
+
+#include <fstream>
+#include <jpeglib.h>
+#include <setjmp.h>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+// These are all common classes it's handy to reference with no namespace.
+using tensorflow::Flag;
+using tensorflow::Tensor;
+using tensorflow::Status;
+using tensorflow::string;
+using tensorflow::int32;
+
+// Takes a file name, and loads a list of labels from it, one per line, and
+// returns a vector of the strings. It pads with empty strings so the length
+// of the result is a multiple of 16, because our model expects that.
+Status ReadLabelsFile(string file_name, std::vector<string>* result,
+ size_t* found_label_count) {
+ std::ifstream file(file_name);
+ if (!file) {
+ return tensorflow::errors::NotFound("Labels file ", file_name,
+ " not found.");
+ }
+ result->clear();
+ string line;
+ while (std::getline(file, line)) {
+ result->push_back(line);
+ }
+ *found_label_count = result->size();
+ const int padding = 16;
+ while (result->size() % padding) {
+ result->emplace_back();
+ }
+ return Status::OK();
+}
+
+// Error handling for JPEG decoding.
+void CatchError(j_common_ptr cinfo) {
+ (*cinfo->err->output_message)(cinfo);
+ jmp_buf *jpeg_jmpbuf = reinterpret_cast<jmp_buf *>(cinfo->client_data);
+ jpeg_destroy(cinfo);
+ longjmp(*jpeg_jmpbuf, 1);
+}
+
+// Decompresses a JPEG file from disk.
+Status LoadJpegFile(string file_name, std::vector<tensorflow::uint8>* data,
+ int* width, int* height, int* channels) {
+ struct jpeg_decompress_struct cinfo;
+ FILE * infile;
+ JSAMPARRAY buffer;
+ int row_stride;
+
+ if ((infile = fopen(file_name.c_str(), "rb")) == NULL) {
+ LOG(ERROR) << "Can't open " << file_name;
+ return tensorflow::errors::NotFound("JPEG file ", file_name,
+ " not found");
+ }
+
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf; // recovery point in case of error
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ return tensorflow::errors::Unknown("JPEG decoding failed");
+ }
+
+ jpeg_create_decompress(&cinfo);
+ jpeg_stdio_src(&cinfo, infile);
+ jpeg_read_header(&cinfo, TRUE);
+ jpeg_start_decompress(&cinfo);
+ *width = cinfo.output_width;
+ *height = cinfo.output_height;
+ *channels = cinfo.output_components;
+ data->resize((*height) * (*width) * (*channels));
+
+ row_stride = cinfo.output_width * cinfo.output_components;
+ buffer = (*cinfo.mem->alloc_sarray)
+ ((j_common_ptr) &cinfo, JPOOL_IMAGE, row_stride, 1);
+ while (cinfo.output_scanline < cinfo.output_height) {
+ tensorflow::uint8* row_address = &((*data)[cinfo.output_scanline * row_stride]);
+ jpeg_read_scanlines(&cinfo, buffer, 1);
+ memcpy(row_address, buffer[0], row_stride);
+ }
+
+ jpeg_finish_decompress(&cinfo);
+ jpeg_destroy_decompress(&cinfo);
+ fclose(infile);
+ return Status::OK();
+}
+
+// Given an image file name, read in the data, try to decode it as an image,
+// resize it to the requested size, and then scale the values as desired.
+Status ReadTensorFromImageFile(string file_name, const int wanted_height,
+ const int wanted_width, const float input_mean,
+ const float input_std,
+ std::vector<Tensor>* out_tensors) {
+ std::vector<tensorflow::uint8> image_data;
+ int image_width;
+ int image_height;
+ int image_channels;
+ TF_RETURN_IF_ERROR(LoadJpegFile(file_name, &image_data, &image_width,
+ &image_height, &image_channels));
+ LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height
+ << "x" << image_channels;
+ const int wanted_channels = 3;
+ if (image_channels < wanted_channels) {
+ return tensorflow::errors::FailedPrecondition("Image needs to have at least ",
+ wanted_channels, " but only has ",
+ image_channels);
+ }
+ // In these loops, we convert the eight-bit data in the image into float, resize
+ // it using bilinear filtering, and scale it numerically to the float range that
+ // the model expects (given by input_mean and input_std).
+ tensorflow::Tensor image_tensor(
+ tensorflow::DT_FLOAT, tensorflow::TensorShape(
+ {1, wanted_height, wanted_width, wanted_channels}));
+ auto image_tensor_mapped = image_tensor.tensor<float, 4>();
+ tensorflow::uint8* in = image_data.data();
+ float *out = image_tensor_mapped.data();
+ const size_t image_rowlen = image_width * image_channels;
+ const float width_scale = static_cast<float>(image_width) / wanted_width;
+ const float height_scale = static_cast<float>(image_height) / wanted_height;
+ for (int y = 0; y < wanted_height; ++y) {
+ const float in_y = y * height_scale;
+ const int top_y_index = static_cast<int>(floorf(in_y));
+ const int bottom_y_index =
+ std::min(static_cast<int>(ceilf(in_y)), (image_height - 1));
+ const float y_lerp = in_y - top_y_index;
+ tensorflow::uint8* in_top_row = in + (top_y_index * image_rowlen);
+ tensorflow::uint8* in_bottom_row = in + (bottom_y_index * image_rowlen);
+ float *out_row = out + (y * wanted_width * wanted_channels);
+ for (int x = 0; x < wanted_width; ++x) {
+ const float in_x = x * width_scale;
+ const int left_x_index = static_cast<int>(floorf(in_x));
+ const int right_x_index =
+ std::min(static_cast<int>(ceilf(in_x)), (image_width - 1));
+ tensorflow::uint8* in_top_left_pixel =
+ in_top_row + (left_x_index * wanted_channels);
+ tensorflow::uint8* in_top_right_pixel =
+ in_top_row + (right_x_index * wanted_channels);
+ tensorflow::uint8* in_bottom_left_pixel =
+ in_bottom_row + (left_x_index * wanted_channels);
+ tensorflow::uint8* in_bottom_right_pixel =
+ in_bottom_row + (right_x_index * wanted_channels);
+ const float x_lerp = in_x - left_x_index;
+ float *out_pixel = out_row + (x * wanted_channels);
+ for (int c = 0; c < wanted_channels; ++c) {
+ const float top_left((in_top_left_pixel[c] - input_mean) / input_std);
+ const float top_right((in_top_right_pixel[c] - input_mean) / input_std);
+ const float bottom_left((in_bottom_left_pixel[c] - input_mean) / input_std);
+ const float bottom_right((in_bottom_right_pixel[c] - input_mean) / input_std);
+ const float top = top_left + (top_right - top_left) * x_lerp;
+ const float bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ out_pixel[c] = top + (bottom - top) * y_lerp;
+ }
+ }
+ }
+
+ out_tensors->push_back(image_tensor);
+ return Status::OK();
+}
+
+// Reads a model graph definition from disk, and creates a session object you
+// can use to run it.
+Status LoadGraph(string graph_file_name,
+ std::unique_ptr<tensorflow::Session>* session) {
+ tensorflow::GraphDef graph_def;
+ Status load_graph_status =
+ ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
+ if (!load_graph_status.ok()) {
+ return tensorflow::errors::NotFound("Failed to load compute graph at '",
+ graph_file_name, "'");
+ }
+ session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
+ Status session_create_status = (*session)->Create(graph_def);
+ if (!session_create_status.ok()) {
+ return session_create_status;
+ }
+ return Status::OK();
+}
+
+// Analyzes the output of the Inception graph to retrieve the highest scores and
+// their positions in the tensor, which correspond to categories.
+Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
+ Tensor* out_indices, Tensor* out_scores) {
+ const Tensor& unsorted_scores_tensor = outputs[0];
+ auto unsorted_scores_flat = unsorted_scores_tensor.flat<float>();
+ std::vector<std::pair<int, float>> scores;
+ for (int i = 0; i < unsorted_scores_flat.size(); ++i) {
+ scores.push_back(std::pair<int, float>({i, unsorted_scores_flat(i)}));
+ }
+ std::sort(scores.begin(), scores.end(),
+ [](const std::pair<int, float> &left,
+ const std::pair<int, float> &right) {
+ return left.second > right.second;
+ });
+ scores.resize(how_many_labels);
+ Tensor sorted_indices(tensorflow::DT_INT32, {scores.size()});
+ Tensor sorted_scores(tensorflow::DT_FLOAT, {scores.size()});
+ for (int i = 0; i < scores.size(); ++i) {
+ sorted_indices.flat<int>()(i) = scores[i].first;
+ sorted_scores.flat<float>()(i) = scores[i].second;
+ }
+ *out_indices = sorted_indices;
+ *out_scores = sorted_scores;
+ return Status::OK();
+}
+
+// Given the output of a model run, and the name of a file containing the labels
+// this prints out the top five highest-scoring values.
+Status PrintTopLabels(const std::vector<Tensor>& outputs,
+ string labels_file_name) {
+ std::vector<string> labels;
+ size_t label_count;
+ Status read_labels_status =
+ ReadLabelsFile(labels_file_name, &labels, &label_count);
+ if (!read_labels_status.ok()) {
+ LOG(ERROR) << read_labels_status;
+ return read_labels_status;
+ }
+ const int how_many_labels = std::min(5, static_cast<int>(label_count));
+ Tensor indices;
+ Tensor scores;
+ TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
+ tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
+ tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
+ for (int pos = 0; pos < how_many_labels; ++pos) {
+ const int label_index = indices_flat(pos);
+ const float score = scores_flat(pos);
+ LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
+ }
+ return Status::OK();
+}
+
+// This is a testing function that returns whether the top label index is the
+// one that's expected.
+Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
+ bool* is_expected) {
+ *is_expected = false;
+ Tensor indices;
+ Tensor scores;
+ const int how_many_labels = 1;
+ TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
+ tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
+ if (indices_flat(0) != expected) {
+ LOG(ERROR) << "Expected label #" << expected << " but got #"
+ << indices_flat(0);
+ *is_expected = false;
+ } else {
+ *is_expected = true;
+ }
+ return Status::OK();
+}
+
+int main(int argc, char* argv[]) {
+ // These are the command-line flags the program can understand.
+ // They define where the graph and input data is located, and what kind of
+ // input the model expects. If you train your own model, or use something
+ // other than GoogLeNet you'll need to update these.
+ string image = "tensorflow/contrib/pi_examples/label_image/data/"
+ "grace_hopper.jpg";
+ string graph =
+ "tensorflow/contrib/pi_examples/label_image/data/"
+ "tensorflow_inception_stripped.pb";
+ string labels =
+ "tensorflow/contrib/pi_examples/label_image/data/"
+ "imagenet_comp_graph_label_strings.txt";
+ int32 input_width = 299;
+ int32 input_height = 299;
+ int32 input_mean = 128;
+ int32 input_std = 128;
+ string input_layer = "Mul";
+ string output_layer = "softmax";
+ bool self_test = false;
+ string root_dir = "";
+ const bool parse_result = tensorflow::ParseFlags(
+ &argc, argv, {Flag("image", &image), //
+ Flag("graph", &graph), //
+ Flag("labels", &labels), //
+ Flag("input_width", &input_width), //
+ Flag("input_height", &input_height), //
+ Flag("input_mean", &input_mean), //
+ Flag("input_std", &input_std), //
+ Flag("input_layer", &input_layer), //
+ Flag("output_layer", &output_layer), //
+ Flag("self_test", &self_test), //
+ Flag("root_dir", &root_dir)});
+ if (!parse_result) {
+ LOG(ERROR) << "Error parsing command-line flags.";
+ return -1;
+ }
+
+ // We need to call this to set up global state for TensorFlow.
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1];
+ return -1;
+ }
+
+ // First we load and initialize the model.
+ std::unique_ptr<tensorflow::Session> session;
+ string graph_path = tensorflow::io::JoinPath(root_dir, graph);
+ Status load_graph_status = LoadGraph(graph_path, &session);
+ if (!load_graph_status.ok()) {
+ LOG(ERROR) << load_graph_status;
+ return -1;
+ }
+
+ // Get the image from disk as a float array of numbers, resized and normalized
+ // to the specifications the main graph expects.
+ std::vector<Tensor> resized_tensors;
+ string image_path = tensorflow::io::JoinPath(root_dir, image);
+ Status read_tensor_status =
+ ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
+ input_std, &resized_tensors);
+ if (!read_tensor_status.ok()) {
+ LOG(ERROR) << read_tensor_status;
+ return -1;
+ }
+ const Tensor& resized_tensor = resized_tensors[0];
+
+ // Actually run the image through the model.
+ std::vector<Tensor> outputs;
+ Status run_status = session->Run({{input_layer, resized_tensor}},
+ {output_layer}, {}, &outputs);
+ if (!run_status.ok()) {
+ LOG(ERROR) << "Running model failed: " << run_status;
+ return -1;
+ } else {
+ LOG(INFO) << "Running model succeeded!";
+ }
+
+ // This is for automated testing to make sure we get the expected result with
+ // the default settings. We know that label 866 (military uniform) should be
+ // the top label for the Admiral Hopper image.
+ if (self_test) {
+ bool expected_matches;
+ Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
+ if (!check_status.ok()) {
+ LOG(ERROR) << "Running check failed: " << check_status;
+ return -1;
+ }
+ if (!expected_matches) {
+ LOG(ERROR) << "Self-test failed!";
+ return -1;
+ }
+ }
+
+ // Do something interesting with the results we've generated.
+ Status print_status = PrintTopLabels(outputs, labels);
+ if (!print_status.ok()) {
+ LOG(ERROR) << "Running print failed: " << print_status;
+ return -1;
+ }
+
+ return 0;
+}
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py
index 34bc61d06d..3ed2ee07f7 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph.py
@@ -15,7 +15,7 @@
r"""Transforms a float-trained graph into an equivalent quantized version.
An example of command-line usage is:
-bazel build tensorflow/contrib/quantization/tools/:quantize_graph \
+bazel build tensorflow/contrib/quantization/tools:quantize_graph \
&& bazel-bin/tensorflow/contrib/quantization/tools/quantize_graph \
--input=tensorflow_inception_graph.pb
--output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index f8f540c0c7..9823980e83 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -101,7 +101,7 @@ class GrpcMasterService : public AsyncServiceInterface {
} \
} while (0)
- void HandleRPCsLoop() {
+ void HandleRPCsLoop() override {
ENQUEUE_REQUEST(CreateSession, true);
ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index f8ad42a691..0d532520cf 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -73,7 +73,7 @@ class GrpcSession : public Session {
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs, RunMetadata* run_metadata);
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
Status Extend(const GraphDef& graph) override;
Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index bf3f413c66..bba579a6a8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -108,7 +108,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
} while (0)
// This method blocks forever handling requests from the completion queue.
- void HandleRPCsLoop() {
+ void HandleRPCsLoop() override {
// TODO(mrry): This may require performance engineering. We can
// add more threads to service the completion queue, and add more
// of various request types if they are short and frequent.
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index 1b30313694..8925e9a58f 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -837,7 +837,6 @@ struct gemm_pack_rhs<
EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
const Index packet_cols4 = (cols / 4) * 4;
- const bool non_standard_patches = rhs.nonStandardPatches();
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
diff --git a/tensorflow/core/kernels/reader_base.h b/tensorflow/core/kernels/reader_base.h
index 8c5bb308ec..3cb910751d 100644
--- a/tensorflow/core/kernels/reader_base.h
+++ b/tensorflow/core/kernels/reader_base.h
@@ -110,7 +110,7 @@ class ReaderBase : public ReaderInterface {
// In this implementation all the records come from the same work unit.
int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
std::vector<string>* keys, std::vector<string>* value,
- OpKernelContext* context);
+ OpKernelContext* context) override;
Status Reset() override;
int64 NumRecordsProduced() override;
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index 613c6a15c5..49c22306a9 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -40,6 +40,35 @@ EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
return reinterpret_cast<const float&>(tmp);
}
+// Specialization non-scalar version on non-sse.
+#ifndef EIGEN_VECTORIZE_SSE2
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
+ float r[4];
+ tensorflow::uint32 p[4];
+ pstoreu(r, from);
+ tensorflow::uint32 * ir = reinterpret_cast<tensorflow::uint32 *>(r);
+ p[0] = (ir[0] << 16) & 0xffff0000;
+ p[1] = ir[0]& 0xffff0000;
+ p[2] = (ir[1] << 16) & 0xffff0000;
+ p[3] = ir[1] & 0xffff0000;
+ return ploadu<Packet4f>(reinterpret_cast<float *>(p));
+}
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
+ float r[4];
+ tensorflow::uint32 p[4];
+ pstoreu(r, from);
+ tensorflow::uint32 * ir = reinterpret_cast<tensorflow::uint32 *>(r);
+ p[0] = (ir[2] << 16) & 0xffff0000;
+ p[1] = ir[2] & 0xffff0000;
+ p[2] = (ir[3] << 16) & 0xffff0000;
+ p[3] = ir[3] & 0xffff0000;
+ return ploadu<Packet4f>(reinterpret_cast<float *>(p));
+}
+#endif
+
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
return from;
@@ -72,16 +101,41 @@ template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pload4bf16(
const typename unpacket_traits<Packet>::type* from) {
assert(false && "Not applicable to Scalar Values");
- return *from;
+ return Packet();
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pload2bf16(
const typename unpacket_traits<Packet>::type* from) {
assert(false && "Not applicable to Scalar Values");
- return *from;
+ return Packet();
}
+// Specialization for pload4bf16 and pload2bf16 for non-sse.
+#ifndef EIGEN_VECTORIZE_SSE2
+template <>
+EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
+ tensorflow::uint32 p[4];
+ const tensorflow::uint32* ir = reinterpret_cast<const tensorflow::uint32 *>(from);
+ p[0] = (ir[0] << 16) & 0xffff0000;
+ p[1] = ir[0]& 0xffff0000;
+ p[2] = (ir[1] << 16) & 0xffff0000;
+ p[3] = ir[1] & 0xffff0000;
+ return ploadu<Packet4f>(reinterpret_cast<float *>(p));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
+ tensorflow::uint32 p[4];
+ const tensorflow::uint32* ir = reinterpret_cast<const tensorflow::uint32 *>(from);
+ p[0] = (ir[0] << 16) & 0xffff0000;
+ p[1] = ir[0]& 0xffff0000;
+ p[2] = (ir[0] << 16) & 0xffff0000;
+ p[3] = ir[0] & 0xffff0000;
+ return ploadu<Packet4f>(reinterpret_cast<float *>(p));
+}
+#endif
+
#ifdef EIGEN_VECTORIZE_SSE2
// For PacketSize of 4 floats the Packet is not modified
template <>
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc
index 5da991084c..403a7cf0ea 100644
--- a/tensorflow/core/lib/core/arena.cc
+++ b/tensorflow/core/lib/core/arena.cc
@@ -35,8 +35,6 @@ limitations under the License.
namespace tensorflow {
namespace core {
-static const int kPageSize = getpagesize();
-
// ----------------------------------------------------------------------
// Arena::Arena()
// Arena::~Arena()
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index eb194a14d4..73b0280a8f 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -29,9 +29,14 @@ RecordReader::RecordReader(RandomAccessFile* file,
const RecordReaderOptions& options)
: src_(file), options_(options) {
if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
+// We don't have zlib available on all embedded platforms, so fail.
+#if defined(IS_SLIM_BUILD)
+ LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
+#else // IS_SLIM_BUILD
zlib_input_buffer_.reset(new ZlibInputBuffer(
src_, options.zlib_options.input_buffer_size,
options.zlib_options.output_buffer_size, options.zlib_options));
+#endif // IS_SLIM_BUILD
} else if (options.compression_type == RecordReaderOptions::NONE) {
// Nothing to do.
} else {
@@ -53,6 +58,7 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
const size_t expected = n + sizeof(uint32);
storage->resize(expected);
+#if !defined(IS_SLIM_BUILD)
if (zlib_input_buffer_) {
// If we have a zlib compressed buffer, we assume that the
// file is being read sequentially, and we use the underlying
@@ -77,6 +83,7 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
}
*result = StringPiece(storage->data(), n);
} else {
+#endif // IS_SLIM_BUILD
// This version supports reading from arbitrary offsets
// since we are accessing the random access file directly.
StringPiece data;
@@ -93,7 +100,9 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
return errors::DataLoss("corrupted record at ", offset);
}
*result = StringPiece(data.data(), n);
+#if !defined(IS_SLIM_BUILD)
}
+#endif // IS_SLIM_BUILD
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index b4c56451be..e6e2a8c8ab 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -19,8 +19,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
+#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_inputbuffer.h"
+#endif // IS_SLIM_BUILD
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -35,8 +37,10 @@ class RecordReaderOptions {
enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
CompressionType compression_type = NONE;
+#if !defined(IS_SLIM_BUILD)
// Options specific to zlib compression.
ZlibCompressionOptions zlib_options;
+#endif // IS_SLIM_BUILD
};
class RecordReader {
@@ -59,7 +63,9 @@ class RecordReader {
RandomAccessFile* src_;
RecordReaderOptions options_;
+#if !defined(IS_SLIM_BUILD)
std::unique_ptr<ZlibInputBuffer> zlib_input_buffer_;
+#endif // IS_SLIM_BUILD
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 7993f6ca20..25873b83ba 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -26,9 +26,14 @@ RecordWriter::RecordWriter(WritableFile* dest,
const RecordWriterOptions& options)
: dest_(dest), options_(options) {
if (options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION) {
+// We don't have zlib available on all embedded platforms, so fail.
+#if defined(IS_SLIM_BUILD)
+ LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
+#else // IS_SLIM_BUILD
zlib_output_buffer_.reset(new ZlibOutputBuffer(
dest_, options.zlib_options.input_buffer_size,
options.zlib_options.output_buffer_size, options.zlib_options));
+#endif // IS_SLIM_BUILD
} else if (options.compression_type == RecordWriterOptions::NONE) {
// Nothing to do
} else {
@@ -37,12 +42,14 @@ RecordWriter::RecordWriter(WritableFile* dest,
}
RecordWriter::~RecordWriter() {
+#if !defined(IS_SLIM_BUILD)
if (zlib_output_buffer_) {
Status s = zlib_output_buffer_->Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
}
+#endif // IS_SLIM_BUILD
}
static uint32 MaskedCrc(const char* data, size_t n) {
@@ -62,16 +69,20 @@ Status RecordWriter::WriteRecord(StringPiece data) {
char footer[sizeof(uint32)];
core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
+#if !defined(IS_SLIM_BUILD)
if (zlib_output_buffer_) {
TF_RETURN_IF_ERROR(
zlib_output_buffer_->Write(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(zlib_output_buffer_->Write(data));
return zlib_output_buffer_->Write(StringPiece(footer, sizeof(footer)));
} else {
+#endif // IS_SLIM_BUILD
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(dest_->Append(data));
return dest_->Append(StringPiece(footer, sizeof(footer)));
+#if !defined(IS_SLIM_BUILD)
}
+#endif // IS_SLIM_BUILD
}
} // namespace io
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 2344df3b25..3d42a281de 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -18,8 +18,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
+#endif // IS_SLIM_BUILD
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -34,8 +36,10 @@ class RecordWriterOptions {
enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 };
CompressionType compression_type = NONE;
- // Options specific to zlib compression.
+// Options specific to zlib compression.
+#if !defined(IS_SLIM_BUILD)
ZlibCompressionOptions zlib_options;
+#endif // IS_SLIM_BUILD
};
class RecordWriter {
@@ -54,9 +58,11 @@ class RecordWriter {
// RecordWriter to the WritableFile. Does *not* flush the
// WritableFile.
Status Flush() {
+#if !defined(IS_SLIM_BUILD)
if (zlib_output_buffer_) {
return zlib_output_buffer_->Flush();
}
+#endif // IS_SLIM_BUILD
return Status::OK();
}
@@ -64,7 +70,9 @@ class RecordWriter {
private:
WritableFile* const dest_;
RecordWriterOptions options_;
+#if !defined(IS_SLIM_BUILD)
std::unique_ptr<ZlibOutputBuffer> zlib_output_buffer_;
+#endif // IS_SLIM_BUILD
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 9b01376577..6a3024d5bc 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Simple transfer learning with an Inception v3 architecture model.
+"""Simple transfer learning with an Inception v3 architecture model which
+displays summaries in TensorBoard.
This example shows how to take a Inception v3 architecture model trained on
ImageNet images, and train a new top layer that can recognize other classes of
@@ -49,6 +50,15 @@ in.
This produces a new model file that can be loaded and run by any TensorFlow
program, for example the label_image sample code.
+
+To use with TensorBoard:
+
+By default, this script will log summaries to /tmp/retrain_logs directory
+
+Visualize the summaries with this command:
+
+tensorboard --logdir /tmp/retrain_logs
+
"""
from __future__ import absolute_import
from __future__ import division
@@ -81,6 +91,8 @@ tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb',
"""Where to save the trained graph.""")
tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt',
"""Where to save the trained graph's labels.""")
+tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs',
+ """Where to save summary logs for TensorBoard.""")
# Details of the training configuration.
tf.app.flags.DEFINE_integer('how_many_training_steps', 4000,
@@ -650,6 +662,19 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
return jpeg_data, distort_result
+def variable_summaries(var, name):
+ """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
+ with tf.name_scope('summaries'):
+ mean = tf.reduce_mean(var)
+ tf.scalar_summary('mean/' + name, mean)
+ with tf.name_scope('stddev'):
+ stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
+ tf.scalar_summary('sttdev/' + name, stddev)
+ tf.scalar_summary('max/' + name, tf.reduce_max(var))
+ tf.scalar_summary('min/' + name, tf.reduce_min(var))
+ tf.histogram_summary(name, var)
+
+
def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
"""Adds a new softmax and fully-connected layer for training.
@@ -670,24 +695,43 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
The tensors for the training and cross entropy results, and tensors for the
bottleneck input and ground truth input.
"""
- bottleneck_input = tf.placeholder_with_default(
- bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
- name='BottleneckInputPlaceholder')
- layer_weights = tf.Variable(
- tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001),
- name='final_weights')
- layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
- logits = tf.matmul(bottleneck_input, layer_weights,
- name='final_matmul') + layer_biases
+ with tf.name_scope('input'):
+ bottleneck_input = tf.placeholder_with_default(
+ bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
+ name='BottleneckInputPlaceholder')
+
+ ground_truth_input = tf.placeholder(tf.float32,
+ [None, class_count],
+ name='GroundTruthInput')
+
+ # Organizing the following ops as `final_training_ops` so they're easier
+ # to see in TensorBoard
+ layer_name = 'final_training_ops'
+ with tf.name_scope(layer_name):
+ with tf.name_scope('weights'):
+ layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights')
+ variable_summaries(layer_weights, layer_name + '/weights')
+ with tf.name_scope('biases'):
+ layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
+ variable_summaries(layer_biases, layer_name + '/biases')
+ with tf.name_scope('Wx_plus_b'):
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+ tf.histogram_summary(layer_name + '/pre_activations', logits)
+
final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
- ground_truth_input = tf.placeholder(tf.float32,
- [None, class_count],
- name='GroundTruthInput')
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
+ tf.histogram_summary(final_tensor_name + '/activations', final_tensor)
+
+ with tf.name_scope('cross_entropy'):
+ cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits, ground_truth_input)
- cross_entropy_mean = tf.reduce_mean(cross_entropy)
- train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
- cross_entropy_mean)
+ with tf.name_scope('total'):
+ cross_entropy_mean = tf.reduce_mean(cross_entropy)
+ tf.scalar_summary('cross entropy', cross_entropy_mean)
+
+ with tf.name_scope('train'):
+ train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
+ cross_entropy_mean)
+
return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
final_tensor)
@@ -703,13 +747,22 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
Returns:
Nothing.
"""
- correct_prediction = tf.equal(
- tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1))
- evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
+ with tf.name_scope('accuracy'):
+ with tf.name_scope('correct_prediction'):
+ correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \
+ tf.argmax(ground_truth_tensor, 1))
+ with tf.name_scope('accuracy'):
+ evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ tf.scalar_summary('accuracy', evaluation_step)
return evaluation_step
def main(_):
+ # Setup the directory we'll write summaries to for TensorBoard
+ if tf.gfile.Exists(FLAGS.summaries_dir):
+ tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
+ tf.gfile.MakeDirs(FLAGS.summaries_dir)
+
# Set up the pre-trained graph.
maybe_download_and_extract()
graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
@@ -750,13 +803,19 @@ def main(_):
FLAGS.final_tensor_name,
bottleneck_tensor)
+ # Create the operations we need to evaluate the accuracy of our new layer.
+ evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
+
+ # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
+ merged = tf.merge_all_summaries()
+ train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train',
+ sess.graph)
+ validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation')
+
# Set up all our weights to their initial default values.
init = tf.initialize_all_variables()
sess.run(init)
- # Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
-
# Run the training for as many cycles as requested on the command line.
for i in range(FLAGS.how_many_training_steps):
# Get a catch of input bottleneck values, either calculated fresh every time
@@ -772,10 +831,12 @@ def main(_):
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor)
# Feed the bottlenecks and ground truth into the graph, and run a training
- # step.
- sess.run(train_step,
+ # step. Capture training summaries for TensorBoard with the `merged` op.
+ train_summary, _ = sess.run([merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
+ train_writer.add_summary(train_summary, i)
+
# Every so often, print out how well the graph is training.
is_last_step = (i + 1 == FLAGS.how_many_training_steps)
if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
@@ -792,10 +853,13 @@ def main(_):
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor))
- validation_accuracy = sess.run(
- evaluation_step,
+ # Run a validation step and capture training summaries for TensorBoard
+ # with the `merged` op.
+ validation_summary, validation_accuracy = sess.run(
+ [merged, evaluation_step],
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
+ validation_writer.add_summary(validation_summary, i)
print('%s: Step %d: Validation accuracy = %.1f%%' %
(datetime.now(), i, validation_accuracy * 100))
diff --git a/tensorflow/g3doc/api_docs/python/contrib.metrics.md b/tensorflow/g3doc/api_docs/python/contrib.metrics.md
index b7a136f87f..aaeed39411 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.metrics.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.metrics.md
@@ -106,7 +106,7 @@ idempotent operation that simply divides `total` by `count`.
To facilitate the estimation of the accuracy over a stream of data, the
function utilizes two operations. First, an `is_correct` operation that
computes a tensor whose shape matches `predictions` and whose elements are
-set to 1.0 when the corresponding values of `predictions` and `labels match
+set to 1.0 when the corresponding values of `predictions` and `labels` match
and 0.0 otherwise. Second, an `update_op` operation whose behavior is
dependent on the value of `weights`. If `weights` is None, then `update_op`
increments `total` with the number of elements of `predictions` that match
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 3a0fc8d28e..e3e80da7f1 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1340,18 +1340,14 @@ The following `DType` objects are defined:
* `tf.bfloat16`: 16-bit truncated floating-point.
* `tf.complex64`: 64-bit single-precision complex.
* `tf.complex128`: 128-bit double-precision complex.
-
* `tf.int8`: 8-bit signed integer.
* `tf.uint8`: 8-bit unsigned integer.
* `tf.uint16`: 16-bit unsigned integer.
* `tf.int16`: 16-bit signed integer.
* `tf.int32`: 32-bit signed integer.
* `tf.int64`: 64-bit signed integer.
-
* `tf.bool`: Boolean.
-
* `tf.string`: String.
-
* `tf.qint8`: Quantized 8-bit signed integer.
* `tf.quint8`: Quantized 8-bit unsigned integer.
* `tf.qint16`: Quantized 16-bit signed integer.
diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md
index b4289a986d..1603df9335 100644
--- a/tensorflow/g3doc/get_started/basic_usage.md
+++ b/tensorflow/g3doc/get_started/basic_usage.md
@@ -319,6 +319,6 @@ with tf.Session() as sess:
A `placeholder()` operation generates an error if you do not supply a feed for
it. See the
[MNIST fully-connected feed tutorial](../tutorials/mnist/tf/index.md)
-([source code](https://www.tensorflow.org/code/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py))
+([source code](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/fully_connected_feed.py))
for a larger-scale example of feeds.
diff --git a/tensorflow/g3doc/how_tos/image_retraining/index.md b/tensorflow/g3doc/how_tos/image_retraining/index.md
index 52c1d00a75..8ebbe57af1 100644
--- a/tensorflow/g3doc/how_tos/image_retraining/index.md
+++ b/tensorflow/g3doc/how_tos/image_retraining/index.md
@@ -117,6 +117,22 @@ to run since there's randomness in the training process. This number is based on
the percent of the images in the test set that are given the correct label
after the model is fully trained.
+## Visualizing the Retraining with TensorBoard
+
+The script includes TensorBoard summaries that make it easier to understand, debug, and optimize the retraining. For example, you can visualize the graph and statistics, such as how the weights or accuracy varied during training.
+
+To launch TensorBoard, run this command during or after retraining:
+
+```sh
+tensorboard --logdir /tmp/retrain_logs
+```
+
+Once TensorBoard is running, navigate your web browser to `localhost:6006` to view the TensorBoard.
+
+The script will log TensorBoard summaries to `/tmp/retrain_logs` by default. You can change the directory with the `--summaries_dir` flag.
+
+The [TensorBoard README](../../../tensorboard/README.md) has a lot more information on TensorBoard usage, including tips & tricks, and debugging information.
+
## Using the Retrained Model
The script will write out a version of the Inception v3 network with a final
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index 73cc87eb57..12de1df66c 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -21,8 +21,8 @@ TensorFlow session.
For your convenience, we've included
[a script](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py)
-which automatically downloads and imports the MNIST dataset. It will create a
-directory `'MNIST_data'` in which to store the data files.
+which will help you download and import the MNIST dataset. Run the following commands to create a
+directory `'MNIST_data'` in the current folder, the data files will be stored inside that directory.
```python
from tensorflow.examples.tutorials.mnist import input_data
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 85a5cc7444..1f29426b4c 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -34,18 +34,14 @@ class DType(object):
* `tf.bfloat16`: 16-bit truncated floating-point.
* `tf.complex64`: 64-bit single-precision complex.
* `tf.complex128`: 128-bit double-precision complex.
-
* `tf.int8`: 8-bit signed integer.
* `tf.uint8`: 8-bit unsigned integer.
* `tf.uint16`: 16-bit unsigned integer.
* `tf.int16`: 16-bit signed integer.
* `tf.int32`: 32-bit signed integer.
* `tf.int64`: 64-bit signed integer.
-
* `tf.bool`: Boolean.
-
* `tf.string`: String.
-
* `tf.qint8`: Quantized 8-bit signed integer.
* `tf.quint8`: Quantized 8-bit unsigned integer.
* `tf.qint16`: Quantized 16-bit signed integer.
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 87e3528196..4a21d1acc6 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -115,18 +115,22 @@ class UnaryOpTest(tf.test.TestCase):
x_init_value=x)
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
- def _check(self, result_tensor, result_np, input_sp_t):
+ def _check(self, result_tensor, result_np, input_sp_t, tol):
self.assertTrue(isinstance(result_tensor, tf.SparseTensor))
self.assertTrue(isinstance(input_sp_t, tf.SparseTensor))
self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
self.assertAllEqual(input_sp_t.shape.eval(), result_tensor.shape.eval())
- self.assertAllClose(result_np, result_tensor.values.eval())
+ if tol is None:
+ self.assertAllClose(result_np, result_tensor.values.eval())
+ else:
+ self.assertAllClose(result_np, result_tensor.values.eval(), rtol=tol,
+ atol=tol)
- def _compareSparseCpu(self, x, np_func, tf_func):
+ def _compareSparseCpu(self, x, np_func, tf_func, tol):
x_sp, x_sp_vals = _sparsify(x)
res_np = np_func(x_sp_vals)
with self.test_session(use_gpu=False):
- self._check(tf_func(x_sp), res_np, x_sp)
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
def _compareGpu(self, x, np_func, tf_func):
np_ans = np_func(x)
@@ -139,19 +143,19 @@ class UnaryOpTest(tf.test.TestCase):
self.assertAllClose(np_ans, tf_gpu)
# TODO(zhifengc/ke): make gradient checker work on GPU.
- def _compareSparseGpu(self, x, np_func, tf_func):
+ def _compareSparseGpu(self, x, np_func, tf_func, tol):
x_sp, x_sp_vals = _sparsify(x)
res_np = np_func(x_sp_vals)
with self.test_session(use_gpu=True):
- self._check(tf_func(x_sp), res_np, x_sp)
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
def _compareBoth(self, x, np_func, tf_func):
self._compareCpu(x, np_func, tf_func)
self._compareGpu(x, np_func, tf_func)
- def _compareBothSparse(self, x, np_func, tf_func):
- self._compareSparseCpu(x, np_func, tf_func)
- self._compareSparseGpu(x, np_func, tf_func)
+ def _compareBothSparse(self, x, np_func, tf_func, tol=None):
+ self._compareSparseCpu(x, np_func, tf_func, tol)
+ self._compareSparseGpu(x, np_func, tf_func, tol)
def _inv(self, x):
return 1.0 / x
@@ -207,6 +211,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(y, np.sign, tf.sign)
def testFloatTanhEdge(self):
@@ -243,6 +249,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.sign, tf.sign)
def testDoubleBasic(self):
@@ -278,6 +286,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(y, np.sign, tf.sign)
def testHalfBasic(self):
@@ -308,6 +318,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(y, np.sign, tf.sign)
def testInt32Basic(self):
@@ -321,6 +333,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(x, np.sign, tf.sign)
def testInt64Basic(self):
@@ -335,6 +348,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(x, np.sign, tf.sign)
def testComplex64Basic(self):
@@ -358,6 +372,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3)
# Numpy uses an incorrect definition of sign; use the right one instead.
def complex_sign(x):
@@ -386,6 +402,8 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.abs, tf.abs)
self._compareBothSparse(x, np.negative, tf.neg)
+ self._compareBothSparse(x, np.square, tf.square)
+ self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3)
# Numpy uses an incorrect definition of sign; use the right one instead.
def complex_sign(x):
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 4ec455fd61..90678929c0 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -91,6 +91,15 @@ class ShapeOpsTest(tf.test.TestCase):
self.assertAllEqual(np_ans, result)
self.assertShapeEqual(np_ans, tf_ans)
+ def _compareSizeSparse(self, x_np, use_gpu=False):
+ np_ans = np.asarray(np.size(x_np))
+ x_tf, unused_nnz = _sparsify(x_np)
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.size(x_tf)
+ result = tf_ans.eval()
+ self.assertAllEqual(np_ans, result)
+ self.assertShapeEqual(np_ans, tf_ans)
+
def _testCpu(self, x):
self._compareShape(x, use_gpu=False)
self._compareShapeN(x, use_gpu=False)
@@ -98,6 +107,7 @@ class ShapeOpsTest(tf.test.TestCase):
self._compareSize(x, use_gpu=False)
self._compareShapeSparse(x, use_gpu=False)
self._compareRankSparse(x, use_gpu=False)
+ self._compareSizeSparse(x, use_gpu=False)
def _testGpu(self, x):
self._compareShape(x, use_gpu=True)
@@ -106,6 +116,7 @@ class ShapeOpsTest(tf.test.TestCase):
self._compareSize(x, use_gpu=True)
self._compareShapeSparse(x, use_gpu=True)
self._compareRankSparse(x, use_gpu=True)
+ self._compareSizeSparse(x, use_gpu=True)
def _testAll(self, x):
self._testCpu(x)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 2dbd7b94f2..0167aba0e8 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -125,11 +125,39 @@ def shape(input, name=None):
"""
with ops.op_scope([input], name, "Shape") as name:
if isinstance(input, ops.SparseTensor):
- return input.shape
+ return gen_math_ops.cast(input.shape, dtypes.int32)
else:
return gen_array_ops.shape(input, name=name)
+def size(input, name=None):
+ """Returns the size of a tensor.
+
+ This operation returns an integer representing the number of elements in
+ `input`.
+
+ For example:
+
+ ```python
+ # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]]
+ size(t) ==> 12
+ ```
+
+ Args:
+ input: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `int32`.
+ """
+ with ops.op_scope([input], name, "Size") as name:
+ if isinstance(input, ops.SparseTensor):
+ return gen_math_ops._prod(gen_math_ops.cast(input.shape, dtypes.int32), 0,
+ name=name)
+ else:
+ return gen_array_ops.size(input, name=name)
+
+
def rank(input, name=None):
"""Returns the rank of a tensor.
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index d552221e60..8bfd9ce8bf 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -60,7 +60,7 @@ def _SumGrad(op, grad):
def _MinOrMaxGrad(op, grad):
- """Gradient for Max or Max. Amazingly it's precisely the same code."""
+ """Gradient for Min or Max. Amazingly it's precisely the same code."""
input_shape = array_ops.shape(op.inputs[0])
output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
y = op.outputs[0]
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 1647ceb1d5..d27cefc61d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -306,6 +306,48 @@ def sign(x, name=None):
return gen_math_ops.sign(x, name=name)
+def square(x, name=None):
+ """Computes square of x element-wise.
+
+ I.e., \\(y = x * x = x^2\\).
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`. Has the same type as `x`.
+ """
+ with ops.op_scope([x], name, "Square") as name:
+ if isinstance(x, ops.SparseTensor):
+ x_square = gen_math_ops.square(x.values, name=name)
+ return ops.SparseTensor(indices=x.indices, values=x_square, shape=x.shape)
+ else:
+ return gen_math_ops.square(x, name=name)
+
+
+def sqrt(x, name=None):
+ """Computes square root of x element-wise.
+
+ I.e., \\(y = \sqrt{x} = x^{1/2}\\).
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`, `complex64`, `complex128`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+ """
+ with ops.op_scope([x], name, "Sqrt") as name:
+ if isinstance(x, ops.SparseTensor):
+ x_sqrt = gen_math_ops.sqrt(x.values, name=name)
+ return ops.SparseTensor(indices=x.indices, values=x_sqrt, shape=x.shape)
+ else:
+ return gen_math_ops.sqrt(x, name=name)
+
+
def complex_abs(x, name=None):
r"""Computes the complex absolute value of a tensor.
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index 4b22b3f5d6..85f9e2cb86 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -101,9 +101,12 @@ def DEFINE_boolean(flag_name, default_value, docstring):
help=docstring,
default=default_value,
type=str2bool)
+
+ # Add negated version, stay consistent with argparse with regard to
+ # dashes in flag names.
_global_parser.add_argument('--no' + flag_name,
action='store_false',
- dest=flag_name)
+ dest=flag_name.replace('-', '_'))
# The internal google library defines the following alias, so we match
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py
index 473877eb1e..39d92bd399 100644
--- a/tensorflow/python/platform/flags_test.py
+++ b/tensorflow/python/platform/flags_test.py
@@ -31,6 +31,7 @@ flags.DEFINE_float("float_foo", 42.0, "HelpString")
flags.DEFINE_boolean("bool_foo", True, "HelpString")
flags.DEFINE_boolean("bool_negation", True, "HelpString")
+flags.DEFINE_boolean("bool-dash-negation", True, "HelpString")
flags.DEFINE_boolean("bool_a", False, "HelpString")
flags.DEFINE_boolean("bool_c", False, "HelpString")
flags.DEFINE_boolean("bool_d", True, "HelpString")
@@ -64,6 +65,10 @@ class FlagsTest(googletest.TestCase):
# --bool_flag=True sets to True
self.assertEqual(True, FLAGS.bool_c)
+ # --no before the flag mirrors argparse's behavior with
+ # regard to dashes in flag names
+ self.assertEqual(False, FLAGS.bool_dash_negation)
+
# --bool_flag=False sets to False
self.assertEqual(False, FLAGS.bool_d)
@@ -85,9 +90,9 @@ class FlagsTest(googletest.TestCase):
if __name__ == "__main__":
# Test command lines
- sys.argv.extend(["--bool_a", "--nobool_negation", "--bool_c=True",
- "--bool_d=False", "--bool_e=gibberish", "--unknown_flag",
- "and_argument"])
+ sys.argv.extend(["--bool_a", "--nobool_negation", "--nobool-dash-negation",
+ "--bool_c=True", "--bool_d=False", "--bool_e=gibberish",
+ "--unknown_flag", "and_argument"])
# googletest.main() tries to interpret the above flags, so use the
# direct functions instead.
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index fab3385cca..cce31ef4dc 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -72,7 +72,7 @@ string GetCudnnVersion() { return ""; }
}
/* static */ port::Status DsoLoader::GetLibcudaDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("cuda", ""),
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("cuda", "1"),
GetCudaDriverLibraryPath()),
dso_handle);
}
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 59cb647e45..07b10de08c 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -27,7 +27,7 @@ from setuptools import find_packages, setup, Command, Extension
from setuptools.command.install import install as InstallCommandBase
from setuptools.dist import Distribution
-_VERSION = '0.9.0'
+_VERSION = '0.8.0'
numpy_version = "1.8.2"
if platform.system() == "Darwin":
diff --git a/util/python/BUILD b/util/python/BUILD
index a610c29299..af05de2004 100644
--- a/util/python/BUILD
+++ b/util/python/BUILD
@@ -15,6 +15,7 @@ genrule(
name = "python_check",
srcs = [
"python_config.sh",
+ "configure_files"
],
outs = [
"python_checked",
@@ -22,3 +23,10 @@ genrule(
cmd = "OUTPUTDIR=\"$(@D)/\"; $(location :python_config.sh) --check && touch $$OUTPUTDIR/python_checked",
local = 1,
)
+
+filegroup(
+ name = "configure_files",
+ data = glob([
+ "*",
+ ])
+)