aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/function.cc5
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc100
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h27
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc4
-rw-r--r--tensorflow/core/kernels/concat_op.cc1
-rw-r--r--tensorflow/core/kernels/concat_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc62
-rw-r--r--tensorflow/core/kernels/tile_ops.h5
-rw-r--r--tensorflow/core/ops/array_ops.cc2
-rw-r--r--tensorflow/g3doc/api_docs/python/array_ops.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/constant_op.md15
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md8
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md38
-rw-r--r--tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/input_data.py4
-rw-r--r--tensorflow/g3doc/tutorials/mnist/mnist.py2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py4
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py7
-rw-r--r--tensorflow/models/image/mnist/convolutional.py4
-rw-r--r--tensorflow/models/rnn/rnn_cell_test.py2
-rw-r--r--tensorflow/models/rnn/seq2seq.py2
-rw-r--r--tensorflow/models/rnn/translate/data_utils.py4
-rw-r--r--tensorflow/python/__init__.py12
-rw-r--r--tensorflow/python/framework/ops.py25
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py26
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/lookup_table_op_test.py218
-rw-r--r--tensorflow/python/ops/clip_ops.py4
-rw-r--r--tensorflow/python/ops/data_flow_ops.py169
-rw-r--r--tensorflow/python/ops/embedding_ops.py2
-rw-r--r--tensorflow/python/ops/math_grad.py2
-rw-r--r--tensorflow/python/ops/math_ops.py15
-rw-r--r--tensorflow/python/ops/nn.py17
-rw-r--r--tensorflow/python/ops/nn_grad.py2
-rw-r--r--tensorflow/python/ops/nn_test.py88
-rw-r--r--tensorflow/python/ops/sparse_ops.py3
-rw-r--r--tensorflow/python/platform/default/_logging.py31
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/training/learning_rate_decay.py6
-rw-r--r--tensorflow/tensorboard/bower.json39
-rw-r--r--tensorflow/tensorboard/tensorboard_handler.py4
43 files changed, 397 insertions, 580 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 2b1a041235..4667f096e0 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -495,6 +495,11 @@ static void SimplifyGraph(Graph* g) {
void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
DumpGraph("Initial", *g);
+
+ // Run SimplifyGraph at least once to rewrite away ops such as
+ // _ListToArray, _ArrayToList, etc.
+ SimplifyGraph(*g);
+
const int kNumInlineRounds = 10;
for (int i = 0; i < kNumInlineRounds; ++i) {
if (!ExpandInlineFunctions(lib, *g)) break;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 3df833594f..8979c94e3d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -65,7 +65,7 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));
// Insert the chunk into the right bin.
- ReassignChunkToBin(c);
+ InsertFreeChunkIntoBin(c);
}
GPUBFCAllocator::~GPUBFCAllocator() {
@@ -76,6 +76,7 @@ GPUBFCAllocator::~GPUBFCAllocator() {
}
gtl::STLDeleteValues(&bins_);
+ gtl::STLDeleteValues(&ptr_to_chunk_map_);
}
void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
@@ -115,10 +116,12 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
// Start searching from the first bin for the smallest chunk that fits
// rounded_bytes.
Bin* b = it->second;
- for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
- if (!chunk->in_use && chunk->size > rounded_bytes) {
- // We found an existing chunk that fits us that wasn't in use.
- chunk->in_use = true;
+ for (GPUBFCAllocator::Chunk* chunk : b->free_chunks) {
+ DCHECK(!chunk->in_use);
+ if (chunk->size >= rounded_bytes) {
+ // We found an existing chunk that fits us that wasn't in use, so remove
+ // it from the free bin structure prior to using.
+ RemoveFreeChunkFromBin(chunk);
// If we can break the size of the chunk into two reasonably
// large pieces, do so.
@@ -132,6 +135,7 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
// The requested size of the returned chunk is what the user
// has allocated.
chunk->requested_size = num_bytes;
+ chunk->in_use = true;
VLOG(4) << "Returning: " << chunk->ptr;
return chunk->ptr;
@@ -152,6 +156,8 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
}
void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
+ CHECK(!c->in_use && !c->bin);
+
// Create a new chunk starting num_bytes after c
GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
@@ -176,9 +182,8 @@ void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
c_neighbor->prev = new_chunk;
}
- // Maintain the bins
- ReassignChunkToBin(new_chunk);
- ReassignChunkToBin(c);
+ // Add the newly free chunk to the free bin.
+ InsertFreeChunkIntoBin(new_chunk);
}
void GPUBFCAllocator::DeallocateRaw(void* ptr) {
@@ -200,11 +205,9 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
GPUBFCAllocator::Chunk* c = it->second;
VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
- // Mark the chunk as no longer in use
- c->in_use = false;
// Consider coalescing it.
- MaybeCoalesce(c);
+ FreeAndMaybeCoalesce(c);
}
// Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
@@ -212,7 +215,7 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
GPUBFCAllocator::Chunk* c2) {
// We can only merge chunks that are not in use.
- DCHECK(!c1->in_use && !c2->in_use);
+ CHECK(!c1->in_use && !c2->in_use);
// c1's prev doesn't change, still points to the same ptr, and is
// still not in use.
@@ -231,62 +234,42 @@ void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
// Set the new size
c1->size += c2->size;
+ DeleteChunk(c2);
+}
+
+void GPUBFCAllocator::DeleteChunk(Chunk* c) {
// Delete c2 and cleanup all state
- RemoveChunkFromBin(c2);
+ VLOG(4) << "Removing: " << c->ptr;
+ ptr_to_chunk_map_.erase(c->ptr);
+ delete c;
}
-void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
+void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::Chunk* c) {
+ CHECK(!c->in_use && !c->bin);
auto it = bins_.lower_bound(c->size);
CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
<< c->size;
-
Bin* new_bin = it->second;
-
- // If the bin has not changed, do nothing.
- Bin* old_bin = c->bin;
- if (old_bin != nullptr && new_bin == old_bin) {
- return;
- }
-
- // The bin has changed. Add the chunk to the new bin and remove
- // the chunk from the old bin.
- new_bin->chunks.insert(c);
c->bin = new_bin;
+ new_bin->free_chunks.insert(c);
+}
- if (old_bin == nullptr) {
- return;
- }
-
- // Remove chunk from old bin
- for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
- if (*it == c) {
- old_bin->chunks.erase(it);
- return;
- }
- }
- CHECK(false) << "Could not find chunk in old bin";
+void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::Chunk* c) {
+ CHECK(!c->in_use && c->bin);
+ int count = c->bin->free_chunks.erase(c);
+ CHECK(count > 0) << "Could not find chunk in bin";
+ c->bin = nullptr;
}
-void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
- Bin* b = c->bin;
- for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
- Chunk* other_c = *it;
- if (other_c->ptr == c->ptr) {
- b->chunks.erase(it);
- VLOG(4) << "Removing: " << c->ptr;
- ptr_to_chunk_map_.erase(c->ptr);
- delete c;
- return;
- }
- }
+void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::Chunk* c) {
+ CHECK(c->in_use && !c->bin);
- CHECK(false) << "Could not find chunk in bin";
-}
+ // Mark the chunk as no longer in use
+ c->in_use = false;
-void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
// This chunk is no longer in-use, consider coalescing the chunk
// with adjacent chunks.
- Chunk* chunk_to_reassign = nullptr;
+ Chunk* chunk_to_reassign = c;
// If the next chunk is free, coalesce the two, if the result would
// fit in an existing bin.
@@ -296,6 +279,7 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
chunk_to_reassign = c;
// Deletes c->next
+ RemoveFreeChunkFromBin(c->next);
Merge(c, c->next);
}
@@ -307,13 +291,11 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
chunk_to_reassign = c->prev;
// Deletes c
+ RemoveFreeChunkFromBin(c->prev);
Merge(c->prev, c);
}
- // Reassign the final merged chunk into the right bin.
- if (chunk_to_reassign) {
- ReassignChunkToBin(chunk_to_reassign);
- }
+ InsertFreeChunkIntoBin(chunk_to_reassign);
}
void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
@@ -354,7 +336,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
size_t total_requested_bytes_in_bin = 0;
size_t total_chunks_in_use = 0;
size_t total_chunks_in_bin = 0;
- for (Chunk* c : b->chunks) {
+ for (Chunk* c : b->free_chunks) {
total_bytes_in_bin += c->size;
total_requested_bytes_in_bin += c->requested_size;
++total_chunks_in_bin;
@@ -388,7 +370,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
<< " was " << strings::HumanReadableNumBytes(b->bin_size)
<< ", Chunk State: ";
- for (Chunk* c : b->chunks) {
+ for (Chunk* c : b->free_chunks) {
LOG(INFO) << c->DebugString(true);
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 3d1601e132..417df6f413 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -102,28 +102,33 @@ class GPUBFCAllocator : public VisitableAllocator {
Chunk* AllocateNewChunk(size_t num_bytes);
void SplitChunk(Chunk* c, size_t num_bytes);
void Merge(Chunk* c1, Chunk* c2);
- void MaybeCoalesce(Chunk* c);
-
- void ReassignChunkToBin(Chunk* c);
- void RemoveChunkFromBin(Chunk* c);
+ void FreeAndMaybeCoalesce(Chunk* c);
+ void InsertFreeChunkIntoBin(Chunk* c);
+ void RemoveFreeChunkFromBin(Chunk* c);
+ void DeleteChunk(Chunk* c);
void DumpMemoryLog(size_t num_bytes);
- // A Bin is a collection of similar-sized Chunks.
+ // A Bin is a collection of similar-sized free chunks.
struct Bin {
// All chunks in this bin have >= bin_size memory.
size_t bin_size = 0;
struct ChunkComparator {
- bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
+ // Sort first by size and then use pointer address as a tie breaker.
+ bool operator()(const Chunk* a, const Chunk* b) const {
+ if (a->size != b->size) {
+ return a->size < b->size;
+ }
+ return a->ptr < b->ptr;
+ }
};
- // List of chunks within the bin, sorted by chunk size.
- std::multiset<Chunk*, ChunkComparator> chunks;
+ // List of free chunks within the bin, sorted by chunk size.
+ // Chunk * not owned.
+ std::set<Chunk*, ChunkComparator> free_chunks;
explicit Bin(size_t bs) : bin_size(bs) {}
-
- ~Bin() { gtl::STLDeleteElements(&chunks); }
};
GPUAllocatorRetry retry_helper_;
@@ -142,7 +147,7 @@ class GPUBFCAllocator : public VisitableAllocator {
// Structures mutable after construction
mutable mutex lock_;
- // Not owned.
+ // Chunk * owned.
std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
// Called once on each region, ASAP.
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 474b988d2f..fa9c0170f5 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -20,7 +20,7 @@ DEFINE_bool(record_mem_types, false,
DEFINE_bool(brain_mem_reg_cuda_dma, true,
"If true, register CPU RAM used to copy to/from GPU RAM "
"with the CUDA driver.");
-DEFINE_bool(brain_gpu_use_bfc_allocator, false,
+DEFINE_bool(brain_gpu_use_bfc_allocator, true,
"If true, uses the Best-Fit GPU allocator.");
DEFINE_bool(brain_gpu_region_allocator_debug, false,
"If true, checks for memory overwrites by writing "
@@ -34,7 +34,7 @@ bool FLAGS_record_mem_types = false;
bool FLAGS_brain_mem_reg_cuda_dma = true;
bool FLAGS_brain_gpu_region_allocator_debug = false;
bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
-bool FLAGS_brain_gpu_use_bfc_allocator = false;
+bool FLAGS_brain_gpu_use_bfc_allocator = true;
#endif
namespace gpu = ::perftools::gputools;
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index b68fcec515..adc802cb45 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -135,6 +135,7 @@ REGISTER_CONCAT(bfloat16);
ConcatOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
index aed36dccef..0f21d5f07c 100644
--- a/tensorflow/core/kernels/concat_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -6,6 +6,7 @@
#include <memory>
+#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,6 +35,7 @@ void ConcatGPU(const GPUDevice& d,
typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 7bfe8b095f..eeb9fa4c38 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -541,12 +541,36 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// The output image size is the spatial size of the output.
const int output_image_size = out_rows * out_cols;
+ // Shard 'batch' images into 'shard_size' groups of images to be fed
+ // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
+ // size ('target_working_set_size') by the matmul size of an individual
+ // image ('work_unit_size').
+
+ // TODO(andydavis)
+ // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = (30LL << 20) / sizeof(T);
+
+ const size_t size_A = output_image_size * filter_total_size;
+
+ const size_t size_B = output_image_size * out_depth;
+
+ const size_t size_C = filter_total_size * out_depth;
+
+ const size_t work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
Tensor col_buffer;
- OP_REQUIRES_OK(
- context,
- context->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({output_image_size, filter_total_size}), &col_buffer));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)}),
+ &col_buffer));
// The input offset corresponding to a single input image.
const int input_offset = input_rows * input_cols * in_depth;
@@ -571,21 +595,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
contract_dims[0].first = 0;
contract_dims[0].second = 0;
- for (int image_id = 0; image_id < batch; ++image_id) {
- // When we compute the gradient with respect to the filters, we need to do
- // im2col to allow gemm-type computation.
- Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
- filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
- stride, col_buffer_data);
+ for (int image_id = 0; image_id < batch; image_id += shard_size) {
+ const int shard_limit = std::min(static_cast<int>(shard_size),
+ static_cast<int>(batch) - image_id);
+ for (int shard_id = 0; shard_id < shard_limit; ++shard_id) {
+ // TODO(andydavis) Parallelize this loop.
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
+ filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
+ stride, col_buffer_data + shard_id * size_A);
+
+ input_data += input_offset;
+ }
- ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size);
- ConstTensorMap B(out_backprop_data + output_offset * image_id,
- output_image_size, out_depth);
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ out_depth);
// Gradient with respect to filter.
C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
- input_data += input_offset;
+ out_backprop_data += output_offset * shard_limit;
}
}
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
index 1a614fe4f1..99455adce2 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -46,8 +46,9 @@ template <typename Device, typename T>
struct TileGrad<Device, T, 0> {
void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
typename TTypes<T, 0>::ConstTensor in,
- const Eigen::DSizes<ptrdiff_t, 0>&,
- const Eigen::DSizes<ptrdiff_t, 0>&, bool first) const {
+ const Eigen::DSizes<Eigen::DenseIndex, 0>&,
+ const Eigen::DSizes<Eigen::DenseIndex, 0>&,
+ bool first) const {
if (first) {
out.device(d) = in;
} else {
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 8c0571b50e..321d9c0276 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -407,7 +407,7 @@ reshape(t, [3, 3]) ==> [[1, 2, 3]
# tensor 't' is [[[1, 1], [2, 2]]
# [[3, 3], [4, 4]]]
-# tensor 't' has shape [2, 2]
+# tensor 't' has shape [2, 2, 2]
reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
[3, 3, 4, 4]]
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index f34e12b1dc..10a6df41ff 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -529,7 +529,7 @@ tf.shape(split0) ==> [5, 10]
* <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split.
Must be in the range `[0, rank(value))`.
-* <b>`num_split`</b>: A 0-D `int32` `Tensor`. The number of ways to split.
+* <b>`num_split`</b>: A Python integer. The number of ways to split.
* <b>`value`</b>: The `Tensor` to split.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index edbb8101a4..0e3abf6676 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -17,7 +17,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
* [`tf.constant(value, dtype=None, shape=None, name='Const')`](#constant)
* [Sequences](#AUTOGENERATED-sequences)
* [`tf.linspace(start, stop, num, name=None)`](#linspace)
- * [`tf.range(start, limit, delta=1, name='range')`](#range)
+ * [`tf.range(start, limit=None, delta=1, name='range')`](#range)
* [Random Tensors](#AUTOGENERATED-random-tensors)
* [Examples:](#AUTOGENERATED-examples-)
* [`tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#random_normal)
@@ -273,12 +273,15 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
- - -
-### `tf.range(start, limit, delta=1, name='range')` <a class="md-anchor" id="range"></a>
+### `tf.range(start, limit=None, delta=1, name='range')` <a class="md-anchor" id="range"></a>
Creates a sequence of integers.
-This operation creates a sequence of integers that begins at `start` and
-extends by increments of `delta` up to but not including `limit`.
+Creates a sequence of integers that begins at `start` and extends by
+increments of `delta` up to but not including `limit`.
+
+Like the Python builtin `range`, `start` defaults to 0, so that
+`range(n) = range(0, n)`.
For example:
@@ -287,12 +290,16 @@ For example:
# 'limit' is 18
# 'delta' is 3
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+
+# 'limit' is 5
+tf.range(limit) ==> [0, 1, 2, 3, 4]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
* <b>`start`</b>: A 0-D (scalar) of type `int32`. First entry in sequence.
+ Defaults to 0.
* <b>`limit`</b>: A 0-D (scalar) of type `int32`. Upper limit of sequence,
exclusive.
* <b>`delta`</b>: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 90c70bec24..85739e6d3b 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -143,7 +143,7 @@ This must be called by the constructors of subclasses.
- - -
-#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
+#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
Add operations to minimize 'loss' by updating 'var_list'.
@@ -163,6 +163,8 @@ this function.
under the key GraphKeys.TRAINABLE_VARIABLES.
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
GATE_NONE, GATE_OP, or GATE_GRAPH.
+* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
* <b>`name`</b>: Optional name for the returned operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
@@ -178,7 +180,7 @@ this function.
- - -
-#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
+#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
Compute gradients of "loss" for the variables in "var_list".
@@ -197,6 +199,8 @@ given variable.
under the key GraphKey.TRAINABLE_VARIABLES.
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
GATE_NONE, GATE_OP, or GATE_GRAPH.
+* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index aa0301028a..e5cd016544 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -268,6 +268,34 @@ $ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu
Note that "--config=cuda" is needed to enable the GPU support.
+##### Enabling Cuda 3.0. <a class="md-anchor" id="AUTOGENERATED-enabling-cuda-3.0."></a>
+TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute
+capabilities. In order to enable earlier Cuda devices such as Grid K520, you
+need to target Cuda 3.0. This can be done through TensorFlow unofficial
+settings with "configure".
+
+```bash
+$ TF_UNOFFICIAL_SETTING=1 ./configure
+
+# Same as the official settings above
+
+WARNING: You are configuring unofficial settings in TensorFlow. Because some
+external libraries are not backward compatible, these settings are largely
+untested and unsupported.
+
+Please specify a list of comma-separated Cuda compute capabilities you want to
+build with. You can find the compute capability of your device at:
+https://developer.nvidia.com/cuda-gpus.
+Please note that each additional compute capability significantly increases
+your build time and binary size. [Default is: "3.5,5.2"]: 3.0
+
+Setting up Cuda include
+Setting up Cuda lib64
+Setting up Cuda bin
+Setting up Cuda nvvm
+Configuration finished
+```
+
##### Known issues <a class="md-anchor" id="AUTOGENERATED-known-issues"></a>
* Although it is possible to build both Cuda and non-Cuda configs under the same
@@ -360,14 +388,20 @@ Make sure you followed the the GPU installation [instructions](#install_cuda).
#### Can't find setup.py <a class="md-anchor" id="AUTOGENERATED-can-t-find-setup.py"></a>
-If, during pip install, you encounter an error like:
+If, during `pip install`, you encounter an error like:
```bash
...
IOError: [Errno 2] No such file or directory: '/tmp/pip-o6Tpui-build/setup.py'
```
-Solution: upgrade your version of pip.
+Solution: upgrade your version of `pip`:
+
+```bash
+pip install --upgrade pip
+```
+
+This may require `sudo`, depending on how `pip` is installed.
#### SSLError: SSL_VERIFY_FAILED <a class="md-anchor" id="AUTOGENERATED-sslerror--ssl_verify_failed"></a>
diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
index 8de7b080eb..e431f4c26d 100644
--- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
+++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
@@ -66,7 +66,7 @@ every hundred steps or so, as in the following code example.
```python
merged_summary_op = tf.merge_all_summaries()
-summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph)
+summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def)
total_step = 0
while training:
total_step += 1
diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/g3doc/tutorials/mnist/input_data.py
index 391d133ea1..890a552010 100644
--- a/tensorflow/g3doc/tutorials/mnist/input_data.py
+++ b/tensorflow/g3doc/tutorials/mnist/input_data.py
@@ -5,9 +5,9 @@ from __future__ import print_function
import gzip
import os
-import urllib
import numpy
+from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
@@ -19,7 +19,7 @@ def maybe_download(filename, work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
- filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+ filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist.py b/tensorflow/g3doc/tutorials/mnist/mnist.py
index 64be52293a..925debac6e 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist.py
+++ b/tensorflow/g3doc/tutorials/mnist/mnist.py
@@ -91,7 +91,7 @@ def loss(logits, labels):
# be a 1.0 in the entry corresponding to the label).
batch_size = tf.size(labels)
labels = tf.expand_dims(labels, 1)
- indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
+ indices = tf.expand_dims(tf.range(batch_size), 1)
concated = tf.concat(1, [indices, labels])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index a83d3eabd5..0a7d1aeae0 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -39,7 +39,7 @@ Tensorflow relies on a highly efficient C++ backend to do its computation. The
connection to this backend is called a session. The common usage for TensorFlow
programs is to first create a graph and then launch it in a session.
-Here we instead use the convenience `InteractiveSession` class, which
+Here we instead use the convenient `InteractiveSession` class, which
makes TensorFlow more flexible about how you
structure your code.
It allows you to interleave operations which build a
@@ -232,7 +232,7 @@ print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a>
Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
-section, we'll fix that, jumping from a very simple model to something moderatly
+section, we'll fix that, jumping from a very simple model to something moderately
sophisticated: a small convolutional neural network. This will get us to around
99.2% accuracy -- not state of the art, but respectable.
diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
index a9e0f28436..d7acb3960e 100644
--- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
@@ -9,9 +9,9 @@ import math
import numpy as np
import os
import random
+from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-import urllib
import zipfile
# Step 1: Download the data.
@@ -20,7 +20,7 @@ url = 'http://mattmahoney.net/dc/'
def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
- filename, _ = urllib.urlretrieve(url + filename, filename)
+ filename, _ = urllib.request.urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 8fcd790130..627cf01b6a 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -25,9 +25,9 @@ import os
import re
import sys
import tarfile
-import urllib
import tensorflow.python.platform
+from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -366,7 +366,7 @@ def loss(logits, labels):
# Reshape the labels into a dense Tensor of
# shape [batch_size, NUM_CLASSES].
sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1])
- indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
+ indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
concated = tf.concat(1, [indices, sparse_labels])
dense_labels = tf.sparse_to_dense(concated,
[FLAGS.batch_size, NUM_CLASSES],
@@ -478,7 +478,8 @@ def maybe_download_and_extract():
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
- filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
+ filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath,
+ reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index e388b772fe..8e9275ab6a 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -11,11 +11,11 @@ from __future__ import print_function
import gzip
import os
import sys
-import urllib
import tensorflow.python.platform
import numpy
+from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -41,7 +41,7 @@ def maybe_download(filename):
os.mkdir(WORK_DIRECTORY)
filepath = os.path.join(WORK_DIRECTORY, filename)
if not os.path.exists(filepath):
- filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+ filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py
index 937e1557bd..447ddfebd4 100644
--- a/tensorflow/models/rnn/rnn_cell_test.py
+++ b/tensorflow/models/rnn/rnn_cell_test.py
@@ -118,7 +118,7 @@ class RNNCellTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 3])
- keep = tf.zeros([1]) + 1
+ keep = tf.zeros([]) + 1
g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3),
keep, keep)(x, m)
sess.run([tf.variables.initialize_all_variables()])
diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py
index 875bcb5e6e..63e6181a26 100644
--- a/tensorflow/models/rnn/seq2seq.py
+++ b/tensorflow/models/rnn/seq2seq.py
@@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
# SparseToDense does not accept batched inputs, we need to do this by
# re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy,
# rewrite this method.
- indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size)
+ indices = targets[i] + num_decoder_symbols * tf.range(batch_size)
with tf.device("/cpu:0"): # Sparse-to-dense must happen on CPU for now.
dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0,
0.0)
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
index b9d951ccd7..00f77599af 100644
--- a/tensorflow/models/rnn/translate/data_utils.py
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -7,9 +7,9 @@ import gzip
import os
import re
import tarfile
-import urllib
from tensorflow.python.platform import gfile
+from six.moves import urllib
# Special vocabulary symbols - we always put them at the start.
_PAD = "_PAD"
@@ -40,7 +40,7 @@ def maybe_download(directory, filename, url):
filepath = os.path.join(directory, filename)
if not os.path.exists(filepath):
print("Downloading %s to %s" % (url, filepath))
- filepath, _ = urllib.urlretrieve(url, filepath)
+ filepath, _ = urllib.request.urlretrieve(url, filepath)
statinfo = os.stat(filepath)
print("Succesfully downloaded", filename, statinfo.st_size, "bytes")
return filepath
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 2cbdf191c6..1b7e0eb9a9 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -13,8 +13,16 @@ import tensorflow as tf
"""
-import tensorflow.python.platform
-from tensorflow.core.framework.graph_pb2 import *
+try:
+ import tensorflow.python.platform
+ from tensorflow.core.framework.graph_pb2 import *
+except ImportError as e:
+ msg = """Error importing tensorflow: you should not try to import
+ tensorflow from its source directory; please exit the tensorflow source tree,
+ and relaunch your python interpreter from there.
+ Original ImportError: %s""" % str(e)
+ raise ImportError(msg)
+
from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.config_pb2 import *
from tensorflow.core.util.event_pb2 import *
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 70321e76dc..2801d588e8 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1551,6 +1551,8 @@ class Graph(object):
# True if the graph is considered "finalized". In that case no
# new operations can be added.
self._finalized = False
+ # Functions defined in the graph
+ self._functions = []
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -1655,8 +1657,30 @@ class Graph(object):
bytesize += op.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
+ if self._functions:
+ for f in self._functions:
+ bytesize += f.ByteSize()
+ if bytesize >= (1 << 31) or bytesize < 0:
+ raise ValueError("GraphDef cannot be larger than 2GB.")
+ graph.library.function.extend(self._functions)
return graph
+ def _add_function(self, function_def):
+ """Adds a function to the graph.
+
+ The function is specified as a [`FunctionDef`]
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer.
+
+ After the function has been added, you can call to the function by
+ passing the function name in place of an op name to
+ `Graph.create_op()`.
+
+ Args:
+ function_def: A `FunctionDef` protocol buffer.
+ """
+ self._functions.append(function_def)
+
# Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes,
input_types=None, name=None, attrs=None, op_def=None,
@@ -1869,7 +1893,6 @@ class Graph(object):
A list of Operations.
"""
return list(self._nodes_by_id.values())
-
def get_operation_by_name(self, name):
"""Returns the `Operation` with the given `name`.
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index b495948554..16efddc697 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -75,18 +75,28 @@ class ConcatOpTest(tf.test.TestCase):
# Random dim to concat on
concat_dim = np.random.randint(5)
params = {}
+ if dtype == tf.bfloat16:
+ dtype_feed = tf.float32
+ else:
+ dtype_feed = dtype
with self.test_session(use_gpu=use_gpu):
p = []
for i in np.arange(num_tensors):
input_shape = shape
input_shape[concat_dim] = np.random.randint(1, 5)
- placeholder = tf.placeholder(dtype, shape=input_shape)
+ placeholder = tf.placeholder(dtype_feed, shape=input_shape)
p.append(placeholder)
- t = dtype.as_numpy_dtype
+ t = dtype_feed.as_numpy_dtype
params[placeholder] = np.random.rand(*input_shape).astype(t)
- c = tf.concat(concat_dim, p)
+ if dtype != dtype_feed:
+ concat_inputs = [tf.cast(p_i, dtype) for p_i in p]
+ else:
+ concat_inputs = p
+ c = tf.concat(concat_dim, concat_inputs)
+ if dtype != dtype_feed:
+ c = tf.cast(c, dtype_feed)
result = c.eval(feed_dict=params)
self.assertEqual(result.shape, c.get_shape())
@@ -100,15 +110,17 @@ class ConcatOpTest(tf.test.TestCase):
ind[concat_dim] = slice(cur_offset,
cur_offset + params[p[i]].shape[concat_dim])
cur_offset += params[p[i]].shape[concat_dim]
- self.assertAllEqual(result[ind], params[p[i]])
+ if dtype == dtype_feed:
+ self.assertAllEqual(result[ind], params[p[i]])
+ else:
+ self.assertAllClose(result[ind], params[p[i]], 0.01)
def testRandom(self):
self._testRandom(tf.float32)
self._testRandom(tf.int16)
self._testRandom(tf.int32, use_gpu=True)
- # Note that the following does not work since bfloat16 is not supported in
- # numpy.
- # self._testRandom(tf.bfloat16)
+ self._testRandom(tf.bfloat16)
+ self._testRandom(tf.bfloat16, use_gpu=True)
def _testGradientsSimple(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 03844d6177..5a987c912c 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -262,7 +262,7 @@ class EmbeddingLookupTest(tf.test.TestCase):
self.assertAllEqual(simple, tf.gather(params, ids).eval())
# Run a few random sharded versions
for procs in 1, 2, 3:
- stride = procs * tf.range(0, params.shape[0] // procs)
+ stride = procs * tf.range(params.shape[0] // procs)
split_params = [tf.gather(params, stride + p)
for p in xrange(procs)]
sharded = tf.nn.embedding_lookup(split_params, ids).eval()
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 1b9f1323e8..8036989a0e 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -190,6 +190,10 @@ class RangeTest(tf.test.TestCase):
self._Range(100, 500, 100), np.array([100, 200, 300, 400])))
self.assertEqual(tf.range(0, 5, 1).dtype, tf.int32)
+ def testLimitOnly(self):
+ with self.test_session():
+ self.assertAllEqual(np.arange(5), tf.range(5).eval())
+
def testEmpty(self):
for start in 0, 5:
self.assertTrue(np.array_equal(self._Range(start, start, 1), []))
diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py
deleted file mode 100644
index 7b5942cacd..0000000000
--- a/tensorflow/python/kernel_tests/lookup_table_op_test.py
+++ /dev/null
@@ -1,218 +0,0 @@
-"""Tests for lookup table ops from tf."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow.python.platform
-
-import numpy as np
-import tensorflow as tf
-
-
-class HashTableOpTest(tf.test.TestCase):
-
- def testHashTable(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- init = table.initialize_from(keys, values)
- init.run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant(['brain', 'salad', 'tank'])
- output = table.lookup(input_string)
-
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
-
- def testHashTableFindHighRank(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- init = table.initialize_from(keys, values)
- init.run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant([['brain', 'salad'], ['tank', 'tarkus']])
- output = table.lookup(input_string)
-
- result = output.eval()
- self.assertAllEqual([[0, 1], [-1, -1]], result)
-
- def testHashTableInitWithPythonArrays(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
- # Empty table.
- self.assertAllEqual(0, table.size().eval())
-
- # Initialize with keys and values tensors.
- keys = ['brain', 'salad', 'surgery']
- values = [0, 1, 2]
- init = table.initialize_from(keys, values)
- init.run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant(['brain', 'salad', 'tank'])
- output = table.lookup(input_string)
-
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
-
- def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str)
- values = np.array([0, 1, 2], dtype=np.int64)
- init = table.initialize_from(keys, values)
- init.run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant(['brain', 'salad', 'tank'])
- output = table.lookup(input_string)
-
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
-
- def testMultipleHashTables(self):
- with self.test_session() as sess:
- shared_name = ''
- default_val = -1
- table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
- table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
- table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- table1.initialize_from(keys, values)
- table2.initialize_from(keys, values)
- table3.initialize_from(keys, values)
-
- tf.initialize_all_tables().run()
- self.assertAllEqual(3, table1.size().eval())
- self.assertAllEqual(3, table2.size().eval())
- self.assertAllEqual(3, table3.size().eval())
-
- input_string = tf.constant(['brain', 'salad', 'tank'])
- output1 = table1.lookup(input_string)
- output2 = table2.lookup(input_string)
- output3 = table3.lookup(input_string)
-
- out1, out2, out3 = sess.run([output1, output2, output3])
- self.assertAllEqual([0, 1, -1], out1)
- self.assertAllEqual([0, 1, -1], out2)
- self.assertAllEqual([0, 1, -1], out3)
-
- def testHashTableWithTensorDefault(self):
- with self.test_session():
- shared_name = ''
- default_val = tf.constant(-1, tf.int64)
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- init = table.initialize_from(keys, values)
- init.run()
-
- input_string = tf.constant(['brain', 'salad', 'tank'])
- output = table.lookup(input_string)
-
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
-
- def testSignatureMismatch(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- init = table.initialize_from(keys, values)
- init.run()
-
- input_string = tf.constant([1, 2, 3], tf.int64)
- with self.assertRaises(TypeError):
- table.lookup(input_string)
-
- with self.assertRaises(TypeError):
- tf.HashTable(tf.string, tf.int64, 'UNK', shared_name)
-
- def testDTypes(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- with self.assertRaises(TypeError):
- tf.HashTable([tf.string], tf.string, default_val, shared_name)
-
- def testNotInitialized(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- input_string = tf.constant(['brain', 'salad', 'surgery'])
- output = table.lookup(input_string)
-
- with self.assertRaisesOpError('Table not initialized'):
- output.eval()
-
- def testInitializeTwice(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2], tf.int64)
- init = table.initialize_from(keys, values)
- init.run()
-
- with self.assertRaisesOpError('Table already initialized'):
- init.run()
-
- def testInitializationWithInvalidDimensions(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = tf.constant(['brain', 'salad', 'surgery'])
- values = tf.constant([0, 1, 2, 3, 4], tf.int64)
- with self.assertRaises(ValueError):
- table.initialize_from(keys, values)
-
- def testInitializationWithInvalidDataTypes(self):
- with self.test_session():
- shared_name = ''
- default_val = -1
- table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
- # Initialize with keys and values tensors.
- keys = [0, 1, 2]
- values = ['brain', 'salad', 'surgery']
- with self.assertRaises(TypeError):
- table.initialize_from(keys, values)
-
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 3dedd33cb9..9893c9b824 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -74,7 +74,7 @@ def clip_by_norm(t, clip_norm, name=None):
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
l2norm_inv = math_ops.rsqrt(
- math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name)
@@ -228,7 +228,7 @@ def clip_by_average_norm(t, clip_norm, name=None):
# L2-norm per element
n_element = math_ops.cast(array_ops.size(t), types.float32)
l2norm_inv = math_ops.rsqrt(
- math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
tclip = array_ops.identity(
t * clip_norm * math_ops.minimum(
l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)),
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 178f716e48..ed09ff3655 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -10,7 +10,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import control_flow_ops
@@ -408,174 +407,12 @@ class FIFOQueue(QueueBase):
# TODO(josh11b): class BatchQueue(QueueBase):
-# pylint: disable=protected-access
-class LookupTableBase(object):
- """Represents a lookup table that persists across different steps."""
-
- def __init__(self, key_dtype, value_dtype, default_value, table_ref):
- """Construct a table object from a table reference.
-
- Args:
- key_dtype: The table key type.
- value_dtype: The table value type.
- default_value: The value to use if a key is missing in the table.
- table_ref: The table reference, i.e. the output of the lookup table ops.
- """
- self._key_dtype = types.as_dtype(key_dtype)
- self._value_dtype = types.as_dtype(value_dtype)
- self._shapes = [tensor_shape.TensorShape([1])]
- self._table_ref = table_ref
- self._name = self._table_ref.op.name.split("/")[-1]
- self._default_value = ops.convert_to_tensor(default_value,
- dtype=self._value_dtype)
- self._default_value.get_shape().merge_with(tensor_shape.scalar())
-
- @property
- def table_ref(self):
- """Get the underlying table reference."""
- return self._table_ref
-
- @property
- def key_dtype(self):
- """The table key dtype."""
- return self._key_dtype
-
- @property
- def value_dtype(self):
- """The table value dtype."""
- return self._value_dtype
-
- @property
- def name(self):
- """The name of the table."""
- return self._name
-
- @property
- def default_value(self):
- """The default value of the table."""
- return self._default_value
-
- def size(self, name=None):
- """Compute the number of elements in this table.
-
- Args:
- name: A name for the operation (optional).
-
- Returns:
- A scalar tensor containing the number of elements in this table.
- """
- if name is None:
- name = "%s_Size" % self._name
- return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
-
- def lookup(self, keys, name=None):
- """Looks up `keys` in a table, outputs the corresponding values.
-
- The `default_value` is use for keys not present in the table.
-
- Args:
- keys: Keys to look up.
- name: Optional name for the op.
-
- Returns:
- The operation that looks up the keys.
-
- Raises:
- TypeError: when `keys` or `default_value` doesn't match the table data
- types.
- """
- if name is None:
- name = "%s_lookup_table_find" % self._name
-
- if keys.dtype != self._key_dtype:
- raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
- self._key_dtype, keys.dtype))
-
- return gen_data_flow_ops._lookup_table_find(
- self._table_ref, keys, self._default_value, name=name)
-
- def initialize_from(self, keys, values, name=None):
- """Initialize the table with the provided keys and values tensors.
-
- Construct an initializer object from keys and value tensors.
-
- Args:
- keys: The tensor for the keys.
- values: The tensor for the values.
- name: Optional name for the op.
-
- Returns:
- The operation that initializes the table.
-
- Raises:
- TypeError: when the keys and values data types do not match the table
- key and value data types.
- """
- if name is None:
- name = "%s_initialize_table" % self.name
- with ops.op_scope([keys, values], None, name):
- keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
- values = ops.convert_to_tensor(values, dtype=self.value_dtype,
- name="values")
-
- init_op = gen_data_flow_ops._initialize_table(
- self.table_ref, keys, values, name=name)
- ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
- return init_op
-
- def _check_table_dtypes(self, key_dtype, value_dtype):
- """Check that the given key_dtype and value_dtype matches the table dtypes'.
-
- Args:
- key_dtype: The key data type to check.
- value_dtype: The value data type to check.
-
- Raises:
- TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
- types.
- """
- if key_dtype != self.key_dtype:
- raise TypeError("Invalid key dtype, expected %s but got %s." % (
- self.key_dtype, key_dtype))
- if value_dtype != self.value_dtype:
- raise TypeError("Invalid value dtype, expected %s but got %s." % (
- self.value_dtype, value_dtype))
-
-
-class HashTable(LookupTableBase):
- """A generic hash table implementation."""
-
- def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
- name="hash_table"):
- """Creates a non-initialized hash table.
-
- This op creates a hash table, specifying the type of its keys and values.
- Before using the table you will have to initialize it. After initialization
- the table will be immutable.
-
- Args:
- key_dtype: Type of the table keys.
- value_dtype: Type of the table values.
- default_value: The scalar tensor to be used when a key is missing in the
- table.
- shared_name: Optional. If non-empty, this table will be shared under
- the given name across multiple sessions.
- name: Optional name for the hash table op.
-
- Returns:
- A `HashTable` object.
- """
- table_ref = gen_data_flow_ops._hash_table(
- shared_name=shared_name, key_dtype=key_dtype,
- value_dtype=value_dtype, name=name)
-
- super(HashTable, self).__init__(key_dtype, value_dtype, default_value,
- table_ref)
-
-
def initialize_all_tables(name="init_all_tables"):
"""Returns an Op that initializes all tables of the default graph.
+ Args:
+ name: Optional name for the initialization op.
+
Returns:
An Op that initializes all tables. Note that if there are
not tables the returned Op is a NoOp.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 80bedd4984..b74f8f5426 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -51,7 +51,7 @@ def embedding_lookup(params, ids, name=None):
else:
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
- original_indices = math_ops.range(0, array_ops.size(flat_ids))
+ original_indices = math_ops.range(array_ops.size(flat_ids))
# Compute flat_ids % partitions for each id
ids_mod_p = flat_ids % np
if ids_mod_p.dtype != types.int32:
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index b404fbc7d7..b8b0741e07 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -21,7 +21,7 @@ def _ReductionGradAssist(op):
indices = op.inputs[1] # [1, 2]
indices_shape = array_ops.shape(indices) # [2]
new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
- [math_ops.range(0, input_rank), # [0, 1, 2, 3]
+ [math_ops.range(input_rank), # [0, 1, 2, 3]
indices], # [1, 2]
[input_shape, # [2, 3, 5, 7]
array_ops.fill(indices_shape, 1)]) # [1, 1]
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f7289ff234..4a2473cae5 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -536,11 +536,14 @@ ops.Tensor._override_operator("__gt__", greater)
ops.Tensor._override_operator("__ge__", greater_equal)
-def range(start, limit, delta=1, name="range"):
+def range(start, limit=None, delta=1, name="range"):
"""Creates a sequence of integers.
- This operation creates a sequence of integers that begins at `start` and
- extends by increments of `delta` up to but not including `limit`.
+ Creates a sequence of integers that begins at `start` and extends by
+ increments of `delta` up to but not including `limit`.
+
+ Like the Python builtin `range`, `start` defaults to 0, so that
+ `range(n) = range(0, n)`.
For example:
@@ -549,10 +552,14 @@ def range(start, limit, delta=1, name="range"):
# 'limit' is 18
# 'delta' is 3
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+
+ # 'limit' is 5
+ tf.range(limit) ==> [0, 1, 2, 3, 4]
```
Args:
start: A 0-D (scalar) of type `int32`. First entry in sequence.
+ Defaults to 0.
limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
exclusive.
delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
@@ -562,6 +569,8 @@ def range(start, limit, delta=1, name="range"):
Returns:
An 1-D `int32` `Tensor`.
"""
+ if limit is None:
+ start, limit = 0, start
return gen_math_ops._range(start, limit, delta, name=name)
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index caf47b1431..5a5c06f975 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -173,6 +173,7 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
@@ -347,7 +348,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
Args:
x: A tensor.
- keep_prob: A Python float. The probability that each element is kept.
+ keep_prob: A scalar `Tensor` with the same type as x. The probability
+ that each element is kept.
noise_shape: A 1-D `Tensor` of type `int32`, representing the
shape for randomly generated keep/drop flags.
seed: A Python integer. Used to create random seeds. See
@@ -361,10 +363,15 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
Raises:
ValueError: If `keep_prob` is not in `(0, 1]`.
"""
- if not (0 < keep_prob <= 1):
- raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
with ops.op_scope([x], name, "dropout") as name:
x = ops.convert_to_tensor(x, name="x")
+ if isinstance(keep_prob, float) and not(0 < keep_prob <= 1):
+ raise ValueError("keep_prob must be a scalar tensor or a float in the "
+ "range (0, 1], got %g" % keep_prob)
+ keep_prob = ops.convert_to_tensor(
+ keep_prob, dtype=x.dtype, name="keep_prob")
+ keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
noise_shape = noise_shape or array_ops.shape(x)
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
@@ -372,7 +379,9 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
noise_shape, seed=seed, dtype=x.dtype)
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor)
- return x * (1.0 / keep_prob) * binary_tensor
+ ret = x * math_ops.inv(keep_prob) * binary_tensor
+ ret.set_shape(x.get_shape())
+ return ret
def depthwise_conv2d(input, filter, strides, padding, name=None):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 535d55f00f..919ce78491 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -85,7 +85,7 @@ def _BiasAddGrad(unused_bias_op, received_grad):
Two tensors, the first one for the "tensor" input of the BiasOp,
the second one for the "bias" input of the BiasOp.
"""
- reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1)
+ reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 48f7a4c987..c24bd0e372 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -405,28 +405,82 @@ class DropoutTest(test_util.TensorFlowTestCase):
sorted_value = np.unique(np.sort(value[i, :]))
self.assertEqual(sorted_value.size, 1)
+ def testDropoutPlaceholderKeepProb(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ keep_prob_placeholder = array_ops.placeholder(types.float32)
+ dropout = nn.dropout(t, keep_prob_placeholder)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropoutUnknownShape(self):
+ x_dim = 40
+ y_dim = 30
+ keep_prob = 0.5
+ x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32)
+ dropout_x = nn.dropout(
+ x, keep_prob, noise_shape=array_ops.placeholder(types.int32))
+ self.assertEqual(x.get_shape(), dropout_x.get_shape())
+
+ def testInvalidKeepProb(self):
+ x_dim = 40
+ y_dim = 30
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ with self.assertRaises(ValueError):
+ nn.dropout(t, -1.0)
+ with self.assertRaises(ValueError):
+ nn.dropout(t, 1.1)
+ with self.assertRaises(ValueError):
+ nn.dropout(t, [0.0, 1.0])
+ with self.assertRaises(ValueError):
+ nn.dropout(t, array_ops.placeholder(types.float64))
+ with self.assertRaises(ValueError):
+ nn.dropout(t, array_ops.placeholder(types.float32, shape=[2]))
+
def testShapedDropoutShapeError(self):
# Runs shaped dropout and verifies an error is thrown on misshapen noise.
x_dim = 40
y_dim = 30
keep_prob = 0.5
- with self.test_session():
- t = constant_op.constant(1.0,
- shape=[x_dim, y_dim],
- dtype=types.float32)
- with self.assertRaises(ValueError):
- _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
- with self.assertRaises(ValueError):
- _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
- with self.assertRaises(ValueError):
- _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
- with self.assertRaises(ValueError):
- _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
- # test that broadcasting proceeds
- _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
- _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
- _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
- _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
+ # test that broadcasting proceeds
+ _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 1a7af78a33..23d3a974cf 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -437,8 +437,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
default_value, dtype=sp_input.values.dtype)
num_rows = math_ops.cast(sp_input.shape[0], types.int32)
- all_row_indices = math_ops.cast(
- math_ops.range(0, num_rows, 1), types.int64)
+ all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64)
empty_row_indices, _ = array_ops.list_diff(
all_row_indices, sp_input.indices[:, 0])
empty_row_indicator = gen_sparse_ops.sparse_to_dense(
diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py
index 66bf2c0889..23318691ee 100644
--- a/tensorflow/python/platform/default/_logging.py
+++ b/tensorflow/python/platform/default/_logging.py
@@ -6,18 +6,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging
import os
import sys
import time
import thread
-from logging import getLogger
-from logging import log
-from logging import debug
-from logging import error
-from logging import fatal
-from logging import info
-from logging import warn
-from logging import warning
from logging import DEBUG
from logging import ERROR
from logging import FATAL
@@ -25,13 +18,25 @@ from logging import INFO
from logging import WARN
# Controls which methods from pyglib.logging are available within the project
-# Do not add methods here without also adding to platform/default/_logging.py
+# Do not add methods here without also adding to platform/google/_logging.py
__all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning',
'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN',
'flush', 'log_every_n', 'log_first_n', 'vlog',
'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity']
-warning = warn
+# Scope the tensorflow logger to not conflict with users' loggers
+_logger = logging.getLogger('tensorflow')
+_handler = logging.StreamHandler()
+_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT, None))
+_logger.addHandler(_handler)
+
+log = _logger.log
+debug = _logger.debug
+error = _logger.error
+fatal = _logger.fatal
+info = _logger.info
+warn = _logger.warn
+warning = _logger.warn
_level_names = {
FATAL: 'FATAL',
@@ -61,7 +66,7 @@ def flush():
# Code below is taken from pyglib/logging
def vlog(level, msg, *args, **kwargs):
- log(level, msg, *args, **kwargs)
+ _logger.log(level, msg, *args, **kwargs)
def _GetNextLogCountPerToken(token):
@@ -169,12 +174,12 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None):
def get_verbosity():
"""Return how much logging output will be produced."""
- return getLogger().getEffectiveLevel()
+ return _logger.getEffectiveLevel()
def set_verbosity(verbosity):
"""Sets the threshold for what messages will be logged."""
- getLogger().setLevel(verbosity)
+ _logger.setLevel(verbosity)
def _get_thread_id():
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 6734690397..77c496fb85 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -163,7 +163,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
is added to the current Graph's QUEUE_RUNNER collection.
"""
with ops.op_scope([limit], name, "input_producer") as name:
- range_tensor = math_ops.range(0, limit)
+ range_tensor = math_ops.range(limit)
return _input_producer(
range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name,
"fraction_of_%d_full" % capacity)
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 8450fae5cb..c1f886d664 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -33,9 +33,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
- learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
- 100000, 0.96, staircase=True)
- optimizer = tf.GradientDescent(learning_rate)
+ learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
+ 100000, 0.96, staircase=True)
+ optimizer = tf.GradientDescentOptimizer(learning_rate)
# Passing global_step to minimize() will increment it at each step.
optimizer.minimize(...my loss..., global_step=global_step)
```
diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json
index bdd16d662a..995ba30363 100644
--- a/tensorflow/tensorboard/bower.json
+++ b/tensorflow/tensorboard/bower.json
@@ -19,26 +19,27 @@
"dagre": "~0.7.4",
"es6-promise": "~3.0.2",
"graphlib": "~1.0.7",
- "iron-ajax": "PolymerElements/iron-ajax#~1.0.8",
- "iron-collapse": "PolymerElements/iron-collapse#~1.0.4",
- "iron-list": "PolymerElements/iron-list#~1.1.5",
- "paper-button": "PolymerElements/paper-button#~1.0.7",
- "paper-checkbox": "PolymerElements/paper-checkbox#~1.0.6",
- "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#~1.0.4",
- "paper-header-panel": "PolymerElements/paper-header-panel#~1.0.5",
- "paper-icon-button": "PolymerElements/paper-icon-button#~1.0.3",
- "paper-input": "PolymerElements/paper-input#~1.0.15",
- "paper-item": "PolymerElements/paper-item#~1.0.3",
- "paper-menu": "PolymerElements/paper-menu#~1.1.1",
- "paper-progress": "PolymerElements/paper-progress#~1.0.7",
- "paper-radio-button": "PolymerElements/paper-radio-button#~1.0.8",
- "paper-radio-group": "PolymerElements/paper-radio-group#~1.0.4",
- "paper-slider": "PolymerElements/paper-slider#~1.0.4",
- "paper-styles": "PolymerElements/paper-styles#~1.0.11",
- "paper-toggle-button": "PolymerElements/paper-toggle-button#~1.0.6",
- "paper-toolbar": "PolymerElements/paper-toolbar#~1.0.4",
+ "iron-ajax": "PolymerElements/iron-ajax#1.0.7",
+ "iron-collapse": "PolymerElements/iron-collapse#1.0.4",
+ "iron-list": "PolymerElements/iron-list#1.1.5",
+ "iron-selector": "PolymerElements/iron-selector#1.0.7",
+ "paper-button": "PolymerElements/paper-button#1.0.8",
+ "paper-checkbox": "PolymerElements/paper-checkbox#1.0.13",
+ "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.0.5",
+ "paper-header-panel": "PolymerElements/paper-header-panel#1.0.5",
+ "paper-icon-button": "PolymerElements/paper-icon-button#1.0.5",
+ "paper-input": "PolymerElements/paper-input#1.0.16",
+ "paper-item": "PolymerElements/paper-item#1.0.5",
+ "paper-menu": "PolymerElements/paper-menu#1.1.1",
+ "paper-progress": "PolymerElements/paper-progress#1.0.7",
+ "paper-radio-button": "PolymerElements/paper-radio-button#1.0.10",
+ "paper-radio-group": "PolymerElements/paper-radio-group#1.0.6",
+ "paper-slider": "PolymerElements/paper-slider#1.0.7",
+ "paper-styles": "PolymerElements/paper-styles#1.0.12",
+ "paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
+ "paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
"plottable": "~1.16.1",
- "polymer": "~1.2.0"
+ "polymer": "1.1.5"
},
"devDependencies": {
"iron-component-page": "PolymerElements/iron-component-page#^1.0.0",
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py
index 0ea7f3a58d..2cec3b8812 100644
--- a/tensorflow/tensorboard/tensorboard_handler.py
+++ b/tensorflow/tensorboard/tensorboard_handler.py
@@ -17,9 +17,9 @@ import json
import mimetypes
import os
import StringIO
-import urllib
import urlparse
+from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
import tensorflow.python.platform
@@ -289,7 +289,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
A string representation of a URL that will load the index-th
sampled image in the given run with the given tag.
"""
- query_string = urllib.urlencode({
+ query_string = urllib.parse.urlencode({
'run': run,
'tag': tag,
'index': index