diff options
author | Shanqing Cai <cais@google.com> | 2017-12-24 21:42:52 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-24 21:42:52 -0500 |
commit | aece567081c6fed35bba14745131a67ca824a4a6 (patch) | |
tree | f50fa4c4a07dd360647e2326c0ebef790b7d0f3e | |
parent | 26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca (diff) | |
parent | 4a7385f88194ae64c2cad7e2a920b66dd7edc982 (diff) |
Merge pull request #15615 from caisq/branch_180008567
Branch 180008567
109 files changed, 1322 insertions, 615 deletions
diff --git a/configure.py b/configure.py index 1917af4b65..e4218b5651 100644 --- a/configure.py +++ b/configure.py @@ -265,19 +265,6 @@ def reset_tf_configure_bazelrc(): f.write('import %workspace%/.tf_configure.bazelrc\n') -def run_gen_git_source(environ_cp): - """Run the gen_git_source to create links. - - The links are for bazel to track dependencies for git hash propagation. - - Args: - environ_cp: copy of the os.environ. - """ - cmd = '"%s" tensorflow/tools/git/gen_git_source.py --configure %s' % ( - environ_cp.get('PYTHON_BIN_PATH'), os.getcwd()) - os.system(cmd) - - def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -1251,7 +1238,6 @@ def main(): reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) - run_gen_git_source(environ_cp) if is_windows(): environ_cp['TF_NEED_S3'] = '0' diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index beffa191d1..589afb9031 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -518,6 +518,15 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } return; } + std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; + if (ctx->should_store_metadata.load()) { + maybe_stats.reset(new tensorflow::NodeExecStats); + maybe_stats->set_node_name(op->name); + maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); + maybe_stats->set_op_start_rel_micros(0); + maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); + // TODO(apassos) track referenced tensors + } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation @@ -525,12 +534,38 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). - status->status = kernel->Run(&op->inputs, &outputs); + // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats + // for ops which are a part of functions. + status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); for (auto* t : copied_tensors) { TFE_DeleteTensorHandle(t); } if (!status->status.ok()) return; + if (maybe_stats != nullptr) { + maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - + maybe_stats->all_start_micros()); + tensorflow::mutex_lock ml(ctx->metadata_mu); + if (ctx->should_store_metadata.load()) { + auto* step_stats = ctx->run_metadata.mutable_step_stats(); + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices().size()) { + step_stats->add_dev_stats(); + } + // Find the current device's index. + int device_idx = 0; + for (int i = 0; i < ctx->devices().size(); ++i) { + if (ctx->devices()[i] == device) { + device_idx = i; + break; + } + } + // Populate the device stats for this device. + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + dev_stats->set_device(device->name()); + *dev_stats->add_node_stats() = *maybe_stats; + } + } *num_retvals = std::min<int>(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { tensorflow::Device* d = IsCPU(device) ? nullptr : device; @@ -577,3 +612,20 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( } return &h->t; } + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->should_store_metadata.store(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + ctx->should_store_metadata.store(false); + ctx->run_metadata.Clear(); +} + +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + status->status = MessageToBuffer(ctx->run_metadata, buf); + ctx->run_metadata.Clear(); +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index ca105962df..7caab43d00 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -207,6 +207,19 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Enables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); + +// Disables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); + +// Populates the passed-in buffer with a serialized RunMetadata protocol buffer +// containing any run metadata information accumulated so far and clears this +// information. +TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 0971e2ab2f..11b7a519c3 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -67,6 +67,11 @@ struct TFE_Context { } const std::vector<tensorflow::Device*>& devices() { return session->devices; } + + // Whether we should compute RunMetadata. + std::atomic<bool> should_store_metadata{false}; + tensorflow::mutex metadata_mu; + tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); }; struct TFE_TensorHandle { diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index c5ec0cfc31..423a7e1ff7 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/config.pb.h" using tensorflow::string; @@ -353,6 +354,47 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_ContextEnableRunMetadata(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TF_Buffer* b = TF_NewBuffer(); + TFE_ContextExportRunMetadata(ctx, b, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::RunMetadata rm; + EXPECT_TRUE( + rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length})); + TF_DeleteBuffer(b); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} + TEST(CAPI, Function) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 38066682a9..ec34b0ea77 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -262,7 +262,8 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, } Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, - std::vector<Tensor>* output_tensors) { + std::vector<Tensor>* output_tensors, + NodeExecStats* stats) { gtl::InlinedVector<TensorValue, 4> inputs; for (Tensor& t : *input_tensors) { inputs.push_back(TensorValue(&t)); @@ -284,6 +285,9 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, params.function_library = flib_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; + if (stats != nullptr) { + params.track_allocations = true; + } // TODO(apassos): use a thread pool. std::function<void(std::function<void()>)> runner = [](std::function<void()> f) { f(); }; @@ -297,6 +301,34 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, for (int i = 0; i < context.num_outputs(); ++i) { output_tensors->push_back(Tensor(*context.mutable_output(i))); } + if (stats != nullptr) { + for (const auto& allocator_pair : context.wrapped_allocators()) { + AllocatorMemoryUsed* memory = stats->add_memory(); + memory->set_allocator_name(allocator_pair.first->Name()); + auto sizes = allocator_pair.second->GetSizes(); + memory->set_total_bytes(std::get<0>(sizes)); + memory->set_peak_bytes(std::get<1>(sizes)); + memory->set_live_bytes(std::get<2>(sizes)); + + AllocatorStats allocator_stats; + allocator_pair.first->GetStats(&allocator_stats); + memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); + allocator_pair.second->GetRecordsAndUnRef(); + } + auto* ms = stats->mutable_memory_stats(); + ms->set_host_temp_memory_size(context.host_temp_memory_size()); + ms->set_device_temp_memory_size(context.device_temp_memory_size()); + for (const auto& alloc_id : context.host_persistent_alloc_ids()) { + ms->mutable_host_persistent_tensor_alloc_ids()->Add(alloc_id); + } + for (const auto& alloc_id : context.device_persistent_alloc_ids()) { + ms->mutable_device_persistent_tensor_alloc_ids()->Add(alloc_id); + } + ms->set_host_persistent_memory_size( + context.host_persistent_memory_allocated()); + ms->set_device_persistent_memory_size( + context.device_persistent_memory_allocated()); + } return Status::OK(); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index fb97e94a94..e28a416e67 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -175,7 +175,8 @@ class KernelAndDevice { : device_(nullptr), flib_(nullptr), rendez_(rendez) {} // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs); + Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, + NodeExecStats* stats); const OpKernel* kernel() const { return kernel_.get(); } diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 3236c6be0e..2ccca66f67 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -96,7 +96,7 @@ TEST(KernelAndDevice, Run) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); ASSERT_TRUE(s.ok()) << s; std::vector<Tensor> outputs; - s = kernel.Run(&inputs, &outputs); + s = kernel.Run(&inputs, &outputs, nullptr); ASSERT_TRUE(s.ok()) << s; ASSERT_EQ(1, outputs.size()); const Tensor& out = outputs[0]; @@ -183,7 +183,7 @@ void BM_KernelAndDeviceRun(int iters) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs)); + TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); } } BENCHMARK(BM_KernelAndDeviceRun); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 2374620f58..13343967c4 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -215,7 +215,6 @@ cc_library( ":common", ":compilation_passes", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -245,7 +244,6 @@ cc_library( "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 4ccea57422..3c7dfef03d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -120,6 +120,7 @@ cc_library( cc_library( name = "xla_compiler", srcs = [ + "const_analysis.cc", "graph_compiler.cc", "xla_compilation_device.cc", "xla_compiler.cc", @@ -133,6 +134,7 @@ cc_library( "xla_gpu_backend.cc", ]), hdrs = [ + "const_analysis.h", "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", @@ -145,7 +147,6 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":const_analysis", ":dump_graph", ":functionalize_control_flow", ":sharding_util", @@ -160,7 +161,6 @@ cc_library( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -356,28 +356,16 @@ tf_cc_test( ], ) -cc_library( - name = "const_analysis", - srcs = ["const_analysis.cc"], - hdrs = ["const_analysis.h"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_cc_test( name = "const_analysis_test", size = "small", srcs = ["const_analysis_test.cc"], deps = [ - ":const_analysis", + ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:ops", "//tensorflow/core:test", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index ab2f1e9a7a..0249500910 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -27,96 +28,18 @@ namespace tensorflow { // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, std::vector<bool>* compile_time_const_args) { - // TODO(phawkins): annotate these on the kernel registrations, rather than - // using a hard-coded list. - // (operator, argument) pairs that must be compile-time constants. - const std::unordered_multimap<string, string> compile_time_const_inputs = { - {"All", "reduction_indices"}, - {"Any", "reduction_indices"}, - {"ArgMin", "dimension"}, - {"ArgMax", "dimension"}, - {"AvgPoolGrad", "orig_input_shape"}, - {"AvgPool3DGrad", "orig_input_shape"}, - {"BatchToSpace", "crops"}, - {"BatchToSpaceND", "block_shape"}, - {"BatchToSpaceND", "crops"}, - {"BroadcastArgs", "s0"}, - {"BroadcastArgs", "s1"}, - {"BroadcastGradientArgs", "s0"}, - {"BroadcastGradientArgs", "s1"}, - {"Concat", "concat_dim"}, - {"ConcatV2", "axis"}, - {"ConcatOffset", "concat_dim"}, - {"ConcatOffset", "shape"}, - {"Conv2DBackpropFilter", "filter_sizes"}, - {"Conv2DBackpropInput", "input_sizes"}, - {"Conv3DBackpropFilterV2", "filter_sizes"}, - {"Conv3DBackpropInputV2", "input_sizes"}, - {"Cumprod", "axis"}, - {"Cumsum", "axis"}, - {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, - {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, - {"DynamicStitch", "indices"}, - {"ExpandDims", "dim"}, - {"Fill", "dims"}, - {"GatherV2", "axis"}, - {"InvertPermutation", "x"}, - {"LinSpace", "start"}, - {"LinSpace", "stop"}, - {"LinSpace", "num"}, - {"Max", "reduction_indices"}, - {"Mean", "reduction_indices"}, - {"Min", "reduction_indices"}, - {"OneHot", "depth"}, - {"Pad", "paddings"}, - {"PadV2", "paddings"}, - {"MirrorPad", "paddings"}, - {"Multinomial", "num_samples"}, - {"Prod", "reduction_indices"}, - {"RandomStandardNormal", "shape"}, - {"RandomUniform", "shape"}, - {"RandomUniformInt", "shape"}, - {"Range", "start"}, - {"Range", "limit"}, - {"Range", "delta"}, - {"Reshape", "shape"}, - {"ResizeBilinear", "size"}, - {"ResourceStridedSliceAssign", "begin"}, - {"ResourceStridedSliceAssign", "end"}, - {"ResourceStridedSliceAssign", "strides"}, - {"Reverse", "dims"}, - {"ReverseV2", "axis"}, - {"Slice", "begin"}, - {"Slice", "size"}, - {"SpaceToBatch", "paddings"}, - {"SpaceToBatchND", "block_shape"}, - {"SpaceToBatchND", "paddings"}, - {"Split", "split_dim"}, - {"SplitV", "split_dim"}, - {"SplitV", "size_splits"}, - {"StackV2", "max_size"}, - {"StridedSlice", "begin"}, - {"StridedSlice", "end"}, - {"StridedSlice", "strides"}, - {"StridedSliceGrad", "shape"}, - {"StridedSliceGrad", "begin"}, - {"StridedSliceGrad", "end"}, - {"StridedSliceGrad", "strides"}, - {"Sum", "reduction_indices"}, - {"TensorArrayV3", "size"}, - {"TensorArraySplitV3", "lengths"}, - {"Tile", "multiples"}, - {"Transpose", "perm"}}; - // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set<string> metadata_ops = { - "Rank", "Shape", "ShapeN", "Size", + "Rank", + "Shape", + "ShapeN", + "Size", }; Status status; std::unordered_set<Node*> must_be_const; - auto visit = [&status, &metadata_ops, &compile_time_const_inputs, - &must_be_const, compile_time_const_args](Node* node) { + auto visit = [&status, &metadata_ops, &must_be_const, + compile_time_const_args](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -139,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g, } // Mark any compile-time constant operator arguments as const. - auto range = compile_time_const_inputs.equal_range(node->type_string()); - if (range.first == range.second) return; + const std::unordered_set<string>* const_inputs = + XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); + if (!const_inputs) return; NameRangeMap input_name_ranges; status = NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; - for (auto it = range.first; it != range.second; ++it) { - auto name_range = input_name_ranges.find(it->second); + for (const string& input : *const_inputs) { + auto name_range = input_name_ranges.find(input); if (name_range == input_name_ranges.end()) continue; for (Edge const* edge : node->in_edges()) { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 21d3e64872..344a2ab2b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); +REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), + BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { public: @@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); +REGISTER_XLA_OP(Name("BatchToSpace") + .CompileTimeConstInput("crops") + .CompileTimeConstInput("block_shape"), + BatchToSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index bb031b8c47..ee2c920453 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp); +REGISTER_XLA_OP(Name("BroadcastArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastArgsOp); // Given shapes of two tensors, computes the reduction indices for the // gradient computation. @@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp); +REGISTER_XLA_OP(Name("BroadcastGradientArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastGradArgsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 592f3ecc3c..545aa364f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel { }; // TODO(b/68769717): Rename this sampler to Categorical. -REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); +REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"), + CategoricalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 73a4740e29..1a246e8df9 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel { in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in_shape.DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. input_data.push_back(ctx->builder()->Reshape(handle, {1})); @@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP(Name("Concat"), ConcatOp); -REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op); +REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); +REGISTER_XLA_OP(Name("ConcatV2") + .TypeConstraint("Tidx", DT_INT32) + .CompileTimeConstInput("axis"), + ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { public: @@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel { } else { const int32 inp0_element = inp0_literal.Get<int>({j}); const int32 inp_element = inp_literal.Get<int>({j}); - OP_REQUIRES( - ctx, (inp0_element == inp_element), - errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", - inp0_element, " vs. ", inp_element)); + OP_REQUIRES(ctx, (inp0_element == inp_element), + errors::InvalidArgument("input[", i, ",", j, + "] mismatch: ", inp0_element, + " vs. ", inp_element)); out_vec(j) = 0; } } @@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp); +REGISTER_XLA_OP(Name("ConcatOffset") + .CompileTimeConstInput("concat_dim") + .CompileTimeConstInput("shape"), + ConcatOffsetOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index aaddbe811c..da9be68732 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -445,21 +445,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), + Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { public: explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), + Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { public: explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") + .CompileTimeConstInput("input_sizes"), DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { @@ -644,7 +649,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), + Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { public: @@ -652,14 +659,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), + Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { public: explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") + .CompileTimeConstInput("filter_sizes"), DepthwiseConv2DBackpropFilterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 7349dcb987..f2cd21ffb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel { XLAShapeToTensorShape(indices_input[input_num].shape(), &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; - OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), - errors::InvalidArgument( - "data[", input_num, "].shape = ", - data_shape.DebugString(), " does not start with indices[", - input_num, "].shape = ", indices_shape.DebugString())); - OP_REQUIRES(ctx, - input_num == 0 || SameExtraShape(data0_shape, indices0_shape, - data_shape, indices_shape), - errors::InvalidArgument( - "Need data[0].shape[", indices0_shape.dims(), - ":] = data[", input_num, "].shape[", indices_shape.dims(), - ":], got data[0].shape = ", data0_shape.DebugString(), - ", data[", input_num, "].shape = ", - data_shape.DebugString(), ", indices[0].shape = ", - indices0_shape.DebugString(), ", indices[", input_num, - "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), + errors::InvalidArgument("data[", input_num, + "].shape = ", data_shape.DebugString(), + " does not start with indices[", input_num, + "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, + input_num == 0 || SameExtraShape(data0_shape, indices0_shape, + data_shape, indices_shape), + errors::InvalidArgument( + "Need data[0].shape[", indices0_shape.dims(), ":] = data[", + input_num, "].shape[", indices_shape.dims(), + ":], got data[0].shape = ", data0_shape.DebugString(), ", data[", + input_num, "].shape = ", data_shape.DebugString(), + ", indices[0].shape = ", indices0_shape.DebugString(), + ", indices[", input_num, + "].shape = ", indices_shape.DebugString())); OP_REQUIRES_OK(ctx, XlaHelpers::ReshapeLiteral(indices_input[input_num], @@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector<int64> slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); - std::vector<int64> stride(1 + data0_shape.dims() - - indices0_shape.dims(), 1); + std::vector<int64> stride(1 + data0_shape.dims() - indices0_shape.dims(), + 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); -REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 9e090fe01c..eaa13b8dfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Fill"), FillOp); +REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e420f21ca3..70192cb324 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -198,6 +198,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { } REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), + GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d91ebb500b..c0b8f9c179 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -243,7 +243,8 @@ class ResizeBilinearOp : public XlaOpKernel { bool align_corners_; }; -REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp); +REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), + ResizeBilinearOp); class ResizeBilinearGradOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index e0dc1870f2..7bf4b435f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} -REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp); +REGISTER_XLA_OP(Name("ArgMax") + .Device(DEVICE_GPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + XlaArgMaxOp); namespace { @@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp { }; XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} -REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp); +REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 20946e247a..b1f3c3c298 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { errors::InvalidArgument("dim must be < input rank (", input_shape.dims(), "), but got: ", dim)); const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES( - ctx, dim_size > 0, - errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", - input_shape.DebugString())); + OP_REQUIRES(ctx, dim_size > 0, + errors::InvalidArgument( + "Reduction axis ", dim, + " is empty in shape: ", input_shape.DebugString())); // The output shape is the input shape contracted along dim. TensorShape output_shape; @@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); }; -REGISTER_XLA_OP( - Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), - ArgMaxCustomCallOp); +REGISTER_XLA_OP(Name("ArgMax") + .TypeConstraint("T", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + ArgMaxCustomCallOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index bea1d1600b..05a36a031a 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp); }; -REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp); +REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"), + MirrorPadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 2a9cfcb2eb..9f7c991380 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -REGISTER_XLA_OP(Name("OneHot"), OneHotOp); +REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index d841bd37b3..791351637a 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Pad"), PadOp); -REGISTER_XLA_OP(Name("PadV2"), PadOp); +REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 2b6053d19d..0b5a38967a 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -455,14 +455,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp { errors::InvalidArgument("Invalid data format")); } }; -REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp); +REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool2DGradOp); class AvgPool3DGradOp : public AvgPoolGradOp { public: explicit AvgPool3DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} }; -REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp); +REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool3DGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2421825ead..c0994c434b 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp); +REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), + RandomUniformOp); class RandomUniformIntOp : public XlaOpKernel { public: @@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp); +REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), + RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { public: @@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp); +REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), + RandomStandardNormalOp); class TruncatedNormalOp : public XlaOpKernel { public: @@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp); +REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), + TruncatedNormalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 647b627408..03b13b2924 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Sum"), SumOp); +REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: @@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Prod"), ProdOp); +REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), + ProdOp); class MinOp : public XlaReductionOp { public: @@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Min"), MinOp); +REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: @@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Max"), MaxOp); +REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: @@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Mean"), MeanOp); +REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), + MeanOp); class AllOp : public XlaReductionOp { public: @@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("All"), AllOp); +REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: @@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Any"), AnyOp); +REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 5952e75272..af4d64b159 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reshape"), ReshapeOp); +REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index bdfd066f01..17a345fc94 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -65,7 +65,7 @@ class ReverseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reverse"), ReverseOp); +REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp); class ReverseV2Op : public XlaOpKernel { public: @@ -103,7 +103,7 @@ class ReverseV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op); +REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 650f8c7dc8..ee4a94164c 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -129,13 +129,19 @@ class CumsumOp : public ScanOp { public: explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} }; -REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); +REGISTER_XLA_OP(Name("Cumsum") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumsumOp); class CumprodOp : public ScanOp { public: explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} }; -REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); +REGISTER_XLA_OP(Name("Cumprod") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumprodOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index c2b0e1bb4c..2c31f8d908 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Range"), RangeOp); +REGISTER_XLA_OP(Name("Range") + .CompileTimeConstInput("start") + .CompileTimeConstInput("limit") + .CompileTimeConstInput("delta"), + RangeOp); class LinSpaceOp : public XlaOpKernel { public: @@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp); +REGISTER_XLA_OP(Name("LinSpace") + .CompileTimeConstInput("start") + .CompileTimeConstInput("stop") + .CompileTimeConstInput("num"), + LinSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index e205fadd2b..8fb7a74310 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -150,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel { ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); } }; -REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index fbe8c78d8f..be1e97bf26 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Slice"), SliceOp); +REGISTER_XLA_OP( + Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"), + SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 83a87f19a7..01b46e160d 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel { block_shape, paddings); } }; -REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); +REGISTER_XLA_OP(Name("SpaceToBatchND") + .CompileTimeConstInput("paddings") + .CompileTimeConstInput("block_shape"), + SpaceToBatchNDOp); class SpaceToBatchOp : public XlaOpKernel { public: @@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); +REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), + SpaceToBatchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 795eb1794f..79c435c90a 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Split"), SplitOp); +REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel { int neg_one_dim = -1; std::vector<int64> split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, split_size_shape.dims() == 1 && - split_size_shape.num_elements() == num_split, + OP_REQUIRES(ctx, + split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, errors::InvalidArgument( "shape of tensor describing " " the output must have dimension 1 and the same " @@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel { } OP_REQUIRES( - ctx, (neg_one_dim == -1 && - total_split_size == input_shape.dim_size(split_dim)) || - (neg_one_dim >= 0 && - total_split_size <= input_shape.dim_size(split_dim)), + ctx, + (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), errors::InvalidArgument("Determined shape must either match " "input shape along split_dim exactly if " "fully specified, or be less than the size of " @@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("SplitV"), SplitVOp); +REGISTER_XLA_OP(Name("SplitV") + .CompileTimeConstInput("split_dim") + .CompileTimeConstInput("size_splits"), + SplitVOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 8013ece861..c912876e65 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -129,7 +129,7 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2"), StackOp); +REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp); class StackPushOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6af4bd0496..f0525a5fb8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp); +REGISTER_XLA_OP(Name("StridedSlice") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { public: @@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +REGISTER_XLA_OP(Name("StridedSliceGrad") + .CompileTimeConstInput("shape") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceGradOp); class StridedSliceAssignOp : public XlaOpKernel { public: @@ -320,7 +329,11 @@ class StridedSliceAssignOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceAssignOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 8a742ff11c..9224072a3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), + TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { public: @@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = b->Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), + TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9ee6bd8925..9aefcd4fc7 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP(Name("Tile"), TileOp); +REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 2fc5d40d10..5c17b7fbf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -72,8 +72,9 @@ class TransposeOp : public XlaOpKernel { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( - i, " is missing from 'perm' argument.")); + OP_REQUIRES( + ctx, bits[i], + errors::InvalidArgument(i, " is missing from 'perm' argument.")); } // 0-D, 1-D, and identity transposes do nothing. @@ -87,7 +88,7 @@ class TransposeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Transpose"), TransposeOp); +REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. // @@ -103,8 +104,9 @@ class InvertPermutationOp : public XlaOpKernel { explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), - std::numeric_limits<int32>::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(ctx->InputShape(0).num_elements(), + std::numeric_limits<int32>::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); @@ -128,7 +130,9 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), +REGISTER_XLA_OP(Name("InvertPermutation") + .TypeConstraint("T", DT_INT32) + .CompileTimeConstInput("x"), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index faf47434b5..97bb100fb1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -83,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; return false; } } + if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible compile time constant inputs."; + return false; + } return true; } @@ -263,6 +268,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( return kernels; } +/* static */ const std::unordered_set<string>* +XlaOpRegistry::CompileTimeConstantInputs(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end()) { + return nullptr; + } + return &it->second->compile_time_constant_inputs; +} + std::vector<string> XlaOpRegistry::BackendNames() { std::vector<string> names; XlaOpRegistry& registry = Instance(); @@ -337,6 +353,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( + StringPiece input_name) { + registration_->compile_time_constant_inputs.insert(input_name.ToString()); + return *this; +} + std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 8bfd9758f7..ff7453194a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -128,6 +128,11 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr + // if the op is not registered. + static const std::unordered_set<string>* CompileTimeConstantInputs( + const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -181,6 +186,9 @@ class XlaOpRegistry { bool has_device_whitelist = false; std::unordered_set<string> device_whitelist; + // Names of arguments that must be compile-time constants. + std::unordered_set<string> compile_time_constant_inputs; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -242,6 +250,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. + XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + std::unique_ptr<XlaOpRegistry::OpRegistration> Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 2a335843f5..80d89d851e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -276,6 +276,23 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!slice->IsInPlaceSlice()) { + return false; + } + // If this slice is lowered to an in-place version, then it forwards the + // operand value to the output. + const InstructionValueSet& operand_set = + GetInstructionValueSet(slice->operand(0)); + InstructionValueSet& slice_set = GetInstructionValueSet(slice); + if (operand_set != slice_set) { + slice_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -527,6 +544,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); + case HloOpcode::kSlice: + return UpdateSliceValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: @@ -688,6 +707,11 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kSlice: + if (!instruction->IsInPlaceSlice()) { + define_all_values(); + } + break; case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 469620d012..89d318188f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -145,6 +145,7 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); + bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 9b778c574c..d455cfc3f1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -964,6 +964,17 @@ class HloInstruction { } const std::vector<int64>& slice_strides() const { return slice_strides_; } + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -1297,6 +1308,9 @@ class HloInstruction { std::vector<int64> slice_limits_; std::vector<int64> slice_strides_; + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; + // The bit sizes for a reduce-precision operation. int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index b01fcccdb4..0cb9b5d810 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -39,7 +39,6 @@ namespace xla { namespace interpreter { namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr<const HloModule> hlo_module) diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0c84856647..657a8fe09a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,16 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { + // A kSlice instruction aliases its operand if the backend lowers it to an + // in-place implementation. + if (slice->IsInPlaceSlice()) { + CreateCopiedPointsToSet(slice, slice->operand(0)); + return Status::OK(); + } + return DefaultAction(slice); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to its output. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); @@ -427,10 +437,15 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + // kSlice ops that are lowered to an in-place version are expected to not + // define their output buffer. + if (buffer.instruction()->opcode() != HloOpcode::kSlice || + !buffer.instruction()->IsInPlaceSlice()) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + } } if (buffer.id() < 0 || diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 8928de107e..5ca6ccb5c9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -250,6 +250,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 2e329cc513..4b43810a68 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -1,4 +1,4 @@ -# HloModule string syntax +# HLO Text Syntax ```yacc hlo_module diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index bc749339bd..046c509626 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -16,15 +16,15 @@ See the @{$python/contrib.crf} guide. -@@crf_sequence_score -@@crf_log_norm -@@crf_log_likelihood -@@crf_unary_score @@crf_binary_score @@crf_decode -@@CrfForwardRnnCell -@@CrfDecodeForwardRnnCell +@@crf_log_likelihood +@@crf_log_norm +@@crf_sequence_score +@@crf_unary_score @@CrfDecodeBackwardRnnCell +@@CrfDecodeForwardRnnCell +@@CrfForwardRnnCell @@viterbi_decode """ @@ -32,16 +32,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score -from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell -from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import viterbi_decode from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index b47fb426a1..721dc4d080 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -179,17 +179,6 @@ class CrfTest(test.TestCase): tf_total_log_likelihood = sess.run(total_log_likelihood) self.assertAllClose(tf_total_log_likelihood, 0.0) - def testLengthsToMasks(self): - with self.test_session() as sess: - sequence_lengths = [4, 1, 8, 2] - max_sequence_length = max(sequence_lengths) - mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length) - tf_mask = sess.run(mask) - self.assertEqual(len(tf_mask), len(sequence_lengths)) - for m, l in zip(tf_mask, sequence_lengths): - self.assertAllEqual(m[:l], [1] * l) - self.assertAllEqual(m[l:], [0] * (len(m) - l)) - def testViterbiDecode(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7f5ae937b2..62708636c6 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -70,25 +70,6 @@ __all__ = [ ] -def _lengths_to_masks(lengths, max_length): - """Creates a binary matrix that can be used to mask away padding. - - Args: - lengths: A vector of integers representing lengths. - max_length: An integer indicating the maximum length. All values in - lengths should be less than max_length. - Returns: - masks: Masks that can be used to get rid of padding. - """ - tiled_ranges = array_ops.tile( - array_ops.expand_dims(math_ops.range(max_length), 0), - [array_ops.shape(lengths)[0], 1]) - lengths = array_ops.expand_dims(lengths, 1) - masks = math_ops.to_float( - math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths)) - return masks - - def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): array_ops.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) return unary_scores @@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 54f6974dba..015f69c567 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -305,10 +305,10 @@ class BatchDatasetTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -334,9 +334,9 @@ class BatchDatasetTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x, x], x)).apply( batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + 4, [5, None])).make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -363,25 +363,18 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = ( - dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) - init_op = iterator.initializer - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dimension -2 must be >= -1"): - sess.run(init_op) + with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator() def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) iterator = ( dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: # Initialize with an input tensor of incompatible rank. @@ -740,7 +733,9 @@ class BatchDatasetSerializationTest( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) - def testDenseToSparseBatchDatasetCore(self): + # TODO(b/70988345): Re-enable when sparse tensors are properly supported by + # the DatasetSerializationTestBase. + def _testDenseToSparseBatchDatasetCore(self): components = np.random.randint(5, size=(40,)).astype(np.int32) diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index bf25cc60a1..7cde6e05b2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -40,6 +40,8 @@ class DatasetSerializationTestBase(test.TestCase): def tearDown(self): self._delete_ckpt() + # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of + # `sparse_tensors` argument. def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index e8b2d44a8b..e0860cbe8a 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -22,6 +22,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -231,32 +232,29 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): input_dataset.output_types) self._input_dataset = input_dataset self._batch_size = batch_size - # pylint: disable=protected-access - self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape) - # pylint: enable=protected-access + self._row_shape = row_shape def _as_variant_tensor(self): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - self._row_shape, - output_shapes=self.output_shapes, - output_types=self.output_types) + row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) @property def output_classes(self): - return (ops.Tensor, ops.Tensor, ops.Tensor) + return sparse_tensor.SparseTensor @property def output_shapes(self): - num_elements = tensor_shape.Dimension(None) - return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), - tensor_shape.vector(num_elements), - tensor_shape.vector(self._row_shape.shape[0] + 1)) + return tensor_shape.vector(None).concatenate(self._row_shape) @property def output_types(self): - return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) + return self._input_dataset.output_types class _RestructuredDataset(dataset_ops.Dataset): diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 00a18569fc..4bea99fbb7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -18,15 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import mvn_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops __all__ = [ @@ -170,15 +167,11 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10 - diff = math_ops.abs( - covariance_matrix - - array_ops.matrix_transpose(covariance_matrix)) - assert_symmetric = check_ops.assert_less( - diff, tol + tol * math_ops.abs(covariance_matrix), - message="Matrix was not symmetric.") - covariance_matrix = control_flow_ops.with_dependencies( - [assert_symmetric], covariance_matrix) + covariance_matrix = control_flow_ops.with_dependencies([ + check_ops.assert_near( + covariance_matrix, + array_ops.matrix_transpose(covariance_matrix), + message="Matrix was not symmetric")], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py index 136085eba2..205709fe2e 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py @@ -39,40 +39,22 @@ def random_dataset(): return tf.data.Dataset.from_tensors((images, labels)) -def train_one_epoch(defun=False): - model = mnist.MNISTModel(data_format()) - if defun: - model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) - - -def evaluate(defun=False): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - if defun: - model.call = tfe.defun(model.call) - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) - - class MNISTTest(tf.test.TestCase): def testTrainOneEpoch(self): - train_one_epoch(defun=False) + model = mnist.MNISTModel(data_format()) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.train_one_epoch(model, optimizer, dataset) def testTest(self): - evaluate(defun=False) - - def testTrainOneEpochWithDefunCall(self): - train_one_epoch(defun=True) - - def testTestWithDefunCall(self): - evaluate(defun=True) + model = mnist.MNISTModel(data_format()) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.test(model, dataset) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index e2ae665a74..d8d8644dde 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -64,22 +64,14 @@ def train_one_step(model, images, labels, optimizer): class ResNet50Test(tf.test.TestCase): - def _apply(self, defun=False): + def test_apply(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) with tf.device(device): images, _ = random_batch(2) output = model(images) self.assertEqual((2, 1000), output.shape) - def test_apply(self): - self._apply(defun=False) - - def test_apply_with_defun(self): - self._apply(defun=True) - def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -183,11 +175,9 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False): + def benchmark_eager_apply(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -199,23 +189,16 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): model(images).cpu() - self._report(label, start, num_iters, device, batch_size, data_format) - - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) - - def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + self._report('eager_apply', start, num_iters, device, batch_size, + data_format) - def _benchmark_eager_train(self, label, make_iterator, defun=False): + def _benchmark_eager_train(self, label, make_iterator): device, data_format = device_and_data_format() for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -234,11 +217,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) - - def benchmark_eager_train_with_defun(self): - self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + self._benchmark_eager_train('eager_train', MockIterator) def benchmark_eager_train_datasets(self): @@ -247,18 +226,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): ds = tf.data.Dataset.from_tensors(tensors).repeat() return tfe.Iterator(ds) - self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) - - def benchmark_eager_train_datasets_with_defun(self): - - def make_iterator(tensors): - with tf.device('/device:CPU:0'): - ds = tf.data.Dataset.from_tensors(tensors).repeat() - return tfe.Iterator(ds) - - self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + self._benchmark_eager_train('eager_train_dataset', make_iterator) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 2f8016ede3..bf029ca5f9 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -51,6 +51,20 @@ class Metric(object): ```python m = SomeMetric(...) + inputs = ... # Some tensors to compute the metric on. + m_update = m(inputs) + # Variables defined in first call, so get the initialization op afterwards. + m_init = m.init_variables() # or tf.global_variables_initializer() + m_result = m.result() + with tf.Session() as sess: + sess.run(m_init) + for input in ...: + sess.run(m_update) + print(sess.run(m_result)) + ``` + Example use with graph execution with placeholders and feed_dict: + ```python + m = SomeMetric(...) m_placeholder = tf.placeholder(...) m_update = m(m_placeholder) # Variables defined in first call, so get the initialization op afterwards. @@ -107,6 +121,7 @@ class Metric(object): """Returns op to execute to update this metric for these inputs. Returns None if eager execution is enabled. + Returns a graph-mode function if graph execution is enabled. Args: *args: @@ -183,6 +198,13 @@ class Metric(object): """Computes and returns a final value for the metric.""" raise NotImplementedError("Metrics must define a result() member function") + def value(self): + """In graph mode returns the result Tensor while in eager the callable.""" + if context.in_graph_mode(): + return self.result() + else: + return self.result + # We can support two different strategies of for doing data-parallel # distributed metric computations: # * Put metric variables on the first device and rely on small diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 1055f4563c..9cf34fd9b2 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.training import training_util @@ -137,7 +138,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.name, "has space") self.assertEqual(m1.numer.name, "has_space/numer:0") - def testGraph(self): + def testGraphWithPlaceholder(self): with context.graph_mode(), self.test_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) @@ -153,6 +154,22 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) + @test_util.run_in_graph_and_eager_modes() + def testGraphAndEagerTensor(self): + m = metrics.Mean() + inputs = ops.convert_to_tensor([1.0, 2.0]) + accumulate = m(inputs) + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + value = m.value() + self.assertEqual(self.evaluate(value), 2.5) + def testTwoMeansGraph(self): # Verify two metrics with the same class and name don't # accidentally share state. diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 8e6b947e5c..3eb4f5f8b3 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -105,13 +105,15 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(akshayka): This test should be changed once an API for compiling + # `call` into a defun is implemented. def testReplacingNetworkCallWithDefun(self): net = MyNetwork(name="abcd") - net.call = function.defun(net.call) x = constant_op.constant([[2.0]]) net(x) # Force variables to be created. self.evaluate(net.trainable_variables[0].assign([[17.0]])) + net.call = function.defun(net.call) result = net(x) # Build and execute the TensorFlow function self.assertEqual(34.0, self.evaluate(result)) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 6146b92311..ce38ea4754 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1710,6 +1710,7 @@ bool InlineAllFunctions(GraphDef* graphdef) { tensorflow::Graph graph(fld); tensorflow::ImportGraphDefOptions gc_opts; + gc_opts.validate_shape = false; const auto& tf_convert_status = tensorflow::ImportGraphDef( gc_opts, graphdef_copy, &graph, nullptr, nullptr); if (!tf_convert_status.ok()) { diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a34c7f91f2..43cb38f169 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -201,6 +201,17 @@ tf_py_test( ], ) +tf_py_test( + name = "tpu_config_test", + size = "small", + srcs = ["python/tpu/tpu_config_test.py"], + additional_deps = [ + ":tpu_estimator", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 77ce38991b..1b6ce2dfdf 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -20,9 +20,19 @@ from __future__ import division from __future__ import print_function import collections +import json +import os from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.platform import tf_logging as logging + +# pylint: disable=protected-access +_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV +_SERVICE_KEY = run_config_lib._SERVICE_KEY +_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' + +# pylint: enable=protected-access class TPUConfig( @@ -74,6 +84,9 @@ class TPUConfig( if initial_infeed_sleep_secs: util_lib.check_positive_integer(initial_infeed_sleep_secs, 'TPUConfig initial_infeed_sleep_secs') + + tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() + return super(TPUConfig, cls).__new__( cls, iterations_per_loop=iterations_per_loop, @@ -126,3 +139,14 @@ class RunConfig(run_config_lib.RunConfig): new_instance = super(RunConfig, self).replace(**kwargs) new_instance._tpu_config = tpu_config # pylint: disable=protected-access return new_instance + + +def _get_tpu_job_name_from_tf_config(): + """Extracts the TPU job name from TF_CONFIG env variable.""" + # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster + # spec propagation. + tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) + tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) + if tpu_job_name: + logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) + return tpu_job_name diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py new file mode 100644 index 0000000000..618f263618 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU RunConfig tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib +from tensorflow.python.platform import test + + +def _set_tf_config_env_variable(tf_config): + return test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }) + + +class TPURunConfigTest(test.TestCase): + + def test_fail_with_invalid_num_shards(self): + with self.assertRaisesRegexp(ValueError, 'must be positive'): + tpu_config_lib.RunConfig( + tpu_config=tpu_config_lib.TPUConfig(num_shards=0)) + + def test_fail_with_iterations_per_loop(self): + with self.assertRaisesRegexp(ValueError, 'must be positive'): + tpu_config_lib.RunConfig( + tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0)) + + +class TPUJobNameTest(test.TestCase): + + def test_default_name(self): + config = tpu_config_lib.RunConfig() + self.assertIsNone(config.tpu_config.tpu_job_name) + + def test_with_tf_config(self): + tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}} + with _set_tf_config_env_variable(tf_config): + config = tpu_config_lib.RunConfig() + self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt index f2f5594c7c..e275cfdd3d 100644 --- a/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt @@ -21,5 +21,5 @@ SparseTensor. The shape may be partially specified, using `-1` to indicate that a particular dimension should use the maximum size of all batch elements. END } - summary: "Creates a dataset that yields a SparseTensor for each element of the input." + summary: "Creates a dataset that batches input elements into a SparseTensor." } diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 15edce6a68..99b33e2ef0 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -1265,7 +1265,7 @@ TEST(DirectSessionTest, LocalDeviceManager) { // A simple benchmark for the overhead of `DirectSession::Run()` calls // with varying numbers of feeds/fetches. -void FeedFetchBenchmarkHelper(int num_feeds, int iters) { +void FeedFetchBenchmarkHelper(int iters, int num_feeds) { testing::StopTiming(); Tensor value(DT_FLOAT, TensorShape()); diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index b5f4ebb005..3fd932da5b 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -202,11 +202,12 @@ class ExecutorBarrier { // below. if (--pending_ == 0) { CHECK(done_cb_ != nullptr); - done = done_cb_; - done_cb_ = nullptr; + std::swap(done, done_cb_); } - status = status_; + if (!status_.ok()) { + status = status_; + } } if (error) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc index 7c09451a8a..08961fc105 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc @@ -25,8 +25,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/stream_executor.h" -namespace gpu = ::perftools::gputools; - namespace tensorflow { GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator, diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc index 45e97fdbf0..eea857f8ce 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" +#include <cstddef> #include <vector> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" @@ -22,16 +23,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/stream_executor.h" -namespace gpu = ::perftools::gputools; - -namespace tensorflow { - #define MASK_WORDS 2 #define MASK_BYTES (MASK_WORDS * sizeof(int64)) +namespace tensorflow { namespace { -static int64* NewMask(int64 word) { +int64* NewMask(int64 word) { int64* m = new int64[MASK_WORDS]; for (int i = 0; i < MASK_WORDS; ++i) { m[i] = word; @@ -39,8 +37,8 @@ static int64* NewMask(int64 word) { return m; } -static int64* before_mask = NewMask(0xabababababababab); -static int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd); +int64* before_mask = NewMask(0xabababababababab); +int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd); bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr, int64* mask) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index ca4b93815c..d34f0cb3c2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -30,9 +30,8 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -namespace gpu = ::perftools::gputools; - namespace tensorflow { +namespace { TEST(GPUDebugAllocatorTest, OverwriteDetection_None) { const CudaGpuId cuda_gpu_id(0); @@ -223,6 +222,7 @@ TEST(GPUDebugAllocatorTest, AllocatedVsRequested) { a.DeallocateRaw(t1); } +} // namespace } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 8a3220ce2b..995fd1253f 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/process_state.h" +#include <cstring> #include <vector> #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" @@ -47,27 +48,19 @@ const bool FLAGS_brain_mem_reg_cuda_dma = true; // performance issues. const bool FLAGS_brain_gpu_record_mem_types = false; -namespace gpu = ::perftools::gputools; - namespace tensorflow { - namespace { + bool useCudaMallocAllocator() { const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR"); - if (debug_allocator_str != nullptr && - strcmp(debug_allocator_str, "cuda_malloc") == 0) - return true; - else - return false; + return debug_allocator_str != nullptr && + std::strcmp(debug_allocator_str, "cuda_malloc") == 0; } bool useCudaMemoryGuardAllocator() { const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR"); - if (debug_allocator_str != nullptr && - strcmp(debug_allocator_str, "memory_guard") == 0) - return true; - else - return false; + return debug_allocator_str != nullptr && + std::strcmp(debug_allocator_str, "memory_guard") == 0; } } // namespace diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 54f082e823..a913f20751 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -369,7 +369,8 @@ class ColocationGraph { "Could not satisfy explicit device specification '", node->requested_device(), "' because no supported kernel for ", specified_device_name.type, " devices is available.", - debug_info); + debug_info, "\nRegistered kernels:\n", + KernelsRegisteredForOp(node->type_string())); } else { return errors::InvalidArgument( "Could not satisfy explicit device specification '", diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index a87c2c9a57..4e9352ee32 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -33,7 +33,7 @@ limitations under the License. // GetFeatureValues<int64>("tag", &example)->Add(id); // // Modification of bytes features is slightly different: -// auto tag = GetFeatureValues<string>("tag", example); +// auto tag = GetFeatureValues<string>("tag", &example); // *tag->Add() = "lorem ipsum"; // // To copy multiple values into a feature: diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index e19f4aebba..2a52c7516e 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -84,7 +84,8 @@ class GraphConstructor { return_tensors(in.return_tensors), return_nodes(in.return_nodes), importing(true), - validate_colocation_constraints(in.validate_colocation_constraints) {} + validate_colocation_constraints(in.validate_colocation_constraints), + validate_shape(in.validate_shape) {} bool allow_internal_ops; bool expect_device_spec; @@ -108,6 +109,7 @@ class GraphConstructor { // remove this. bool importing; bool validate_colocation_constraints; + bool validate_shape = true; }; typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; @@ -561,7 +563,7 @@ Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) { } Status GraphConstructor::ValidateShape(Node* node) { - if (!opts_.importing) return Status::OK(); + if (!opts_.importing || !opts_.validate_shape) return Status::OK(); TF_RETURN_IF_ERROR(refiner_->AddNode(node)); // For nodes with the _output_shapes attribute, override the shape. std::vector<TensorShapeProto> shape_attrs; diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 07814b2ef7..b03d655fe6 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -57,7 +57,8 @@ struct ImportGraphDefOptions { ImportGraphDefOptions() : uniquify_names(false), uniquify_prefix(false), - skip_mapped_nodes(false) {} + skip_mapped_nodes(false), + validate_shape(true) {} // Name prefix to use for nodes imported from the GraphDef. For example, if // prefix="animals" and GraphDef contains a node "bunny" then the node will be @@ -130,6 +131,9 @@ struct ImportGraphDefOptions { // If true, checks that all colocation constraints are nodes in the GraphDef. bool validate_colocation_constraints = true; + // If false skips shape validation. + bool validate_shape; + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef. // Similar to the producer_op_list argument to import_graph_def in the diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index b97e3d1db1..ae70c98608 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -25,14 +25,16 @@ namespace grappler { VirtualCluster::VirtualCluster( const std::unordered_map<string, DeviceProperties>& devices) - : Cluster(0), node_estimator_(new OpLevelCostEstimator()) { + : Cluster(0), + node_estimator_(new OpLevelCostEstimator()), + node_manager_(new FirstReadyManager()) { devices_ = devices; } VirtualCluster::VirtualCluster( const std::unordered_map<string, DeviceProperties>& devices, - OpLevelCostEstimator* node_estimator) - : Cluster(0), node_estimator_(node_estimator) { + OpLevelCostEstimator* node_estimator, ReadyNodeManager* node_manager) + : Cluster(0), node_estimator_(node_estimator), node_manager_(node_manager) { devices_ = devices; } VirtualCluster::~VirtualCluster() {} @@ -54,7 +56,7 @@ Status VirtualCluster::Run(const GraphDef& graph, item.graph = graph; item.feed = feed; item.fetch = fetch; - VirtualScheduler scheduler(&item, true, this); + VirtualScheduler scheduler(&item, true, this, node_manager_.get()); TF_RETURN_IF_ERROR(scheduler.Init()); if (metadata) { diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index 1c73dbb240..dde70bab7a 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -19,6 +19,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" #include "tensorflow/core/protobuf/device_properties.pb.h" namespace tensorflow { @@ -31,7 +32,8 @@ class VirtualCluster : public Cluster { public: VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices); VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices, - OpLevelCostEstimator* node_estimator); + OpLevelCostEstimator* node_estimator, + ReadyNodeManager* node_manager); ~VirtualCluster() override; @@ -45,6 +47,7 @@ class VirtualCluster : public Cluster { private: std::unique_ptr<OpLevelCostEstimator> node_estimator_; + std::unique_ptr<ReadyNodeManager> node_manager_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index ca66f7c75a..e8d77f405c 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -34,13 +34,15 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes) : cluster_(cluster), node_estimator_(new OpLevelCostEstimator()), + node_manager_(VirtualScheduler::ReadyNodeManagerFactory("FirstReady")), use_static_shapes_(use_static_shapes) {} AnalyticalCostEstimator::AnalyticalCostEstimator( Cluster* cluster, OpLevelCostEstimator* node_estimator, - bool use_static_shapes) + ReadyNodeManager* node_manager, bool use_static_shapes) : cluster_(cluster), node_estimator_(node_estimator), + node_manager_(node_manager), use_static_shapes_(use_static_shapes) {} Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { @@ -61,7 +63,8 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, } } std::vector<string> inaccurate_nodes; - VirtualScheduler scheduler(&item, use_static_shapes_, cluster_); + VirtualScheduler scheduler(&item, use_static_shapes_, cluster_, + node_manager_.get()); auto status = scheduler.Init(); if (!status.ok()) { costs->execution_time = Costs::Duration::max(); diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h index cf9163302c..dd2738e088 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/cost_estimator.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" @@ -39,9 +40,10 @@ class AnalyticalCostEstimator : public CostEstimator { // Does not take ownership of cluster. AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes); // Does not take ownership of the cluster, but takes ownership of the - // node_estimator + // node_estimator and the node_manager AnalyticalCostEstimator(Cluster* cluster, OpLevelCostEstimator* node_estimator, + ReadyNodeManager* node_manager, bool use_static_shapes); ~AnalyticalCostEstimator() override {} @@ -59,6 +61,7 @@ class AnalyticalCostEstimator : public CostEstimator { Cluster* cluster_; // Not owned. GrapplerItem item_; std::unique_ptr<OpLevelCostEstimator> node_estimator_; + std::unique_ptr<ReadyNodeManager> node_manager_; bool use_static_shapes_; }; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index fb3bdedcc6..a0f61c5392 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -113,10 +113,17 @@ void LIFOManager::RemoveCurrNode() { curr_pos_ = nodes_.end(); // Reset curr_pos_. } -FirstReadyManager::FirstReadyManager( - const std::unordered_map<const NodeDef*, NodeState>* node_state) - : ReadyNodeManager(), node_state_(node_state) { +FirstReadyManager::FirstReadyManager() : ReadyNodeManager() { std::make_heap(nodes_.begin(), nodes_.end()); +} + +void FirstReadyManager::Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) { + // Reset the node state since different instances of the scheduler can reuse + // the same node_manager. + node_state_ = node_state; + nodes_.clear(); + waiting_queue_.clear(); greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool { if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) { // Use Node name as tie-breaker for deterministic node scheduling. @@ -163,12 +170,14 @@ void FirstReadyManager::DrainWaitingQueue() { waiting_queue_.clear(); } -CompositeNodeManager::CompositeNodeManager( - const std::unordered_map<const NodeDef*, NodeState>* node_state) - : ReadyNodeManager(), - send_manager_(node_state), - recv_manager_(node_state), - node_state_(node_state) { +CompositeNodeManager::CompositeNodeManager() + : ReadyNodeManager(), send_manager_(), recv_manager_() {} + +void CompositeNodeManager::Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) { + node_state_ = node_state; + send_manager_.Init(node_state); + recv_manager_.Init(node_state); curr_node_ = nullptr; } @@ -241,11 +250,11 @@ bool CompositeNodeManager::Empty() const { return empty && send_manager_.Empty() && recv_manager_.Empty(); } -// VirtualScheduler VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, const bool use_static_shapes, - Cluster* cluster) - : ready_nodes_(ReadyNodeManagerFactory("FirstReady")), + Cluster* cluster, + ReadyNodeManager* ready_nodes) + : ready_nodes_(ready_nodes), graph_costs_(Costs::ZeroCosts()), graph_properties_(*grappler_item), cluster_(cluster), @@ -262,9 +271,9 @@ ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory( } else if (ready_node_manager == "LIFO") { return new LIFOManager(); } else if (ready_node_manager == "FirstReady") { - return new FirstReadyManager(GetNodeStates()); + return new FirstReadyManager(); } else if (ready_node_manager == "Composite") { - return new CompositeNodeManager(GetNodeStates()); + return new CompositeNodeManager(); } LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager; } @@ -274,7 +283,7 @@ Status VirtualScheduler::Init() { // necessary information for emulating tensorflow op scheduling and // construct internal data structures (NodeState and DeviceState) for virtual // scheduling. - + ready_nodes_->Init(GetNodeStates()); // Construct graph properties. Status status; if (use_static_shapes_) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index df8ae5861a..c180250908 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -127,6 +127,8 @@ class ReadyNodeManager { public: ReadyNodeManager() {} virtual ~ReadyNodeManager() {} + virtual void Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) {} virtual void AddNode(const NodeDef* node) = 0; virtual const NodeDef* GetCurrNode() = 0; virtual void RemoveCurrNode() = 0; @@ -137,6 +139,8 @@ class FIFOManager : public ReadyNodeManager { public: FIFOManager() : ReadyNodeManager() {} ~FIFOManager() override {} + virtual void Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override { CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; @@ -157,6 +161,8 @@ class LIFOManager : public ReadyNodeManager { public: LIFOManager() : ReadyNodeManager() {} ~LIFOManager() override {} + void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state) + override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override; void RemoveCurrNode() override; @@ -176,8 +182,9 @@ class LIFOManager : public ReadyNodeManager { // time_ready value (it depends on C++ STL push_heap and pop_heap). class FirstReadyManager : public ReadyNodeManager { public: - FirstReadyManager( - const std::unordered_map<const NodeDef*, NodeState>* node_state); + FirstReadyManager(); + void Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) override; ~FirstReadyManager() override {} void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); } const NodeDef* GetCurrNode() override; @@ -212,10 +219,11 @@ class FirstReadyManager : public ReadyNodeManager { // _Send and _Recv, fairly, in terms of their time_ready. class CompositeNodeManager : public ReadyNodeManager { public: - CompositeNodeManager( - const std::unordered_map<const NodeDef*, NodeState>* node_state); + CompositeNodeManager(); ~CompositeNodeManager() override {} + void Init( + const std::unordered_map<const NodeDef*, NodeState>* node_state) override; void AddNode(const NodeDef* node) override; const NodeDef* GetCurrNode() override; void RemoveCurrNode() override; @@ -244,8 +252,8 @@ class CompositeNodeManager : public ReadyNodeManager { class VirtualScheduler { public: VirtualScheduler(const GrapplerItem* grappler_item, - const bool use_static_shapes, Cluster* cluster); - + const bool use_static_shapes, Cluster* cluster, + ReadyNodeManager* ready_nodes); // Initializes NodeState and DeviceState from grappler_item_ and // graph_properties_. Status Init(); @@ -260,6 +268,9 @@ class VirtualScheduler { // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). Costs Summary(RunMetadata* metadata); + // Methods called from constructor. + static ReadyNodeManager* ReadyNodeManagerFactory( + const string& ready_node_manager); // Return per device peak memory usage. const std::unordered_map<string, int64> GetPeakMemoryUsage() const; @@ -285,9 +296,6 @@ class VirtualScheduler { const string kAttrDstDevice = "dst_device_"; const string kChannelDevice = "Channel"; - // Methods called from constructor. - ReadyNodeManager* ReadyNodeManagerFactory(const string& ready_node_manager); - // Methods called from Init(). Fails if initialize_ is set. void MaybeUpdateInputOutput(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node); @@ -304,7 +312,7 @@ class VirtualScheduler { bool IsPersistentNode(const NodeDef* node) const; // Scheduler states: - std::unique_ptr<ReadyNodeManager> ready_nodes_; + ReadyNodeManager* ready_nodes_; // Not owned. std::unordered_map<const NodeDef*, NodeState> node_map_; std::unordered_map<string, DeviceState> device_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index cd960bab29..9e6561230a 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -29,7 +29,8 @@ class TestVirtualScheduler : public VirtualScheduler { public: TestVirtualScheduler(const GrapplerItem* grappler_item, const bool use_static_shapes, Cluster* cluster) - : VirtualScheduler(grappler_item, use_static_shapes, cluster) {} + : VirtualScheduler(grappler_item, use_static_shapes, cluster, + &ready_node_manager_) {} FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize); FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); @@ -37,6 +38,9 @@ class TestVirtualScheduler : public VirtualScheduler { FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); FRIEND_TEST(VirtualSchedulerTest, Variable); FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer); + + protected: + FirstReadyManager ready_node_manager_; }; class VirtualSchedulerTest : public ::testing::Test { @@ -1148,23 +1152,24 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) { } TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) { - FirstReadyManager manager = FirstReadyManager(&node_states_); + FirstReadyManager manager; + manager.Init(&node_states_); manager.AddNode(&node1_); EXPECT_EQ("Node1", manager.GetCurrNode()->name()); } TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) { - FirstReadyManager manager = FirstReadyManager(&node_states_); - + FirstReadyManager manager; + manager.Init(&node_states_); manager.AddNode(&node1_); manager.RemoveCurrNode(); EXPECT_TRUE(manager.Empty()); } TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) { - FirstReadyManager manager = FirstReadyManager(&node_states_); - + FirstReadyManager manager; + manager.Init(&node_states_); // Insert nodes in some random order. manager.AddNode(&node2_); manager.AddNode(&node1_); @@ -1191,8 +1196,8 @@ TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) { } TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) { - FirstReadyManager manager = FirstReadyManager(&node_states_); - + FirstReadyManager manager; + manager.Init(&node_states_); // Insert nodes in some random order. manager.AddNode(&node2_); manager.AddNode(&node1_); @@ -1248,8 +1253,10 @@ TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) { } TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) { - FirstReadyManager manager1 = FirstReadyManager(&node_states_); - FirstReadyManager manager2 = FirstReadyManager(&node_states_); + FirstReadyManager manager1; + manager1.Init(&node_states_); + FirstReadyManager manager2; + manager2.Init(&node_states_); // 6 nodes with same time_ready. NodeDef node7; @@ -1312,15 +1319,16 @@ TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) { } TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) { - CompositeNodeManager manager = CompositeNodeManager(&node_states_); - + CompositeNodeManager manager; + manager.Init(&node_states_); manager.AddNode(&node1_); manager.RemoveCurrNode(); EXPECT_TRUE(manager.Empty()); } TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) { - CompositeNodeManager manager = CompositeNodeManager(&node_states_); + CompositeNodeManager manager; + manager.Init(&node_states_); manager.AddNode(&node1_); manager.RemoveCurrNode(); @@ -1328,7 +1336,8 @@ TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) { } TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) { - CompositeNodeManager manager = CompositeNodeManager(&node_states_); + CompositeNodeManager manager; + manager.Init(&node_states_); // Add the nodes to LIFOManager. manager.AddNode(&node1_); @@ -1359,8 +1368,8 @@ TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) { } TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) { - CompositeNodeManager manager = CompositeNodeManager(&node_states_); - + CompositeNodeManager manager; + manager.Init(&node_states_); // Additional nodes on kCPU1 NodeDef node7; NodeDef node8; diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index c9e5f842be..2786b8cf62 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -926,7 +926,7 @@ class AvgPoolGradProcessor : public NodeProcessor { protected: std::vector<int> GetInputPos() const override { return {1}; } Status CustomizedProcessing() override { - return UpdateAttrValueOfInput(0, true); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32); } }; @@ -1062,9 +1062,7 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor { std::vector<int> GetInputPos() const override { return {2}; } Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32)); - return Status::OK(); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32); } }; @@ -1371,9 +1369,7 @@ class FillProcessor : public AgnosticNodeProcessor { Status CustomizedProcessing() override { DataType dtype = node_->attr().at("index_type").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype); } }; @@ -1470,9 +1466,7 @@ class PadProcessor : public AgnosticNodeProcessor { protected: Status CustomizedProcessing() override { DataType dtype = node_->attr().at("Tpaddings").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype); } }; @@ -1484,9 +1478,7 @@ class ReverseProcessor : public AgnosticNodeProcessor { protected: Status CustomizedProcessing() override { DataType dtype = node_->attr().at("Tidx").type(); - TF_RETURN_IF_ERROR( - UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype)); - return Status::OK(); + return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype); } }; @@ -1511,9 +1503,8 @@ class SplitProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR(UpdateOrTransformParamInput( - axis_node_pos_, "DataFormatDimMap", DT_INT32)); - return Status::OK(); + return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", + DT_INT32); } int axis_node_pos_; @@ -1629,40 +1620,14 @@ class SumProcessor : public AgnosticNodeProcessor { int port; ParseNodeName(node_->input(0), &port); return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && - IsPortDimsFour(*input0, port) && IsAlongDimNHW() && IsOnGPU(); + IsPortDimsFour(*input0, port) && IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } Status CustomizedProcessing() override { - return UpdateAttrValueOfInput(1, false); - } - - private: - bool IsAlongDimNHW() const { - NodeDef* reduction_indices = node_map_->GetNode(node_->input(1)); - if (!IsConstant(*reduction_indices)) { - return false; - } - Tensor tensor; - if (reduction_indices->attr().find({"value"}) == - reduction_indices->attr().end()) { - return false; - } - auto success = - tensor.FromProto(reduction_indices->attr().at({"value"}).tensor()); - if (!success) { - LOG(ERROR) << "Failed to parse TensorProto."; - return false; - } - if (tensor.flat<int>().size() != 3) { - return false; - } - if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 && - tensor.flat<int>()(2) == 2) { - return true; - } - return false; + DataType dtype = node_->attr().at("Tidx").type(); + return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype); } }; diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 6735373aac..24381a13ea 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { @@ -85,11 +86,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { input_(input) { input_->Ref(); - output_shapes_.reserve(3); - // Outputs represent a SparseTensor as (indices, values, dense_shape). - output_shapes_.push_back({-1, row_shape_.dims() + 1}); - output_shapes_.push_back({-1}); - output_shapes_.push_back({row_shape_.dims() + 1}); + output_shapes_.reserve(1); + PartialTensorShape output_shape({-1}); + output_shape.AppendShape(row_shape_); + output_shapes_.push_back(output_shape); } ~Dataset() override { input_->Unref(); } @@ -101,8 +101,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { } const DataTypeVector& output_dtypes() const override { - static DataTypeVector* output_dtypes_ = - new DataTypeVector({DT_INT64, DataTypeToEnum<T>::value, DT_INT64}); + static DataTypeVector* output_dtypes_ = new DataTypeVector({DT_VARIANT}); return *output_dtypes_; } @@ -220,7 +219,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { {total_elements, row_ndims + 1}); Tensor values( cpu_allocator(), - DatasetIterator<Dataset<T>>::dataset()->output_dtypes()[1], + DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0], {total_elements}); auto indices_matrix = indices.matrix<int64>(); auto values_flat = values.flat<T>(); @@ -256,9 +255,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { dense_shape_vec(0) = batch_elements.size(); - out_tensors->push_back(std::move(indices)); - out_tensors->push_back(std::move(values)); - out_tensors->push_back(std::move(dense_shape)); + Tensor serialized_sparse(DT_VARIANT, TensorShape({3})); + auto serialized_sparse_t = serialized_sparse.vec<Variant>(); + serialized_sparse_t(0) = std::move(indices); + serialized_sparse_t(1) = std::move(values); + serialized_sparse_t(2) = std::move(dense_shape); + out_tensors->push_back(std::move(serialized_sparse)); *end_of_sequence = false; return Status::OK(); diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 1688674eb7..09ba092f40 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -566,6 +566,27 @@ class FusedBatchNormOp : public OpKernel { bool is_training_; }; +namespace { + +template <typename Device> +void FillZeros(Tensor* t); + +#if GOOGLE_CUDA +template <> +void FillZeros<GPUDevice>(Tensor* t) { + cudaMemset(const_cast<char*>(t->tensor_data().data()), 0, + t->tensor_data().size()); +} +#endif + +template <> +void FillZeros<CPUDevice>(Tensor* t) { + memset(const_cast<char*>(t->tensor_data().data()), 0, + t->tensor_data().size()); +} + +} // namespace + template <typename Device, typename T, typename U> class FusedBatchNormGradOp : public OpKernel { public: @@ -623,14 +644,17 @@ class FusedBatchNormGradOp : public OpKernel { Tensor* offset_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape, &offset_backprop)); - // two placeholders for estimated_mean and estimated_variance, which are + // Two placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. + // They are filled with zeros so as to avoid NaN outputs. Tensor* placeholder_1 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(3, TensorShape({}), &placeholder_1)); + FillZeros<Device>(placeholder_1); Tensor* placeholder_2 = nullptr; OP_REQUIRES_OK( context, context->allocate_output(4, TensorShape({}), &placeholder_2)); + FillZeros<Device>(placeholder_2); if (is_training_) { functor::FusedBatchNormGrad<Device, T, U>()( diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index caf4c533c3..4d95946467 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -12463,6 +12463,24 @@ op { } } op { + name: "DecodeCompressed" + input_arg { + name: "bytes" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "compression_type" + type: "string" + default_value { + s: "" + } + } +} +op { name: "DecodeGif" input_arg { name: "contents" @@ -18388,6 +18406,32 @@ op { } } op { + name: "HSVToRGB" + input_arg { + name: "images" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "HashTable" output_arg { name: "table_handle" @@ -31006,6 +31050,32 @@ op { } } op { + name: "RGBToHSV" + input_arg { + name: "images" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "RandomCrop" input_arg { name: "image" @@ -32351,6 +32421,60 @@ op { is_stateful: true } op { + name: "RecordInput" + output_arg { + name: "records" + type: DT_STRING + } + attr { + name: "file_pattern" + type: "string" + } + attr { + name: "file_random_seed" + type: "int" + default_value { + i: 301 + } + } + attr { + name: "file_shuffle_shift_ratio" + type: "float" + default_value { + f: 0 + } + } + attr { + name: "file_buffer_size" + type: "int" + default_value { + i: 10000 + } + } + attr { + name: "file_parallelism" + type: "int" + default_value { + i: 16 + } + } + attr { + name: "batch_size" + type: "int" + default_value { + i: 32 + } + } + attr { + name: "compression_type" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { name: "ReduceJoin" input_arg { name: "inputs" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index e943a698ae..9f4e0e91a7 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -437,13 +437,11 @@ REGISTER_OP("DenseToSparseBatchDataset") .Input("batch_size: int64") .Input("row_shape: int64") .Output("handle: variant") - // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. .Attr("output_types: list(type) >= 1") - // NOTE(mrry): the 1st and 2nd elements will be vectors. .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( -Creates a dataset that yields a SparseTensor for each element of the input. +Creates a dataset that batches input elements into a SparseTensor. input_dataset: A handle to an input dataset. Must have a single component. batch_size: A scalar representing the number of elements to accumulate in a diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 333c280146..726d376b25 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -6868,6 +6868,29 @@ op { description: "RFC 4180 format is expected for the CSV records.\n(https://tools.ietf.org/html/rfc4180)\nNote that we allow leading and trailing spaces with int or float field." } op { + name: "DecodeCompressed" + input_arg { + name: "bytes" + description: "A Tensor of string which is compressed." + type: DT_STRING + } + output_arg { + name: "output" + description: "A Tensor with the same shape as input `bytes`, uncompressed\nfrom bytes." + type: DT_STRING + } + attr { + name: "compression_type" + type: "string" + default_value { + s: "" + } + description: "A scalar containing either (i) the empty string (no\ncompression), (ii) \"ZLIB\", or (iii) \"GZIP\"." + } + summary: "Decompress strings." + description: "This op decompresses each element of the `bytes` input `Tensor`, which\nis assumed to be compressed using the given `compression_type`.\n\nThe `output` is a string `Tensor` of the same shape as `bytes`,\neach element containing the decompressed data from the corresponding\nelement in `bytes`." +} +op { name: "DecodeGif" input_arg { name: "contents" @@ -7169,7 +7192,7 @@ op { has_minimum: true minimum: 1 } - summary: "Creates a dataset that yields a SparseTensor for each element of the input." + summary: "Creates a dataset that batches input elements into a SparseTensor." } op { name: "DenseToSparseSetOperation" @@ -10888,6 +10911,8 @@ op { } allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -20070,6 +20095,8 @@ op { } allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -21287,6 +21314,14 @@ op { } description: "The batch size." } + attr { + name: "compression_type" + type: "string" + default_value { + s: "" + } + description: "The type of compression for the file. Currently ZLIB and\nGZIP are supported. Defaults to none." + } summary: "Emits randomized records." is_stateful: true } diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files index 4d2c3b6234..ac50f528ba 100644 --- a/tensorflow/docs_src/mobile/leftnav_files +++ b/tensorflow/docs_src/mobile/leftnav_files @@ -1,6 +1,7 @@ index.md ### TensorFlow Lite tflite/index.md +tflite/demo_android.md >>> ### TensorFlow Mobile mobile_intro.md diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md new file mode 100644 index 0000000000..79b567897c --- /dev/null +++ b/tensorflow/docs_src/mobile/tflite/demo_android.md @@ -0,0 +1,39 @@ +# TensorFlow Lite Demo for Android + +The TensorFlow Lite demo is a camera app that continuously classifies whatever +it sees from your device's back camera, using a quantized MobileNet model. + +You'll need an Android device running Android 5.0 or higher to run the demo. + +To get you started working with TensorFlow Lite on Android, we'll walk you +through building and deploying our TensorFlow demo app in Android Studio. + +It's also possible to build the demo app with Bazel, but we only recommend +this for advanced users who are very familiar with the Bazel build +environment. For more information on that, see our page [on Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite#building-tensorflow-lite-and-the-demo-app-from-source). + +## Build and deploy with Android Studio + +1. Clone the TensorFlow repository from GitHub if you haven't already: + + git clone https://github.com/tensorflow/tensorflow + +2. Install the latest version of Android Studio from [here](https://developer.android.com/studio/index.html). + +3. From the **Welcome to Android Studio** screen, use the **Import Project + (Gradle, Eclipse ADT, etc)** option to import the + `tensorflow/contrib/lite/java/demo` directory as an existing Android Studio + Project. + + Android Studio may prompt you to install Gradle upgrades and other tool + versions; you should accept these upgrades. + +4. Download the TensorFlow Lite MobileNet model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip). + + Unzip this and copy the `mobilenet_quant_v1_224.tflite` file to the assets + directory: `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + +5. Build and run the app in Android Studio. + +You'll have to grant permissions for the app to use the device's camera. Point +the camera at various objects and enjoy seeing how the model classifies things! diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 47f46cde50..7d1418d420 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5953,7 +5953,7 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source return op.Output(0) } -// Creates a dataset that yields a SparseTensor for each element of the input. +// Creates a dataset that batches input elements into a SparseTensor. // // Arguments: // input_dataset: A handle to an input dataset. Must have a single component. @@ -8732,39 +8732,6 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt return op.Output(0) } -// Makes its input available to the next iteration. -// -// Arguments: -// data: The tensor to be made available to the next iteration. -// -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NextIteration", - Input: []tf.Input{ - data, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Does nothing. Only useful as a placeholder for control edges. -// -// Returns the created operation. -func NoOp(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NoOp", - } - return scope.AddOperation(opspec) -} - // DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. type DepthwiseConv2dNativeAttr func(optionalAttr) @@ -9731,6 +9698,39 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf. return op.Output(0) } +// Makes its input available to the next iteration. +// +// Arguments: +// data: The tensor to be made available to the next iteration. +// +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextIteration", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Does nothing. Only useful as a placeholder for control edges. +// +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NoOp", + } + return scope.AddOperation(opspec) +} + // Returns the rank of a tensor. // // This operation returns an integer representing the rank of `input`. @@ -9969,6 +9969,53 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf. return sparse_indices, sparse_values, sparse_shapes, dense_values } +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) + +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// Decompress strings. +// +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. +// +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. +// +// Arguments: +// bytes: A Tensor of string which is compressed. +// +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeCompressed", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Copy a tensor setting everything outside a central band in each innermost matrix // // to zero. @@ -24616,6 +24663,17 @@ func RecordInputBatchSize(value int64) RecordInputAttr { } } +// RecordInputCompressionType sets the optional compression_type attribute to value. +// +// value: The type of compression for the file. Currently ZLIB and +// GZIP are supported. Defaults to none. +// If not specified, defaults to "" +func RecordInputCompressionType(value string) RecordInputAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + // Emits randomized records. // // Arguments: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3a2a7015a1..c2d9c67a29 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4418,7 +4418,10 @@ py_test( "grappler/datasets_test.py", ], srcs_version = "PY2AND3", - tags = ["no_pip"], # tf_optimizer is not available in pip. + tags = [ + "grappler", + "no_pip", # tf_optimizer is not available in pip. + ], deps = [ ":array_ops", ":client_testlib", @@ -4547,9 +4550,9 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/core:protos_all_py", ], + shard_count = 10, tags = [ "grappler", - "manual", ], ) diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 569ea04f01..e08cc51b4d 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -426,6 +426,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec): executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) + _execute_based_on_task_type(executor, config) + + +def _execute_based_on_task_type(executor, config): + """Executes the `executor` based on `config.task_type`.""" if (not config.cluster_spec and config.task_type != run_config_lib.TaskType.EVALUATOR): logging.info('Running training and evaluation locally (non-distributed).') @@ -462,7 +467,6 @@ def train_and_evaluate(estimator, train_spec, eval_spec): 'Task type {} is not supported. Supported task types are {}'.format( config.task_type, [x[len('run_'):] for x in available_tasks])) getattr(executor, task_to_run)() - return class _StopAtSecsHook(session_run_hook.SessionRunHook): @@ -492,6 +496,7 @@ class _TrainingExecutor(object): estimator, train_spec, eval_spec, + train_hooks=None, continuous_eval_listener=None): if not isinstance(estimator, estimator_lib.Estimator): raise TypeError('`estimator` must have type `tf.estimator.Estimator`.') @@ -505,6 +510,8 @@ class _TrainingExecutor(object): raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.') self._eval_spec = eval_spec + self._train_hooks = _validate_hooks(train_hooks) + if (continuous_eval_listener and not isinstance(continuous_eval_listener, _ContinuousEvalListener)): raise TypeError('`continuous_eval_listener` must have type ' @@ -607,7 +614,8 @@ class _TrainingExecutor(object): self._eval_spec.throttle_secs)) stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) - train_hooks = list(self._train_spec.hooks) + [stop_hook] + train_hooks = ( + list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks)) logging.info('Start train and evaluate loop. The evaluate will happen ' 'after {} secs (eval_spec.throttle_secs) or training is ' 'finished.'.format(self._eval_spec.throttle_secs)) @@ -695,10 +703,11 @@ class _TrainingExecutor(object): start_delay_secs) time.sleep(start_delay_secs) - self._estimator.train(input_fn=self._train_spec.input_fn, - max_steps=self._train_spec.max_steps, - hooks=self._train_spec.hooks, - saving_listeners=saving_listeners) + self._estimator.train( + input_fn=self._train_spec.input_fn, + max_steps=self._train_spec.max_steps, + hooks=list(self._train_spec.hooks) + list(self._train_hooks), + saving_listeners=saving_listeners) def _start_continuous_evaluation(self): """Repeatedly calls `Estimator` evaluate and export until training ends.""" diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 6390a67762..9536ee44d5 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -72,6 +72,7 @@ _NONE_EXPORTER_NAME_MSG = ( 'An Exporter cannot have a name that is `None` or empty.') _INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`' _INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`' +_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`' _INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG' _INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`' _INVALID_TASK_TYPE = '`estimator.config` must have task_type set.' @@ -522,6 +523,29 @@ class TrainingExecutorConstructorTest(test.TestCase): with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG): training._TrainingExecutor(estimator, train_spec, invalid_eval_spec) + def test_invalid_train_hooks(self): + estimator = estimator_lib.Estimator(model_fn=lambda features: features) + train_spec = training.TrainSpec(input_fn=lambda: 1) + eval_spec = training.EvalSpec(input_fn=lambda: 1) + invalid_train_hooks = [object()] + + with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): + training._TrainingExecutor( + estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks) + + def test_invalid_continuous_eval_listener(self): + estimator = estimator_lib.Estimator(model_fn=lambda features: features) + train_spec = training.TrainSpec(input_fn=lambda: 1) + eval_spec = training.EvalSpec(input_fn=lambda: 1) + invalid_continuous_eval_listener = object() + + with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG): + training._TrainingExecutor( + estimator, + train_spec, + eval_spec, + continuous_eval_listener=invalid_continuous_eval_listener) + class _TrainingExecutorTrainingTest(object): """Tests training of _TrainingExecutor.""" @@ -554,19 +578,40 @@ class _TrainingExecutorTrainingTest(object): self.assertTrue(mock_server_instance.start.called) - mock_est.train.assert_called_with(input_fn=train_spec.input_fn, - max_steps=train_spec.max_steps, - hooks=train_spec.hooks, - saving_listeners=test.mock.ANY) + mock_est.train.assert_called_with( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks), + saving_listeners=test.mock.ANY) mock_est.evaluate.assert_not_called() mock_est.export_savedmodel.assert_not_called() @test.mock.patch.object(time, 'sleep') @test.mock.patch.object(server_lib, 'Server') + def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = self._run_config + train_spec = training.TrainSpec( + input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + extra_hooks = [_FakeHook()] + + executor = training._TrainingExecutor( + mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) + self._run_task(executor) + + mock_est.train.assert_called_with( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks) + extra_hooks, + saving_listeners=test.mock.ANY) + + @test.mock.patch.object(time, 'sleep') + @test.mock.patch.object(server_lib, 'Server') def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, @@ -614,7 +659,7 @@ class _TrainingExecutorTrainingTest(object): def test_single_worker_node_with_empty_tf_master( self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) @@ -676,7 +721,7 @@ class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest, def test_delay_for_worker(self, _): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, @@ -703,7 +748,7 @@ class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest, def test_no_delay_for_chief(self, _): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, @@ -726,7 +771,8 @@ class TrainingExecutorRunMasterTest(test.TestCase): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} mock_est.config = self._run_config - mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123) + mock_train_spec = test.mock.Mock( + spec=training.TrainSpec, max_steps=123, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) executor = training._TrainingExecutor(mock_est, mock_train_spec, @@ -759,19 +805,42 @@ class TrainingExecutorRunMasterTest(test.TestCase): self.assertTrue(mock_server_instance.start.called) - mock_est.train.assert_called_with(input_fn=train_spec.input_fn, - max_steps=train_spec.max_steps, - hooks=train_spec.hooks, - saving_listeners=test.mock.ANY) + mock_est.train.assert_called_with( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks), + saving_listeners=test.mock.ANY) mock_est.export_savedmodel.assert_not_called() @test.mock.patch.object(time, 'sleep') @test.mock.patch.object(server_lib, 'Server') + def test_train_with_train_hooks(self, mock_server, unused_mock_sleep): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} + mock_est.config = self._run_config + train_spec = training.TrainSpec( + input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) + extra_hooks = [_FakeHook()] + + executor = training._TrainingExecutor( + mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) + executor.run_master() + + mock_est.train.assert_called_with( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks) + extra_hooks, + saving_listeners=test.mock.ANY) + + @test.mock.patch.object(time, 'sleep') + @test.mock.patch.object(server_lib, 'Server') def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} mock_est.config = self._run_config - mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123) + mock_train_spec = test.mock.Mock( + spec=training.TrainSpec, max_steps=123, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) executor = training._TrainingExecutor(mock_est, mock_train_spec, @@ -821,7 +890,8 @@ class TrainingExecutorRunMasterTest(test.TestCase): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} - mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123) + mock_train_spec = test.mock.Mock( + spec=training.TrainSpec, max_steps=123, hooks=[]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) @@ -1039,6 +1109,28 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): hooks=eval_spec.hooks) self.assertFalse(mock_est.train.called) + def test_evaluate_with_train_hooks(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.latest_checkpoint.return_value = 'latest_it_is' + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) + + eval_spec = training.EvalSpec( + input_fn=lambda: 1, + steps=2, + hooks=[_FakeHook()], + name='cont_eval', + start_delay_secs=0, + throttle_secs=0) + + # The train_hooks will not be called during eval. + mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook) + executor = training._TrainingExecutor( + mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook]) + executor.run_evaluator() + + mock_hook.begin.assert_not_called() + def test_evaluate_multiple_times(self): training_max_step = 200 @@ -1095,7 +1187,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): }] mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) mock_train_spec.max_steps = training_max_step class _Listener(training._ContinuousEvalListener): @@ -1112,8 +1204,9 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) - training._TrainingExecutor(mock_est, mock_train_spec, eval_spec, - listener).run_evaluator() + training._TrainingExecutor( + mock_est, mock_train_spec, eval_spec, + continuous_eval_listener=listener).run_evaluator() # Before_eval returns False during the second time, so, evaluate will be # called once. @@ -1152,8 +1245,9 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) - training._TrainingExecutor(mock_est, mock_train_spec, eval_spec, - listener).run_evaluator() + training._TrainingExecutor( + mock_est, mock_train_spec, eval_spec, + continuous_eval_listener=listener).run_evaluator() # after_eval returns False during the first time, so, evaluate will be # called once. @@ -1280,8 +1374,11 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) - executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec, - continuous_eval_listener) + executor = training._TrainingExecutor( + mock_est, + mock_train_spec, + eval_spec, + continuous_eval_listener=continuous_eval_listener) executor.run_evaluator() # Three checkpoint paths are invalid. @@ -1667,6 +1764,26 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps']) + def test_train_hooks(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + train_spec = training.TrainSpec( + input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2) + mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + extra_hooks = [_FakeHook()] + + executor = training._TrainingExecutor( + mock_est, train_spec, eval_spec, train_hooks=extra_hooks) + executor.run_local() + + train_args = mock_est.train.call_args[1] + self.assertEqual( + list(train_spec.hooks) + extra_hooks, [ + h for h in train_args['hooks'] + if not isinstance(h, training._StopAtSecsHook) + ]) + def test_errors_out_if_throttle_secs_is_zero(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index bb4f6388fd..72e3520272 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -823,7 +823,7 @@ class LayoutOptimizerTest(test.TestCase): for node in optimized_graph.node: if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']: found += 1 - self.assertEqual(node.attr['data_format'].s, 'NCHW') + self.assertEqual(node.attr['data_format'].s, b'NCHW') self.assertEqual(found, 5) def testDepthwise(self): @@ -840,7 +840,7 @@ class LayoutOptimizerTest(test.TestCase): 'DepthwiseConv2dNativeBackpropInput' ]: found += 1 - self.assertEqual(node.attr['data_format'].s, 'NCHW') + self.assertEqual(node.attr['data_format'].s, b'NCHW') self.assertEqual(found, 6) def testCheckpointCompatibility(self): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 26ce0a6518..acbbb21322 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -634,8 +634,7 @@ class Layer(object): except AttributeError: pass input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) - with ops.init_scope(): - self.build(input_shapes) + self.build(input_shapes) try: # Note: not all sub-classes of Layer call Layer.__init__ (especially # the ones under tensorflow/python/keras). Hence we recompute this diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e6070e8123..2135f6dd01 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1565,9 +1565,9 @@ def tf_version_info_genrule(): native.genrule( name="version_info_gen", srcs=[ - clean_dep("//tensorflow/tools/git:gen/spec.json"), - clean_dep("//tensorflow/tools/git:gen/head"), - clean_dep("//tensorflow/tools/git:gen/branch_ref"), + clean_dep("@local_config_git//:gen/spec.json"), + clean_dep("@local_config_git//:gen/head"), + clean_dep("@local_config_git//:gen/branch_ref"), ], outs=["util/version_info.cc"], cmd= diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD index f502c8dde0..942ceab85f 100644 --- a/tensorflow/tools/git/BUILD +++ b/tensorflow/tools/git/BUILD @@ -7,9 +7,7 @@ package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 exports_files( - glob(["gen/*"]) + [ - "gen_git_source.py", - ], + ["gen_git_source.py"], ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/tools/git/gen/branch_ref b/tensorflow/tools/git/gen/branch_ref deleted file mode 100644 index 8b13789179..0000000000 --- a/tensorflow/tools/git/gen/branch_ref +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tensorflow/tools/git/gen/head b/tensorflow/tools/git/gen/head deleted file mode 100644 index 8b13789179..0000000000 --- a/tensorflow/tools/git/gen/head +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tensorflow/tools/git/gen/spec.json b/tensorflow/tools/git/gen/spec.json deleted file mode 100644 index 176bbc21cc..0000000000 --- a/tensorflow/tools/git/gen/spec.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "git": false -} diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py index 2e27487d2f..3630dbd740 100755 --- a/tensorflow/tools/git/gen_git_source.py +++ b/tensorflow/tools/git/gen_git_source.py @@ -62,7 +62,7 @@ def parse_branch_ref(filename): raise RuntimeError("Git directory has unparseable HEAD") -def configure(src_base_path, debug=False): +def configure(src_base_path, gen_path, debug=False): """Configure `src_base_path` to embed git hashes if available.""" # TODO(aselle): No files generated or symlinked here are deleted by @@ -71,7 +71,6 @@ def configure(src_base_path, debug=False): # without running ./configure again. git_path = os.path.join(src_base_path, ".git") - gen_path = os.path.join(src_base_path, "tensorflow", "tools", "git", "gen") # Remove and recreate the path if os.path.exists(gen_path): @@ -261,6 +260,10 @@ parser.add_argument( help="Path to configure as a git repo dependency tracking sentinel") parser.add_argument( + "--gen_root_path", type=str, + help="Root path to place generated git files (created by --configure).") + +parser.add_argument( "--generate", type=str, help="Generate given spec-file, HEAD-symlink-file, ref-symlink-file", @@ -274,7 +277,9 @@ parser.add_argument( args = parser.parse_args() if args.configure is not None: - configure(args.configure, debug=args.debug) + if args.gen_root_path is None: + raise RuntimeError("Must pass --gen_root_path arg when running --configure") + configure(args.configure, args.gen_root_path, debug=args.debug) elif args.generate is not None: generate(args.generate) elif args.raw_generate is not None: diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 6a496f53f0..6571d9c463 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -2,6 +2,7 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") +load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") @@ -47,6 +48,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # version we require here. check_version("0.5.4") cuda_configure(name="local_config_cuda") + git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") python_configure(name="local_config_python") diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index c00d8d9015..6aa95ccce9 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -15,6 +15,8 @@ https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py, which was released under the BSD 3-Clause License at: https://github.com/jekbradbury/examples/blob/spinn/LICENSE +Other eager execution examples can be found under [tensorflow/contrib/eager/python/examples](../../../../tensorflow/contrib/eager/python/examples). + ## Content Python source file(s): diff --git a/third_party/git/BUILD b/third_party/git/BUILD new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/third_party/git/BUILD diff --git a/third_party/git/BUILD.tpl b/third_party/git/BUILD.tpl new file mode 100644 index 0000000000..7b031e74d5 --- /dev/null +++ b/third_party/git/BUILD.tpl @@ -0,0 +1,10 @@ +# Description: +# Exports generated files used to generate tensorflow/core/util/version_info.cc + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files( + glob(["gen/*"]), +) diff --git a/third_party/git/git_configure.bzl b/third_party/git/git_configure.bzl new file mode 100644 index 0000000000..bd197bfd24 --- /dev/null +++ b/third_party/git/git_configure.bzl @@ -0,0 +1,20 @@ +"""Repository rule for Git autoconfiguration.""" + +def _git_conf_impl(repository_ctx): + repository_ctx.template( + "BUILD", + Label("//third_party/git:BUILD.tpl")) + + tensorflow_root_path = str(repository_ctx.path( + Label("@org_tensorflow//:BUILD")))[:-len("BUILD")] + python_script_path = repository_ctx.path( + Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py")) + generated_files_path = repository_ctx.path("gen") + + repository_ctx.execute([ + python_script_path, "--configure", tensorflow_root_path, + "--gen_root_path", generated_files_path], quiet=False) + +git_configure = repository_rule( + implementation = _git_conf_impl, +) |