aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc6
-rw-r--r--tensorflow/core/framework/resource_mgr.cc6
-rw-r--r--tensorflow/core/framework/resource_mgr.h73
-rw-r--r--tensorflow/core/kernels/io.cc2
-rw-r--r--tensorflow/core/kernels/matrix_inverse_op.cc9
-rw-r--r--tensorflow/core/kernels/save_op_test.cc31
-rw-r--r--tensorflow/core/kernels/tile_ops.cc10
-rw-r--r--tensorflow/core/lib/random/random.cc2
-rw-r--r--tensorflow/core/ops/io_ops.cc4
-rw-r--r--tensorflow/core/ops/ops.pbtxt2
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h1
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.cc8
-rw-r--r--tensorflow/examples/android/BUILD18
-rwxr-xr-xtensorflow/examples/android/jni/libpthread.sobin14096 -> 0 bytes
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java18
-rw-r--r--tensorflow/examples/how_tos/reading_data/BUILD68
-rw-r--r--tensorflow/examples/how_tos/reading_data/__init__.py (renamed from tensorflow/g3doc/how_tos/reading_data/__init__.py)0
-rw-r--r--tensorflow/examples/how_tos/reading_data/convert_to_records.py (renamed from tensorflow/g3doc/how_tos/reading_data/convert_to_records.py)4
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py (renamed from tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py)17
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py (renamed from tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py)16
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py (renamed from tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py)5
-rw-r--r--tensorflow/examples/tutorials/__init__.py (renamed from tensorflow/g3doc/tutorials/mnist/__init__.py)0
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD115
-rw-r--r--tensorflow/examples/tutorials/mnist/__init__.py22
-rw-r--r--tensorflow/examples/tutorials/mnist/fully_connected_feed.py (renamed from tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py)14
-rw-r--r--tensorflow/examples/tutorials/mnist/input_data.py (renamed from tensorflow/g3doc/tutorials/mnist/input_data.py)0
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist.py (renamed from tensorflow/g3doc/tutorials/mnist/mnist.py)16
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_softmax.py (renamed from tensorflow/g3doc/tutorials/mnist/mnist_softmax.py)16
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py (renamed from tensorflow/g3doc/tutorials/mnist/mnist_with_summaries.py)20
-rw-r--r--tensorflow/examples/tutorials/word2vec/BUILD30
-rw-r--r--tensorflow/examples/tutorials/word2vec/__init__.py (renamed from tensorflow/g3doc/tutorials/word2vec/__init__.py)0
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py (renamed from tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py)0
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/index.md10
-rw-r--r--tensorflow/g3doc/resources/faq.md5
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md8
-rw-r--r--tensorflow/g3doc/tutorials/mnist/download/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md5
-rw-r--r--tensorflow/g3doc/tutorials/mnist/tf/index.md6
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/index.md6
-rw-r--r--tensorflow/python/framework/dtypes.py17
-rw-r--r--tensorflow/python/framework/dtypes_test.py12
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py110
-rw-r--r--tensorflow/python/ops/array_ops.py18
-rw-r--r--tensorflow/python/ops/clip_ops.py7
-rw-r--r--tensorflow/python/ops/gradients.py2
-rw-r--r--tensorflow/python/ops/rnn.py18
-rw-r--r--tensorflow/python/ops/rnn_cell.py73
-rw-r--r--tensorflow/python/summary/event_accumulator.py5
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py10
-rw-r--r--tensorflow/tensorboard/CHANGES4
-rw-r--r--tensorflow/tensorboard/TAG2
-rw-r--r--tensorflow/tensorboard/tensorboard.py7
-rw-r--r--tensorflow/tools/pip_package/BUILD1
54 files changed, 652 insertions, 216 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index fd5b2d5927..0f2a50a76f 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -114,7 +113,10 @@ DirectSession::DirectSession(const SessionOptions& options,
cancellation_manager_(new CancellationManager()) {
static bool init = InitModule(options);
CHECK(init); // Avoids compiler warning that init is unused.
- session_handle_ = strings::FpToString(random::New64());
+ // NOTE(mrry): We do not need to use a unique string for the session
+ // handle, because DirectSession owns its devices. This may change
+ // in future versions.
+ session_handle_ = "direct";
int devices_added = 0;
if (options.config.log_device_placement()) {
const string mapping_str = device_mgr_->DeviceMappingString();
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 71dc968312..3a7de172a6 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -41,7 +41,7 @@ void ResourceMgr::Clear() {
containers_.clear();
}
-Status ResourceMgr::DoCreate(const string& container, std::type_index type,
+Status ResourceMgr::DoCreate(const string& container, ResourceMgrTypeIndex type,
const string& name, ResourceBase* resource) {
{
mutex_lock l(mu_);
@@ -58,7 +58,7 @@ Status ResourceMgr::DoCreate(const string& container, std::type_index type,
type.name());
}
-Status ResourceMgr::DoLookup(const string& container, std::type_index type,
+Status ResourceMgr::DoLookup(const string& container, ResourceMgrTypeIndex type,
const string& name,
ResourceBase** resource) const {
mutex_lock l(mu_);
@@ -76,7 +76,7 @@ Status ResourceMgr::DoLookup(const string& container, std::type_index type,
return Status::OK();
}
-Status ResourceMgr::DoDelete(const string& container, std::type_index type,
+Status ResourceMgr::DoDelete(const string& container, ResourceMgrTypeIndex type,
const string& name) {
ResourceBase* base = nullptr;
{
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 8ce6cb604e..e931608d8a 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -71,6 +71,65 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
};
+// On Android, we would like to avoid using RTTI for smaller binary sizes. The
+// following #ifdef section provides a non-functional replacement for
+// std::type_index (with a minimal set of functions needed by ResourceMgr).
+#ifdef __ANDROID__
+
+// A thin TypeIndex class that mimics std::type_index but does not use RTTI. As
+// a result, it does not provide the actual name of the type, and only returns a
+// pre-baked string specifying that RTTI is disabled.
+// The hash code provided in this class is unique for each class. However, it is
+// generated at runtime so this hash code should not be serialized - the value
+// for the same type can change from different runs.
+class ResourceMgrTypeIndex {
+ public:
+ ResourceMgrTypeIndex(const ResourceMgrTypeIndex& src) : hash_(src.hash_) {}
+ ResourceMgrTypeIndex& operator=(const ResourceMgrTypeIndex& src) {
+ hash_ = src.hash_;
+ return *this;
+ }
+ bool operator==(const ResourceMgrTypeIndex& rhs) const {
+ return (hash_ == rhs.hash_);
+ }
+ bool operator!=(const ResourceMgrTypeIndex& rhs) const {
+ return (hash_ != rhs.hash_);
+ }
+ ~ResourceMgrTypeIndex() {}
+
+ string name() const { return "[RTTI disabled for Android]"; }
+ uint64 hash_code() const { return hash_; }
+
+ // Returns a ResourceMgrTypeIndex object that corresponds to a typename.
+ template <typename T>
+ static ResourceMgrTypeIndex Make() {
+ static bool hash_bit[1];
+ return ResourceMgrTypeIndex(
+ static_cast<uint64>(reinterpret_cast<intptr_t>(hash_bit)));
+ }
+
+ private:
+ // We hide the constructor to be private. One needs to create the templated
+ // Make<T>() function to create a ResourceMgrTypeIndex object.
+ ResourceMgrTypeIndex(const uint64 hash) : hash_(hash) {}
+ uint64 hash_;
+};
+
+template <typename T>
+inline ResourceMgrTypeIndex GetResourceMgrTypeIndex() {
+ return ResourceMgrTypeIndex::Make<T>();
+}
+
+#else // __ANDROID__
+
+typedef std::type_index ResourceMgrTypeIndex;
+template <typename T>
+inline ResourceMgrTypeIndex GetResourceMgrTypeIndex() {
+ return ResourceMgrTypeIndex(typeid(T));
+}
+
+#endif // __ANDROID__
+
class ResourceMgr {
public:
ResourceMgr();
@@ -122,7 +181,7 @@ class ResourceMgr {
void Clear();
private:
- typedef std::pair<std::type_index, string> Key;
+ typedef std::pair<ResourceMgrTypeIndex, string> Key;
struct KeyHash {
std::size_t operator()(const Key& k) const {
return Hash64(k.second.data(), k.second.size(), k.first.hash_code());
@@ -139,13 +198,13 @@ class ResourceMgr {
mutable mutex mu_;
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
- Status DoCreate(const string& container, std::type_index type,
+ Status DoCreate(const string& container, ResourceMgrTypeIndex type,
const string& name,
ResourceBase* resource) TF_MUST_USE_RESULT;
- Status DoLookup(const string& container, std::type_index type,
+ Status DoLookup(const string& container, ResourceMgrTypeIndex type,
const string& name,
ResourceBase** resource) const TF_MUST_USE_RESULT;
- Status DoDelete(const string& container, std::type_index type,
+ Status DoDelete(const string& container, ResourceMgrTypeIndex type,
const string& name) TF_MUST_USE_RESULT;
TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
@@ -223,7 +282,7 @@ Status ResourceMgr::Create(const string& container, const string& name,
T* resource) {
CheckDeriveFromResourceBase<T>();
CHECK(resource != nullptr);
- return DoCreate(container, std::type_index(typeid(T)), name, resource);
+ return DoCreate(container, GetResourceMgrTypeIndex<T>(), name, resource);
}
template <typename T>
@@ -231,7 +290,7 @@ Status ResourceMgr::Lookup(const string& container, const string& name,
T** resource) const {
CheckDeriveFromResourceBase<T>();
ResourceBase* found = nullptr;
- Status s = DoLookup(container, std::type_index(typeid(T)), name, &found);
+ Status s = DoLookup(container, GetResourceMgrTypeIndex<T>(), name, &found);
if (s.ok()) {
// It's safe to down cast 'found' to T* since
// typeid(T).hash_code() is part of the map key.
@@ -265,7 +324,7 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
template <typename T>
Status ResourceMgr::Delete(const string& container, const string& name) {
CheckDeriveFromResourceBase<T>();
- return DoDelete(container, std::type_index(typeid(T)), name);
+ return DoDelete(container, GetResourceMgrTypeIndex<T>(), name);
}
template <typename T>
diff --git a/tensorflow/core/kernels/io.cc b/tensorflow/core/kernels/io.cc
index 2af6bee079..d7443b8239 100644
--- a/tensorflow/core/kernels/io.cc
+++ b/tensorflow/core/kernels/io.cc
@@ -148,6 +148,7 @@ void SaveTensors(
break
switch (input.dtype()) {
+ WRITER_ADD(DT_BOOL);
WRITER_ADD(DT_FLOAT);
WRITER_ADD(DT_DOUBLE);
WRITER_ADD(DT_INT32);
@@ -269,6 +270,7 @@ void RestoreTensor(OpKernelContext* context,
break
switch (type) {
+ READER_COPY(DT_BOOL);
READER_COPY(DT_FLOAT);
READER_COPY(DT_DOUBLE);
READER_COPY(DT_INT32);
diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc
index d7c8149ceb..345d2ec250 100644
--- a/tensorflow/core/kernels/matrix_inverse_op.cc
+++ b/tensorflow/core/kernels/matrix_inverse_op.cc
@@ -51,11 +51,10 @@ class MatrixInverseOp
}
}
- using typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
- using
- typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
- using typename UnaryLinearAlgebraOp<Scalar,
- SupportsBatchOperationT>::ConstMatrixMap;
+ typedef UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> Base;
+ using Matrix = typename Base::Matrix;
+ using MatrixMap = typename Base::MatrixMap;
+ using ConstMatrixMap = typename Base::ConstMatrixMap;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
MatrixMap* output) override {
diff --git a/tensorflow/core/kernels/save_op_test.cc b/tensorflow/core/kernels/save_op_test.cc
index f05274129a..c09c8cfd55 100644
--- a/tensorflow/core/kernels/save_op_test.cc
+++ b/tensorflow/core/kernels/save_op_test.cc
@@ -44,8 +44,8 @@ class SaveOpTest : public OpsTestBase {
ASSERT_OK(NodeDefBuilder("myop", "Save")
.Input(FakeInput())
.Input(FakeInput())
- .Input(FakeInput(
- {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
+ .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
+ DT_QINT8, DT_QINT32}))
.Finalize(node_def()));
ASSERT_OK(InitOp());
}
@@ -53,7 +53,8 @@ class SaveOpTest : public OpsTestBase {
TEST_F(SaveOpTest, Simple) {
const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
- const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
+ const string tensornames[] = {"tensor_bool", "tensor_int",
+ "tensor_float", "tensor_double",
"tensor_qint8", "tensor_qint32"};
MakeOp();
@@ -62,9 +63,12 @@ TEST_F(SaveOpTest, Simple) {
[&filename](int x) -> string { return filename; });
// Add the tensor names
- AddInput<string>(TensorShape({5}),
+ AddInput<string>(TensorShape({6}),
[&tensornames](int x) -> string { return tensornames[x]; });
+ // Add a 1-d bool tensor
+ AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
+
// Add a 1-d integer tensor
AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
@@ -94,6 +98,25 @@ TEST_F(SaveOpTest, Simple) {
// We expect to find all saved tensors
{
+ // The 1-d bool tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_bool", &shape, &type));
+ TensorShape expected({2});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_BOOL, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-");
+ bool data[2];
+ std::fill_n(data, 2, false);
+ EXPECT_TRUE(reader.CopySliceData("tensor_bool", s, data));
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ((i != 0), data[i]);
+ }
+ }
+
+ {
// The 1-d integer tensor
TensorShape shape;
DataType type;
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index 6968169248..d4c6ecb6c7 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -132,9 +132,11 @@ template <DataType DT, int NDIM>
inline void TileOp<Device>::HandleCase(
OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array,
Tensor* result) {
+ // TODO(vrv): print out the device name if useful. Currently disabled to avoid
+ // having to use RTTI.
LOG(FATAL) << "TileOp: Invalid combination of Device, DT and NDIM: "
- << typeid(Device).name() << ", " << DataTypeString(DT) << ", "
- << NDIM;
+ // << typeid(Device).name() << ", "
+ << DataTypeString(DT) << ", " << NDIM;
}
#define HANDLE_CASE(device, dtype, ndim) \
@@ -353,7 +355,11 @@ inline void TileGradientOp<Device>::HandleCase(
OpKernelContext* context, const std::vector<int32>& input_dims,
const gtl::ArraySlice<int32>& multiples_array, Tensor* result) {
LOG(FATAL) << "TileGradientOp: Invalid combination of Device, DT and NDIM: "
+#ifdef __ANDROID__
+ << "[Device not shown, no RTTI], " << DataTypeString(DT) << ", "
+#else // __ANDROID__
<< typeid(Device).name() << ", " << DataTypeString(DT) << ", "
+#endif // __ANDROID__
<< NDIM;
}
diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc
index f0cecfd03c..8bb7730cfc 100644
--- a/tensorflow/core/lib/random/random.cc
+++ b/tensorflow/core/lib/random/random.cc
@@ -22,7 +22,7 @@ namespace tensorflow {
namespace random {
std::mt19937_64* InitRng() {
- std::random_device device("/dev/random");
+ std::random_device device("/dev/urandom");
return new std::mt19937_64(device());
}
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 62b06d3aa1..84f06e3ef7 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -22,7 +22,7 @@ REGISTER_OP("Save")
.Input("filename: string")
.Input("tensor_names: string")
.Input("data: T")
- .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Attr("T: list({bool, float, double, int32, int64, quint8, qint8, qint32})")
.Doc(R"doc(
Saves the input tensors to disk.
@@ -42,7 +42,7 @@ REGISTER_OP("SaveSlices")
.Input("tensor_names: string")
.Input("shapes_and_slices: string")
.Input("data: T")
- .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Attr("T: list({bool, float, double, int32, int64, quint8, qint8, qint32})")
.Doc(R"doc(
Saves input tensors slices to disk.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index aa8faa5144..98e435c157 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6163,6 +6163,7 @@ op {
minimum: 1
allowed_values {
list {
+ type: DT_BOOL
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -6205,6 +6206,7 @@ op {
minimum: 1
allowed_values {
list {
+ type: DT_BOOL
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
index 587d99cc07..0a7d67ee20 100644
--- a/tensorflow/core/util/saved_tensor_slice_util.h
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -90,6 +90,7 @@ void Fill(T* data, size_t n, TensorProto* t);
t->mutable_##FIELD##_val()->Swap(&copy); \
}
+TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool);
TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc
index 0de576fd28..a200a1834c 100644
--- a/tensorflow/core/util/tensor_slice_reader_cache.cc
+++ b/tensorflow/core/util/tensor_slice_reader_cache.cc
@@ -52,9 +52,17 @@ const TensorSliceReader* TensorSliceReaderCache::GetReader(
TensorSliceReader::OpenTableFunction open_function, int preferred_shard) {
mutex_lock l(mu_);
+#ifdef ANDROID
+ // On Android, we have RTTI disabled so we will hard-code func_ptr to be zero,
+ // since we cannot figure out the target type for open_function.
+ // TODO(jiayq): find a more elegant way to possibly enable cache again.
+ TensorSliceReaderCache::OpenFuncType* func_ptr = nullptr;
+#else // ANDROID
// Get the function pointer from the open_function value.
TensorSliceReaderCache::OpenFuncType* func_ptr =
open_function.target<TensorSliceReaderCache::OpenFuncType>();
+#endif
+
if (!func_ptr) {
// We could not get the pointer, no caching is possible.
LOG(WARNING) << "Caching disabled because the open function is a lambda.";
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 042e5e6e14..8a9d525b5a 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -11,7 +11,7 @@ exports_files(["LICENSE"])
cc_library(
name = "tensorflow_native_libs",
- srcs = glob(["jni/**/*.cc"]),
+ srcs = glob(["jni/**/*.cc"]) + [":libpthread.so"],
hdrs = glob(["jni/**/*.h"]),
copts = [
"-std=c++11",
@@ -23,18 +23,20 @@ cc_library(
"manual",
"notap",
],
- deps = [
- ":dummy_pthread",
- "//tensorflow/core:android_tensorflow_lib",
- ],
+ deps = ["//tensorflow/core:android_tensorflow_lib"],
)
# This library only exists as a workaround to satisfy dependencies
# that declare -lpthread in their linkopts. Although Android supports
# pthreads, it does not provide it as a separate library.
-cc_library(
- name = "dummy_pthread",
- srcs = ["jni/libpthread.so"],
+cc_binary(
+ name = "libpthread.so",
+ srcs = [],
+ linkopts = ["-shared"],
+ tags = [
+ "manual",
+ "notap",
+ ],
)
android_binary(
diff --git a/tensorflow/examples/android/jni/libpthread.so b/tensorflow/examples/android/jni/libpthread.so
deleted file mode 100755
index 7992d0de4c..0000000000
--- a/tensorflow/examples/android/jni/libpthread.so
+++ /dev/null
Binary files differ
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java
index 9c53021636..e09a3d403d 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java
@@ -180,9 +180,9 @@ public class CameraConnectionFragment extends Fragment {
private Handler backgroundHandler;
/**
- * An {@link ImageReader} that handles still image capture.
+ * An {@link ImageReader} that handles preview frame capture.
*/
- private ImageReader imageReader;
+ private ImageReader previewReader;
/**
* {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview
@@ -328,10 +328,6 @@ public class CameraConnectionFragment extends Fragment {
Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)),
new CompareSizesByArea());
- imageReader =
- ImageReader.newInstance(
- largest.getWidth(), largest.getHeight(), ImageFormat.YUV_420_888, /*maxImages*/ 2);
-
// Danger, W.R.! Attempting to use too large a preview size could exceed the camera
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
// garbage capture data.
@@ -393,9 +389,9 @@ public class CameraConnectionFragment extends Fragment {
cameraDevice.close();
cameraDevice = null;
}
- if (null != imageReader) {
- imageReader.close();
- imageReader = null;
+ if (null != previewReader) {
+ previewReader.close();
+ previewReader = null;
}
} catch (final InterruptedException e) {
throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
@@ -465,7 +461,7 @@ public class CameraConnectionFragment extends Fragment {
LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight());
// Create the reader for the preview frames.
- final ImageReader previewReader =
+ previewReader =
ImageReader.newInstance(
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
@@ -474,7 +470,7 @@ public class CameraConnectionFragment extends Fragment {
// Here, we create a CameraCaptureSession for camera preview.
cameraDevice.createCaptureSession(
- Arrays.asList(surface, imageReader.getSurface(), previewReader.getSurface()),
+ Arrays.asList(surface, previewReader.getSurface()),
new CameraCaptureSession.StateCallback() {
@Override
diff --git a/tensorflow/examples/how_tos/reading_data/BUILD b/tensorflow/examples/how_tos/reading_data/BUILD
new file mode 100644
index 0000000000..c1e773d905
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/BUILD
@@ -0,0 +1,68 @@
+# Description:
+# Example MNIST TensorFlow models for demonstrating data reading.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "convert_to_records",
+ srcs = ["convert_to_records.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_reader",
+ srcs = [
+ "fully_connected_reader.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_preloaded",
+ srcs = [
+ "fully_connected_preloaded.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_preloaded_var",
+ srcs = [
+ "fully_connected_preloaded_var.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/g3doc/how_tos/reading_data/__init__.py b/tensorflow/examples/how_tos/reading_data/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/tensorflow/g3doc/how_tos/reading_data/__init__.py
+++ b/tensorflow/examples/how_tos/reading_data/__init__.py
diff --git a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
index ce3b016798..30b5a384a8 100644
--- a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
@@ -23,7 +23,7 @@ import tensorflow.python.platform
import numpy
import tensorflow as tf
-from tensorflow.g3doc.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import input_data
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames
@@ -32,7 +32,7 @@ TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
-tf.app.flags.DEFINE_string('directory', 'data',
+tf.app.flags.DEFINE_string('directory', '/tmp/data',
'Directory to download data files and write the '
'converted result')
tf.app.flags.DEFINE_integer('validation_size', 5000,
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
index 79f945a845..39ce1a759b 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
@@ -15,10 +15,16 @@
"""Trains the MNIST network using preloaded data in a constant.
-Command to run this py_binary target:
+Run using bazel:
bazel run -c opt \
- <...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded
+ <...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded
+
+or, if installed via pip:
+
+cd tensorflow/examples/how_tos/reading_data
+python fully_connected_preloaded.py
+
"""
from __future__ import absolute_import
from __future__ import division
@@ -31,8 +37,8 @@ import tensorflow.python.platform
import numpy
import tensorflow as tf
-from tensorflow.g3doc.tutorials.mnist import input_data
-from tensorflow.g3doc.tutorials.mnist import mnist
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
@@ -44,7 +50,8 @@ flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size. '
'Must divide evenly into the dataset sizes.')
-flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
index 68c8fce7dd..9a7e4e8e81 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -15,10 +15,15 @@
"""Trains the MNIST network using preloaded data stored in a variable.
-Command to run this py_binary target:
+Run using bazel:
bazel run -c opt \
- <...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded_var
+ <...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded_var
+
+or, if installed via pip:
+
+cd tensorflow/examples/how_tos/reading_data
+python fully_connected_preloaded_var.py
"""
from __future__ import absolute_import
from __future__ import division
@@ -31,8 +36,8 @@ import tensorflow.python.platform
import numpy
import tensorflow as tf
-from tensorflow.g3doc.tutorials.mnist import input_data
-from tensorflow.g3doc.tutorials.mnist import mnist
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
@@ -44,7 +49,8 @@ flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size. '
'Must divide evenly into the dataset sizes.')
-flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index cebedd5f3c..bf1ef08c60 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -34,7 +34,7 @@ import tensorflow.python.platform
import numpy
import tensorflow as tf
-from tensorflow.g3doc.tutorials.mnist import mnist
+from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
@@ -45,7 +45,8 @@ flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size.')
-flags.DEFINE_string('train_dir', 'data', 'Directory with the training data.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory with the training data.')
# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
diff --git a/tensorflow/g3doc/tutorials/mnist/__init__.py b/tensorflow/examples/tutorials/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/tensorflow/g3doc/tutorials/mnist/__init__.py
+++ b/tensorflow/examples/tutorials/__init__.py
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
new file mode 100644
index 0000000000..4fc90730d5
--- /dev/null
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -0,0 +1,115 @@
+# Description:
+# Example TensorFlow models for MNIST used in tutorials
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "package",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":input_data",
+ ":mnist",
+ ],
+)
+
+py_library(
+ name = "input_data",
+ srcs = ["input_data.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_library(
+ name = "mnist",
+ srcs = [
+ "mnist.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_feed",
+ srcs = [
+ "fully_connected_feed.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ ":mnist",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "mnist_with_summaries",
+ srcs = [
+ "mnist_with_summaries.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "fully_connected_feed_test",
+ size = "small",
+ srcs = [
+ "fully_connected_feed.py",
+ ],
+ args = [
+ "--fake_data",
+ "--max_steps=10",
+ "--train_dir=/tmp/mnist",
+ ],
+ main = "fully_connected_feed.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ ":mnist",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "mnist_with_summaries_test",
+ size = "small",
+ srcs = [
+ "mnist_with_summaries.py",
+ ],
+ args = [
+ "--fake_data",
+ "--max_steps=10",
+ "--learning_rate=0.00",
+ ],
+ main = "mnist_with_summaries.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/examples/tutorials/mnist/__init__.py b/tensorflow/examples/tutorials/mnist/__init__.py
new file mode 100644
index 0000000000..026a1fcf64
--- /dev/null
+++ b/tensorflow/examples/tutorials/mnist/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Imports mnist tutorial libraries used by tutorial examples."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
diff --git a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
index 0242c81563..3f5beff6c5 100644
--- a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
@@ -13,15 +13,7 @@
# limitations under the License.
# ==============================================================================
-"""Trains and Evaluates the MNIST network using a feed dictionary.
-
-TensorFlow install instructions:
-https://tensorflow.org/get_started/os_setup.html
-
-MNIST tutorial:
-https://tensorflow.org/tutorials/mnist/tf/index.html
-
-"""
+"""Trains and Evaluates the MNIST network using a feed dictionary."""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
@@ -35,8 +27,8 @@ import numpy
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-from tensorflow.g3doc.tutorials.mnist import input_data
-from tensorflow.g3doc.tutorials.mnist import mnist
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py
index ae3727c82e..ae3727c82e 100644
--- a/tensorflow/g3doc/tutorials/mnist/input_data.py
+++ b/tensorflow/examples/tutorials/mnist/input_data.py
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist.py b/tensorflow/examples/tutorials/mnist/mnist.py
index 4cb53b7deb..096344e0b4 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist.py
+++ b/tensorflow/examples/tutorials/mnist/mnist.py
@@ -25,12 +25,6 @@ apply gradients.
This file is used by the various "fully_connected_*.py" files and not meant to
be run.
-
-TensorFlow install instructions:
-https://tensorflow.org/get_started/os_setup.html
-
-MNIST tutorial:
-https://tensorflow.org/tutorials/mnist/tf/index.html
"""
from __future__ import absolute_import
from __future__ import division
@@ -54,14 +48,14 @@ def inference(images, hidden1_units, hidden2_units):
Args:
images: Images placeholder, from inputs().
- hidden1: Size of the first hidden layer.
- hidden2: Size of the second hidden layer.
+ hidden1_units: Size of the first hidden layer.
+ hidden2_units: Size of the second hidden layer.
Returns:
softmax_linear: Output tensor with the computed logits.
"""
# Hidden 1
- with tf.name_scope('hidden1') as scope:
+ with tf.name_scope('hidden1'):
weights = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
@@ -70,7 +64,7 @@ def inference(images, hidden1_units, hidden2_units):
name='biases')
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
- with tf.name_scope('hidden2') as scope:
+ with tf.name_scope('hidden2'):
weights = tf.Variable(
tf.truncated_normal([hidden1_units, hidden2_units],
stddev=1.0 / math.sqrt(float(hidden1_units))),
@@ -79,7 +73,7 @@ def inference(images, hidden1_units, hidden2_units):
name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
- with tf.name_scope('softmax_linear') as scope:
+ with tf.name_scope('softmax_linear'):
weights = tf.Variable(
tf.truncated_normal([hidden2_units, NUM_CLASSES],
stddev=1.0 / math.sqrt(float(hidden2_units))),
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py b/tensorflow/examples/tutorials/mnist/mnist_softmax.py
index 050d660038..aace92f092 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_softmax.py
@@ -23,21 +23,23 @@ from __future__ import division
from __future__ import print_function
# Import data
-import input_data
-mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
+
+mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+
sess = tf.InteractiveSession()
# Create the model
x = tf.placeholder("float", [None, 784])
-W = tf.Variable(tf.zeros([784,10]))
+W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
-y = tf.nn.softmax(tf.matmul(x,W) + b)
+y = tf.nn.softmax(tf.matmul(x, W) + b)
# Define loss and optimizer
-y_ = tf.placeholder("float", [None,10])
-cross_entropy = -tf.reduce_sum(y_*tf.log(y))
+y_ = tf.placeholder("float", [None, 10])
+cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# Train
@@ -47,6 +49,6 @@ for i in range(1000):
train_step.run({x: batch_xs, y_: batch_ys})
# Test trained model
-correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
+correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 95373404d6..9d2d624d0c 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -30,7 +30,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
-from tensorflow.g3doc.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
flags = tf.app.flags
@@ -54,28 +54,28 @@ def main(_):
b = tf.Variable(tf.zeros([10], name='bias'))
# use a name scope to organize nodes in the graph visualizer
- with tf.name_scope('Wx_b') as scope:
+ with tf.name_scope('Wx_b'):
y = tf.nn.softmax(tf.matmul(x, W) + b)
# Add summary ops to collect data
- w_hist = tf.histogram_summary('weights', W)
- b_hist = tf.histogram_summary('biases', b)
- y_hist = tf.histogram_summary('y', y)
+ _ = tf.histogram_summary('weights', W)
+ _ = tf.histogram_summary('biases', b)
+ _ = tf.histogram_summary('y', y)
# Define loss and optimizer
y_ = tf.placeholder('float', [None, 10], name='y-input')
# More name scopes will clean up the graph representation
- with tf.name_scope('xent') as scope:
+ with tf.name_scope('xent'):
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
- ce_summ = tf.scalar_summary('cross entropy', cross_entropy)
- with tf.name_scope('train') as scope:
+ _ = tf.scalar_summary('cross entropy', cross_entropy)
+ with tf.name_scope('train'):
train_step = tf.train.GradientDescentOptimizer(
FLAGS.learning_rate).minimize(cross_entropy)
- with tf.name_scope('test') as scope:
+ with tf.name_scope('test'):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
- accuracy_summary = tf.scalar_summary('accuracy', accuracy)
+ _ = tf.scalar_summary('accuracy', accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
diff --git a/tensorflow/examples/tutorials/word2vec/BUILD b/tensorflow/examples/tutorials/word2vec/BUILD
new file mode 100644
index 0000000000..03e8c845a5
--- /dev/null
+++ b/tensorflow/examples/tutorials/word2vec/BUILD
@@ -0,0 +1,30 @@
+# Description:
+# TensorFlow model for word2vec
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "word2vec_basic",
+ srcs = [
+ "word2vec_basic.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/g3doc/tutorials/word2vec/__init__.py b/tensorflow/examples/tutorials/word2vec/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/tensorflow/g3doc/tutorials/word2vec/__init__.py
+++ b/tensorflow/examples/tutorials/word2vec/__init__.py
diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index e04e86a100..e04e86a100 100644
--- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
diff --git a/tensorflow/g3doc/how_tos/reading_data/index.md b/tensorflow/g3doc/how_tos/reading_data/index.md
index e8ad141fea..7caff72c2e 100644
--- a/tensorflow/g3doc/how_tos/reading_data/index.md
+++ b/tensorflow/g3doc/how_tos/reading_data/index.md
@@ -35,7 +35,7 @@ it is executed without a feed, so you won't forget to feed it.
An example using `placeholder` and feeding to train on MNIST data can be found
in
-[`tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py),
+[`tensorflow/examples/tutorials/mnist/fully_connected_feed.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py),
and is described in the [MNIST tutorial](../../tutorials/mnist/tf/index.md).
## Reading from files
@@ -154,7 +154,7 @@ as a field). You write a little program that gets your data, stuffs it in an
writes the string to a TFRecords file using the
[`tf.python_io.TFRecordWriter` class](../../api_docs/python/python_io.md#TFRecordWriter).
For example,
-[`tensorflow/g3doc/how_tos/reading_data/convert_to_records.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py)
+[`tensorflow/examples/how_tos/reading_data/convert_to_records.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/how_tos/reading_data/convert_to_records.py)
converts MNIST data to this format.
To read a file of TFRecords, use
@@ -163,7 +163,7 @@ the [`tf.parse_single_example`](../../api_docs/python/io_ops.md#parse_single_exa
decoder. The `parse_single_example` op decodes the example protocol buffers into
tensors. An MNIST example using the data produced by `convert_to_records` can be
found in
-[`tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py),
+[`tensorflow/examples/how_tos/reading_data/fully_connected_reader.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py),
which you can compare with the `fully_connected_feed` version.
### Preprocessing
@@ -455,8 +455,8 @@ multiple preprocessing threads, set the `num_threads` parameter to a number
bigger than 1.
An MNIST example that preloads the data using constants can be found in
-[`tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py), and one that preloads the data using variables can be found in
-[`tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py),
+[`tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py), and one that preloads the data using variables can be found in
+[`tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py),
You can compare these with the `fully_connected_feed` and
`fully_connected_reader` versions above.
diff --git a/tensorflow/g3doc/resources/faq.md b/tensorflow/g3doc/resources/faq.md
index 171697c69e..0db42a0b8b 100644
--- a/tensorflow/g3doc/resources/faq.md
+++ b/tensorflow/g3doc/resources/faq.md
@@ -255,6 +255,11 @@ these summaries to a log directory. Then, start TensorBoard using
For more details, see the [Summaries and TensorBoard tutorial]
(../how_tos/summaries_and_tensorboard/index.md).
+#### Every time I launch TensorBoard, I get a network security popup!
+
+You can change TensorBoard to serve on localhost rather than '0.0.0.0' by
+the flag --host=localhost. This should quiet any security warnings.
+
## Extending TensorFlow
See also the how-to documentation for
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index 99453d065d..2d4d6c566c 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -3,7 +3,9 @@
*This tutorial is intended for readers who are new to both machine learning and
TensorFlow. If you already
know what MNIST is, and what softmax (multinomial logistic) regression is,
-you might prefer this [faster paced tutorial](../pros/index.md).*
+you might prefer this [faster paced tutorial](../pros/index.md).
+Be sure to [install TensorFlow](../../../get_started/os_setup.md) before
+starting either tutorial.*
When one learns how to program, there's a tradition that the first thing you do
is print "Hello World." Just like programming has Hello World, machine learning
@@ -37,11 +39,11 @@ The MNIST data is hosted on
[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). For your
convenience, we've included some python code to download and install the data
automatically. You can either download
-[the code](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py)
+[the code](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/input_data.py)
and import it as below, or simply copy and paste it in.
```python
-import input_data
+import tensorflow.examples.tutorials.mnist.input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
```
diff --git a/tensorflow/g3doc/tutorials/mnist/download/index.md b/tensorflow/g3doc/tutorials/mnist/download/index.md
index 3cb9528f34..dcd7dfc23d 100644
--- a/tensorflow/g3doc/tutorials/mnist/download/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/download/index.md
@@ -1,6 +1,6 @@
# MNIST Data Download
-Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
+Code: [tensorflow/examples/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/)
The goal of this tutorial is to show how to download the dataset files required
for handwritten digit classification using the (classic) MNIST data set.
@@ -11,7 +11,7 @@ This tutorial references the following files:
File | Purpose
--- | ---
-[`input_data.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py) | The code to download the MNIST dataset for training and evaluation.
+[`input_data.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/input_data.py) | The code to download the MNIST dataset for training and evaluation.
## Prepare the Data
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index 4d8b5e84bd..866d4c8367 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -9,7 +9,8 @@ while constructing a deep convolutional MNIST classifier.
*This introduction assumes familiarity with neural networks and the MNIST
dataset. If you don't have
a background with them, check out the
-[introduction for beginners](../beginners/index.md).*
+[introduction for beginners](../beginners/index.md). Be sure to
+[install TensorFlow](../../../get_started/os_setup.md) before starting.*
## Setup
@@ -19,7 +20,7 @@ TensorFlow session.
### Load MNIST Data
For your convenience, we've included
-[a script](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py)
+[a script](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/input_data.py)
which automatically downloads and imports the MNIST dataset. It will create a
directory `'MNIST_data'` in which to store the data files.
diff --git a/tensorflow/g3doc/tutorials/mnist/tf/index.md b/tensorflow/g3doc/tutorials/mnist/tf/index.md
index 373b8968c5..94418481a6 100644
--- a/tensorflow/g3doc/tutorials/mnist/tf/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/tf/index.md
@@ -1,6 +1,6 @@
# TensorFlow Mechanics 101
-Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
+Code: [tensorflow/examples/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/)
The goal of this tutorial is to show how to use TensorFlow to train and
evaluate a simple feed-forward neural network for handwritten digit
@@ -18,8 +18,8 @@ This tutorial references the following files:
File | Purpose
--- | ---
-[`mnist.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/mnist.py) | The code to build a fully-connected MNIST model.
-[`fully_connected_feed.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py) | The main code to train the built MNIST model against the downloaded dataset using a feed dictionary.
+[`mnist.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/mnist.py) | The code to build a fully-connected MNIST model.
+[`fully_connected_feed.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py) | The main code to train the built MNIST model against the downloaded dataset using a feed dictionary.
Simply run the `fully_connected_feed.py` file directly to start training:
diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md
index f026d8c5b5..1882c56265 100644
--- a/tensorflow/g3doc/tutorials/word2vec/index.md
+++ b/tensorflow/g3doc/tutorials/word2vec/index.md
@@ -19,7 +19,7 @@ represent words as vectors.
We walk through the code later during the tutorial, but if you'd prefer to dive
straight in, feel free to look at the minimalistic implementation in
-[tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py)
+[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py)
This basic example contains the code needed to download some data, train on it a
bit and visualize the result. Once you get comfortable with reading and running
the basic version, you can graduate to
@@ -269,7 +269,7 @@ nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
Now that we have the parameters in place, we can define our skip-gram model
graph. For simplicity, let's suppose we've already integerized our text corpus
with a vocabulary so that each word is represented as an integer (see
-[tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py)
+[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py)
for the details). The skip-gram model takes two inputs. One is a batch full of
integers representing the source context words, the other is for the target
words. Let's create placeholder nodes for these inputs, so that we can feed in
@@ -321,7 +321,7 @@ for inputs, labels in generate_batch(...):
```
See the full example code in
-[tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py](./word2vec_basic.py).
+[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py).
## Visualizing the Learned Embeddings
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 48bf6cac00..52cd56efbf 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -59,8 +59,10 @@ class DType(object):
@@base_dtype
@@is_ref_dtype
@@as_ref
+ @@is_floating
@@is_integer
@@is_quantized
+ @@is_unsigned
@@as_numpy_dtype
@@as_datatype_enum
@@ -137,6 +139,21 @@ class DType(object):
return self.base_dtype in [qint8, quint8, qint32, bfloat16]
@property
+ def is_unsigned(self):
+ """Returns whether this type is unsigned.
+
+ Non-numeric, unordered, and quantized types are not considered unsigned, and
+ this function returns `False`.
+
+ Returns:
+ Whether a `DType` is unsigned.
+ """
+ try:
+ return self.min == 0
+ except TypeError:
+ return False
+
+ @property
def min(self):
"""Returns the minimum representable value in this data type.
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index 6a05208410..f82ba741dc 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -145,6 +145,18 @@ class TypesTest(test_util.TensorFlowTestCase):
self.assertEqual(tf.as_dtype("string").is_floating, False)
self.assertEqual(tf.as_dtype("bool").is_floating, False)
+ def testIsUnsigned(self):
+ self.assertEqual(tf.as_dtype("int8").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("int16").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("int32").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("int64").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("uint8").is_unsigned, True)
+ self.assertEqual(tf.as_dtype("float32").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("float64").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("bool").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("string").is_unsigned, False)
+ self.assertEqual(tf.as_dtype("complex64").is_unsigned, False)
+
def testMinMax(self):
# make sure min/max evaluates for all data types that have min/max
for datatype_enum in types_pb2.DataType.values():
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index f3fc9086d6..12eb7fcc2f 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -106,6 +106,11 @@ class ReshapeTest(tf.test.TestCase):
with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
tf.reshape(y, [17, -1])
+ z = tf.constant(0.0, shape=[32, 128])
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot reshape a tensor with 4096 elements"):
+ tf.reshape(z, [4095])
+
def testPartialShapes(self):
x = tf.placeholder(tf.float32)
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 604936a0d5..29431891b3 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -64,7 +64,10 @@ class RNNTest(tf.test.TestCase):
def testRNN(self):
cell = Plus1RNNCell()
batch_size = 2
- inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ input_size = 5
+ max_length = 8 # unrolled up to this length
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
for out, inp in zip(outputs, inputs):
@@ -72,7 +75,7 @@ class RNNTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session(use_gpu=False) as sess:
- input_value = np.random.randn(batch_size, 5)
+ input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [states[-1]],
feed_dict={inputs[0]: input_value})
@@ -82,14 +85,18 @@ class RNNTest(tf.test.TestCase):
# Final state
self.assertAllClose(
- values[-1], 10.0*np.ones((batch_size, 5), dtype=np.float32))
+ values[-1],
+ max_length * np.ones((batch_size, input_size), dtype=np.float32))
def testDropout(self):
cell = Plus1RNNCell()
full_dropout_cell = tf.nn.rnn_cell.DropoutWrapper(
cell, input_keep_prob=1e-12, seed=0)
batch_size = 2
- inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ input_size = 5
+ max_length = 8
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
outputs, states = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("drop_scope"):
@@ -101,7 +108,7 @@ class RNNTest(tf.test.TestCase):
self.assertEqual(out.dtype, inp.dtype)
with self.test_session(use_gpu=False) as sess:
- input_value = np.random.randn(batch_size, 5)
+ input_value = np.random.randn(batch_size, input_size)
values = sess.run(outputs + [states[-1]],
feed_dict={inputs[0]: input_value})
full_dropout_values = sess.run(dropped_outputs,
@@ -116,7 +123,10 @@ class RNNTest(tf.test.TestCase):
cell = Plus1RNNCell()
sequence_length = tf.placeholder(tf.int64)
batch_size = 2
- inputs = [tf.placeholder(tf.float32, shape=(batch_size, 5))] * 10
+ input_size = 5
+ max_length = 8
+ inputs = max_length * [
+ tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("drop_scope"):
dynamic_outputs, dynamic_states = tf.nn.rnn(
cell, inputs, sequence_length=sequence_length, dtype=tf.float32)
@@ -124,7 +134,7 @@ class RNNTest(tf.test.TestCase):
self.assertEqual(len(dynamic_states), len(inputs))
with self.test_session(use_gpu=False) as sess:
- input_value = np.random.randn(batch_size, 5)
+ input_value = np.random.randn(batch_size, input_size)
dynamic_values = sess.run(dynamic_outputs,
feed_dict={inputs[0]: input_value,
sequence_length: [2, 3]})
@@ -136,7 +146,8 @@ class RNNTest(tf.test.TestCase):
for v in dynamic_values[:3]:
self.assertAllClose(v, input_value + 1.0)
for vi, v in enumerate(dynamic_state_values[:3]):
- self.assertAllEqual(v, 1.0 * (vi + 1) * np.ones((batch_size, 5)))
+ self.assertAllEqual(v, 1.0 * (vi + 1) *
+ np.ones((batch_size, input_size)))
# zeros for t = 3+
for v in dynamic_values[3:]:
self.assertAllEqual(v, np.zeros_like(input_value))
@@ -154,11 +165,12 @@ class LSTMTest(tf.test.TestCase):
num_units = 3
input_size = 5
batch_size = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
@@ -173,12 +185,13 @@ class LSTMTest(tf.test.TestCase):
num_units = 3
input_size = 5
batch_size = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
cell_clip=0.0, initializer=initializer)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
self.assertEqual(len(outputs), len(inputs))
@@ -197,12 +210,13 @@ class LSTMTest(tf.test.TestCase):
num_units = 3
input_size = 5
batch_size = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
- state_saver = TestStateSaver(batch_size, 2*num_units)
+ state_saver = TestStateSaver(batch_size, 2 * num_units)
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=False, initializer=initializer)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
with tf.variable_scope("share_scope"):
outputs, states = tf.nn.state_saving_rnn(
@@ -223,9 +237,10 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
@@ -244,10 +259,11 @@ class LSTMTest(tf.test.TestCase):
num_proj = 4
num_proj_shards = 4
num_unit_shards = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
@@ -274,9 +290,10 @@ class LSTMTest(tf.test.TestCase):
num_proj = 4
num_proj_shards = 4
num_unit_shards = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
- inputs = 10 * [tf.placeholder(tf.float64)]
+ inputs = max_length * [tf.placeholder(tf.float64)]
cell = tf.nn.rnn_cell.LSTMCell(
num_units,
@@ -305,8 +322,9 @@ class LSTMTest(tf.test.TestCase):
num_proj = 4
num_proj_shards = 4
num_unit_shards = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
- inputs = 10 * [tf.placeholder(tf.float32)]
+ inputs = max_length * [tf.placeholder(tf.float32)]
initializer = tf.constant_initializer(0.001)
cell_noshard = tf.nn.rnn_cell.LSTMCell(
@@ -355,10 +373,11 @@ class LSTMTest(tf.test.TestCase):
num_proj = 4
num_proj_shards = 4
num_unit_shards = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
sequence_length = tf.placeholder(tf.int64)
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
- inputs = 10 * [tf.placeholder(tf.float64)]
+ inputs = max_length * [tf.placeholder(tf.float64)]
cell = tf.nn.rnn_cell.LSTMCell(
num_units,
@@ -392,28 +411,33 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
+ max_length = 8
with self.test_session(graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
- inputs = 10 * [
+ initializer_d = tf.random_uniform_initializer(-1, 1, seed=self._seed+1)
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
num_proj=num_proj, initializer=initializer)
+ cell_d = tf.nn.rnn_cell.LSTMCell(
+ num_units, input_size, use_peepholes=True,
+ num_proj=num_proj, initializer=initializer_d)
with tf.variable_scope("share_scope"):
outputs0, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("share_scope", reuse=True):
outputs1, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
with tf.variable_scope("diff_scope"):
- outputs2, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+ outputs2, _ = tf.nn.rnn(cell_d, inputs, dtype=tf.float32)
tf.initialize_all_variables().run()
input_value = np.random.randn(batch_size, input_size)
output_values = sess.run(
outputs0 + outputs1 + outputs2, feed_dict={inputs[0]: input_value})
- outputs0_values = output_values[:10]
- outputs1_values = output_values[10:20]
- outputs2_values = output_values[20:]
+ outputs0_values = output_values[:max_length]
+ outputs1_values = output_values[max_length:2*max_length]
+ outputs2_values = output_values[2*max_length:]
self.assertEqual(len(outputs0_values), len(outputs1_values))
self.assertEqual(len(outputs0_values), len(outputs2_values))
for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
@@ -427,9 +451,10 @@ class LSTMTest(tf.test.TestCase):
input_size = 5
batch_size = 2
num_proj = 4
+ max_length = 8
with self.test_session(graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-1, 1, seed=self._seed)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(None, input_size))]
cell = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, use_peepholes=True,
@@ -446,43 +471,43 @@ class LSTMTest(tf.test.TestCase):
input_value = np.random.randn(batch_size, input_size)
output_values = sess.run(
outputs0 + outputs1, feed_dict={inputs[0]: input_value})
- outputs0_values = output_values[:10]
- outputs1_values = output_values[10:]
+ outputs0_values = output_values[:max_length]
+ outputs1_values = output_values[max_length:]
self.assertEqual(len(outputs0_values), len(outputs1_values))
for out0, out1 in zip(outputs0_values, outputs1_values):
self.assertAllEqual(out0, out1)
def testNoProjNoShardingSimpleStateSaver(self):
- self._testNoProjNoShardingSimpleStateSaver(False)
- self._testNoProjNoShardingSimpleStateSaver(True)
+ self._testNoProjNoShardingSimpleStateSaver(use_gpu=False)
+ self._testNoProjNoShardingSimpleStateSaver(use_gpu=True)
def testNoProjNoSharding(self):
- self._testNoProjNoSharding(False)
- self._testNoProjNoSharding(True)
+ self._testNoProjNoSharding(use_gpu=False)
+ self._testNoProjNoSharding(use_gpu=True)
def testCellClipping(self):
- self._testCellClipping(False)
- self._testCellClipping(True)
+ self._testCellClipping(use_gpu=False)
+ self._testCellClipping(use_gpu=True)
def testProjNoSharding(self):
- self._testProjNoSharding(False)
- self._testProjNoSharding(True)
+ self._testProjNoSharding(use_gpu=False)
+ self._testProjNoSharding(use_gpu=True)
def testProjSharding(self):
- self._testProjSharding(False)
- self._testProjSharding(True)
+ self._testProjSharding(use_gpu=False)
+ self._testProjSharding(use_gpu=True)
def testShardNoShardEquivalentOutput(self):
- self._testShardNoShardEquivalentOutput(False)
- self._testShardNoShardEquivalentOutput(True)
+ self._testShardNoShardEquivalentOutput(use_gpu=False)
+ self._testShardNoShardEquivalentOutput(use_gpu=True)
def testDoubleInput(self):
- self._testDoubleInput(False)
- self._testDoubleInput(True)
+ self._testDoubleInput(use_gpu=False)
+ self._testDoubleInput(use_gpu=True)
def testDoubleInputWithDropoutAndDynamicCalculation(self):
- self._testDoubleInputWithDropoutAndDynamicCalculation(False)
- self._testDoubleInputWithDropoutAndDynamicCalculation(True)
+ self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=False)
+ self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True)
class BidirectionalRNNTest(tf.test.TestCase):
@@ -495,6 +520,7 @@ class BidirectionalRNNTest(tf.test.TestCase):
num_units = 3
input_size = 5
batch_size = 2
+ max_length = 8
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
sequence_length = tf.placeholder(tf.int64)
@@ -502,7 +528,7 @@ class BidirectionalRNNTest(tf.test.TestCase):
num_units, input_size, initializer=initializer)
cell_bw = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
- inputs = 10 * [
+ inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, inputs, dtype=tf.float32,
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 5930b3486d..bcbca6943a 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -844,6 +844,12 @@ def _SqueezeShape(op):
def _ReshapeShape(op):
"""Shape function for Reshape op."""
input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims is not None:
+ num_elements = tensor_shape.Dimension(1)
+ for dim in input_shape.dims:
+ num_elements *= dim
+ else:
+ num_elements = tensor_shape.Dimension(None)
new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
new_shape = tensor_util.ConstantValue(op.inputs[1])
if new_shape is None:
@@ -853,13 +859,15 @@ def _ReshapeShape(op):
new_shape = np.reshape(new_shape, -1).tolist()
if -1 not in new_shape:
# The new shape is fully defined.
+ if (num_elements.value is not None
+ and num_elements.value != np.prod(new_shape)):
+ raise ValueError(
+ "Cannot reshape a tensor with %d elements to shape %s (%d elements)"
+ % (num_elements.value, new_shape, np.prod(new_shape)))
return [tensor_shape.TensorShape(new_shape)]
- elif input_shape.is_fully_defined():
- # We know the input shape, so we can calculate the missing
+ elif num_elements.value is not None:
+ # We know the number of elements, so we can calculate the missing
# dimension in the new_shape.
- num_elements = 1
- for dim in input_shape.dims:
- num_elements *= dim.value
known_elements = 1
unknown_index = None
for i, dim in enumerate(new_shape):
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index a2b39d6594..a85eb85a82 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -52,11 +52,8 @@ def clip_by_value(t, clip_value_min, clip_value_max,
t = ops.convert_to_tensor(t, name="t")
# Go through list of tensors, for each value in each tensor clip
- t_min = math_ops.minimum(
- t, array_ops.fill(array_ops.shape(t), clip_value_max))
- t_max = math_ops.maximum(
- t_min, array_ops.fill(array_ops.shape(t), clip_value_min),
- name=name)
+ t_min = math_ops.minimum(t, clip_value_max)
+ t_max = math_ops.maximum(t_min, clip_value_min, name=name)
return t_max
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 3c9b432030..f17211d677 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -444,7 +444,7 @@ def gradients(ys,
op_wrapper = control_flow_ops.MakeWrapper(op)
in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
_VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len(in_grads) > 1:
+ if gate_gradients and len(filter(None, in_grads)) > 1:
in_grads = control_flow_ops.tuple(in_grads)
logging.vlog(1, "Gradient for '" + op.name + "'")
logging.vlog(1, " in --> %s",
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index dcd2334e19..3b5e9d2364 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -55,7 +55,8 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
Args:
cell: An instance of RNNCell.
- inputs: A length T list of inputs, each a vector with shape [batch_size].
+ inputs: A length T list of inputs, each a tensor of shape
+ [batch_size, cell.input_size].
initial_state: (optional) An initial state for the RNN. This must be
a tensor of appropriate type and shape [batch_size x cell.state_size].
dtype: (optional) The data type for the initial state. Required if
@@ -124,7 +125,8 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
Args:
cell: An instance of RNNCell.
- inputs: A length T list of inputs, each a vector with shape [batch_size].
+ inputs: A length T list of inputs, each a tensor of shape
+ [batch_size, cell.input_size].
state_saver: A state saver object with methods `state` and `save_state`.
state_name: The name to use with the state_saver.
sequence_length: (optional) An int64 vector (tensor) size [batch_size].
@@ -182,15 +184,17 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
Similar to the unidirectional case above (rnn) but takes input and builds
independent forward and backward RNNs with the final forward and backward
outputs depth-concatenated, such that the output will have the format
- [time][batch][cell_fw.output_size + cell_bw.output_size]. The initial state
- for both directions is zero by default (but can be set optionally) and no
- intermediate states are ever returned -- the network is fully unrolled for
- the given (passed in) length(s) of the sequence(s).
+ [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
+ forward and backward cell must match. The initial state for both directions
+ is zero by default (but can be set optionally) and no intermediate states are
+ ever returned -- the network is fully unrolled for the given (passed in)
+ length(s) of the sequence(s).
Args:
cell_fw: An instance of RNNCell, to be used for forward direction.
cell_bw: An instance of RNNCell, to be used for backward direction.
- inputs: A length T list of inputs, each a vector with shape [batch_size].
+ inputs: A length T list of inputs, each a tensor of shape
+ [batch_size, cell.input_size].
initial_state_fw: (optional) An initial state for the forward RNN.
This must be a tensor of appropriate type and shape
[batch_size x cell.state_size].
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index 584849236a..b1a92c06e8 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -198,6 +198,24 @@ class BasicLSTMCell(RNNCell):
return new_h, array_ops.concat(1, [new_c, new_h])
+def _get_sharded_variable(name, shape, initializer, dtype, num_shards):
+ """Get a list of sharded variables with the given dtype and initializer."""
+ unit_shard_size = int(math.ceil(shape[1] / num_shards))
+
+ shards = []
+ for i in range(num_shards):
+ current_size = min(unit_shard_size, shape[1] - unit_shard_size * i)
+ shards.append(vs.get_variable(name + "_%d" % i, [shape[0], current_size],
+ initializer=initializer, dtype=dtype))
+ return shards
+
+
+def _matmul_with_sharded_variable(tensor, sharded_tensor):
+ """Multiply tensor with each tensor in sharded_tensor and column-concat."""
+ return array_ops.concat(1, [math_ops.matmul(tensor, shard)
+ for shard in sharded_tensor])
+
+
class LSTMCell(RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -231,15 +249,8 @@ class LSTMCell(RNNCell):
matrices. If None, no projection is performed.
num_unit_shards: How to split the weight matrix. If >1, the weight
matrix is stored across num_unit_shards.
- Note that num_unit_shards must evenly divide num_units * 4.
num_proj_shards: How to split the projection matrix. If >1, the
projection matrix is stored across num_proj_shards.
- Note that num_proj_shards must evenly divide num_proj
- (if num_proj is not None).
-
- Raises:
- ValueError: if num_unit_shards doesn't divide 4 * num_units or
- num_proj_shards doesn't divide num_proj
"""
self._num_units = num_units
self._input_size = input_size
@@ -250,11 +261,6 @@ class LSTMCell(RNNCell):
self._num_unit_shards = num_unit_shards
self._num_proj_shards = num_proj_shards
- if (num_units * 4) % num_unit_shards != 0:
- raise ValueError("num_unit_shards must evently divide 4 * num_units")
- if num_proj and num_proj % num_proj_shards != 0:
- raise ValueError("num_proj_shards must evently divide num_proj")
-
if num_proj:
self._state_size = num_units + num_proj
self._output_size = num_proj
@@ -299,15 +305,10 @@ class LSTMCell(RNNCell):
dtype = input_.dtype
- unit_shard_size = (4 * self._num_units) // self._num_unit_shards
-
with vs.variable_scope(scope or type(self).__name__): # "LSTMCell"
- w = array_ops.concat(
- 1,
- [vs.get_variable("W_%d" % i,
- shape=[self.input_size + num_proj, unit_shard_size],
- initializer=self._initializer,
- dtype=dtype) for i in xrange(self._num_unit_shards)])
+ sharded_w = _get_sharded_variable(
+ "W", [self.input_size + num_proj, 4 * self._num_units],
+ self._initializer, dtype, self._num_unit_shards)
b = vs.get_variable(
"B", shape=[4 * self._num_units],
@@ -315,17 +316,24 @@ class LSTMCell(RNNCell):
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat(1, [input_, m_prev])
- i, j, f, o = array_ops.split(
- 1, 4, nn_ops.bias_add(math_ops.matmul(cell_inputs, w), b))
+ lstm_matrix = nn_ops.bias_add(
+ _matmul_with_sharded_variable(cell_inputs, sharded_w), b)
+ i, j, f, o = array_ops.split(1, 4, lstm_matrix)
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
- "W_F_diag", shape=[self._num_units], dtype=dtype)
+ "W_F_diag", shape=[self._num_units],
+ initializer=self._initializer,
+ dtype=dtype)
w_i_diag = vs.get_variable(
- "W_I_diag", shape=[self._num_units], dtype=dtype)
+ "W_I_diag", shape=[self._num_units],
+ initializer=self._initializer,
+ dtype=dtype)
w_o_diag = vs.get_variable(
- "W_O_diag", shape=[self._num_units], dtype=dtype)
+ "W_O_diag", shape=[self._num_units],
+ initializer=self._initializer,
+ dtype=dtype)
if self._use_peepholes:
c = (sigmoid(f + 1 + w_f_diag * c_prev) * c_prev +
@@ -342,16 +350,11 @@ class LSTMCell(RNNCell):
m = sigmoid(o) * tanh(c)
if self._num_proj is not None:
- proj_shard_size = self._num_proj // self._num_proj_shards
- w_proj = array_ops.concat(
- 1,
- [vs.get_variable("W_P_%d" % i,
- shape=[self._num_units, proj_shard_size],
- initializer=self._initializer,
- dtype=dtype)
- for i in xrange(self._num_proj_shards)])
- # TODO(ebrevdo), use matmulsum
- m = math_ops.matmul(m, w_proj)
+ sharded_w_proj = _get_sharded_variable(
+ "W_P", [self._num_units, self._num_proj], self._initializer,
+ dtype, self._num_proj_shards)
+
+ m = _matmul_with_sharded_variable(m, sharded_w_proj)
return m, array_ops.concat(1, [c, m])
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
index f2b4437809..ab5fb4a426 100644
--- a/tensorflow/python/summary/event_accumulator.py
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -174,7 +174,10 @@ class EventAccumulator(object):
with self._generator_mutex:
for event in self._generator.Load():
## Check if the event happened after a crash
- if event.step < self.most_recent_step:
+ ## file_version events always have step 0, ignore.
+ ## TODO(danmane): Have this check for restart events explicitly
+ if (event.step < self.most_recent_step and
+ not event.HasField('file_version')):
## Keep data in reservoirs that has a step less than event.step
_NotExpired = lambda x: x.step < event.step
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index 3cc7e493d0..a64084d826 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -391,6 +391,16 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
## Check that we have discarded 200 and 300
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
+ def testFileVersionEventDoesntTriggerDiscard(self):
+ """Test that file version event doesnt trigger data purge."""
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalar('s1', wall_time=1, step=100, value=20)
+ ev = tf.Event(wall_time=2, step=0, file_version='0')
+ gen.AddEvent(ev)
+ acc.Reload()
+ self.assertEqual([x.step for x in acc.Scalars('s1')], [100])
+
class RealisticEventAccumulatorTest(EventAccumulatorTest):
diff --git a/tensorflow/tensorboard/CHANGES b/tensorflow/tensorboard/CHANGES
index 2fe8332ece..88e252f457 100644
--- a/tensorflow/tensorboard/CHANGES
+++ b/tensorflow/tensorboard/CHANGES
@@ -1,2 +1,6 @@
--- 2 ---
Begin tracking TensorBoard changes.
+
+--- 3 ---
+Change default # of scalar values to 1000
+Fix bug where TensorBoard discards all values after a restart. \ No newline at end of file
diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG
index 0cfbf08886..00750edc07 100644
--- a/tensorflow/tensorboard/TAG
+++ b/tensorflow/tensorboard/TAG
@@ -1 +1 @@
-2
+3
diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py
index 87c0827c16..f6666d3b8b 100644
--- a/tensorflow/tensorboard/tensorboard.py
+++ b/tensorflow/tensorboard/tensorboard.py
@@ -55,8 +55,9 @@ flags.DEFINE_boolean('debug', False, 'Whether to run the app in debug mode. '
'This increases log verbosity to DEBUG.')
-flags.DEFINE_string('host', '127.0.0.1', 'What host to listen to. Defaults to '
- 'serving on localhost, set to 0.0.0.0 for remote access.')
+flags.DEFINE_string('host', '0.0.0.0', 'What host to listen to. Defaults to '
+ 'serving on 0.0.0.0, set to 127.0.0.1 (localhost) to'
+ 'disable remote access (also quiets security warnings).')
flags.DEFINE_integer('port', 6006, 'What port to serve TensorBoard on.')
@@ -66,7 +67,7 @@ FLAGS = flags.FLAGS
TENSORBOARD_SIZE_GUIDANCE = {
event_accumulator.COMPRESSED_HISTOGRAMS: 500,
event_accumulator.IMAGES: 4,
- event_accumulator.SCALARS: 10000,
+ event_accumulator.SCALARS: 1000,
event_accumulator.HISTOGRAMS: 1,
}
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 692080fc23..649bf23284 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -19,6 +19,7 @@ sh_binary(
"setup.py",
":simple_console",
"//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist:package",
"//tensorflow/models/embedding:package",
"//tensorflow/models/image/cifar10:all_files",
"//tensorflow/models/image/mnist:convolutional",