aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-24 21:42:52 -0500
committerGravatar GitHub <noreply@github.com>2017-12-24 21:42:52 -0500
commitaece567081c6fed35bba14745131a67ca824a4a6 (patch)
treef50fa4c4a07dd360647e2326c0ebef790b7d0f3e
parent26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca (diff)
parent4a7385f88194ae64c2cad7e2a920b66dd7edc982 (diff)
Merge pull request #15615 from caisq/branch_180008567
Branch 180008567
-rw-r--r--configure.py14
-rw-r--r--tensorflow/c/eager/c_api.cc56
-rw-r--r--tensorflow/c/eager/c_api.h13
-rw-r--r--tensorflow/c/eager/c_api_internal.h5
-rw-r--r--tensorflow/c/eager/c_api_test.cc42
-rw-r--r--tensorflow/c/eager/runtime.cc34
-rw-r--r--tensorflow/c/eager/runtime.h3
-rw-r--r--tensorflow/c/eager/runtime_test.cc4
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD20
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc100
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc44
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/one_hot_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc1
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc23
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h1
-rw-r--r--tensorflow/compiler/xla/tools/parser/README.md2
-rw-r--r--tensorflow/contrib/crf/__init__.py17
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py11
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py27
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py33
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py2
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py22
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py17
-rw-r--r--tensorflow/contrib/eager/python/examples/mnist/mnist_test.py40
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py46
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py22
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py19
-rw-r--r--tensorflow/contrib/eager/python/network_test.py4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc1
-rw-r--r--tensorflow/contrib/tpu/BUILD11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py24
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc2
-rw-r--r--tensorflow/core/common_runtime/executor.h7
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc19
-rw-r--r--tensorflow/core/common_runtime/placer.cc3
-rw-r--r--tensorflow/core/example/feature_util.h2
-rw-r--r--tensorflow/core/graph/graph_constructor.cc6
-rw-r--r--tensorflow/core/graph/graph_constructor.h6
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc10
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h5
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc7
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.h5
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc39
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h28
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc41
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc55
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc24
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc26
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt124
-rw-r--r--tensorflow/core/ops/dataset_ops.cc4
-rw-r--r--tensorflow/core/ops/ops.pbtxt37
-rw-r--r--tensorflow/docs_src/mobile/leftnav_files1
-rw-r--r--tensorflow/docs_src/mobile/tflite/demo_android.md39
-rw-r--r--tensorflow/go/op/wrappers.go126
-rw-r--r--tensorflow/python/BUILD7
-rw-r--r--tensorflow/python/estimator/training.py21
-rw-r--r--tensorflow/python/estimator/training_test.py161
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py4
-rw-r--r--tensorflow/python/layers/base.py3
-rw-r--r--tensorflow/tensorflow.bzl6
-rw-r--r--tensorflow/tools/git/BUILD4
-rw-r--r--tensorflow/tools/git/gen/branch_ref1
-rw-r--r--tensorflow/tools/git/gen/head1
-rw-r--r--tensorflow/tools/git/gen/spec.json3
-rwxr-xr-xtensorflow/tools/git/gen_git_source.py11
-rw-r--r--tensorflow/workspace.bzl2
-rw-r--r--third_party/examples/eager/spinn/README.md2
-rw-r--r--third_party/git/BUILD0
-rw-r--r--third_party/git/BUILD.tpl10
-rw-r--r--third_party/git/git_configure.bzl20
109 files changed, 1322 insertions, 615 deletions
diff --git a/configure.py b/configure.py
index 1917af4b65..e4218b5651 100644
--- a/configure.py
+++ b/configure.py
@@ -265,19 +265,6 @@ def reset_tf_configure_bazelrc():
f.write('import %workspace%/.tf_configure.bazelrc\n')
-def run_gen_git_source(environ_cp):
- """Run the gen_git_source to create links.
-
- The links are for bazel to track dependencies for git hash propagation.
-
- Args:
- environ_cp: copy of the os.environ.
- """
- cmd = '"%s" tensorflow/tools/git/gen_git_source.py --configure %s' % (
- environ_cp.get('PYTHON_BIN_PATH'), os.getcwd())
- os.system(cmd)
-
-
def cleanup_makefile():
"""Delete any leftover BUILD files from the Makefile build.
@@ -1251,7 +1238,6 @@ def main():
reset_tf_configure_bazelrc()
cleanup_makefile()
setup_python(environ_cp)
- run_gen_git_source(environ_cp)
if is_windows():
environ_cp['TF_NEED_S3'] = '0'
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index beffa191d1..589afb9031 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -518,6 +518,15 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
return;
}
+ std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
+ if (ctx->should_store_metadata.load()) {
+ maybe_stats.reset(new tensorflow::NodeExecStats);
+ maybe_stats->set_node_name(op->name);
+ maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
+ maybe_stats->set_op_start_rel_micros(0);
+ maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
+ // TODO(apassos) track referenced tensors
+ }
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
// (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
// which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
@@ -525,12 +534,38 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
// This is quite subtle. Re-work things to make this better? (Would it make
// sense for FunctionLibraryRuntime to ensure thread-safe access to
- // FunctionLibraryDefinition?).
- status->status = kernel->Run(&op->inputs, &outputs);
+ // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
+ // for ops which are a part of functions.
+ status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
if (!status->status.ok()) return;
+ if (maybe_stats != nullptr) {
+ maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
+ maybe_stats->all_start_micros());
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ if (ctx->should_store_metadata.load()) {
+ auto* step_stats = ctx->run_metadata.mutable_step_stats();
+ // Lazily initialize the RunMetadata with information about all devices if
+ // this is the first call.
+ while (step_stats->dev_stats_size() < ctx->devices().size()) {
+ step_stats->add_dev_stats();
+ }
+ // Find the current device's index.
+ int device_idx = 0;
+ for (int i = 0; i < ctx->devices().size(); ++i) {
+ if (ctx->devices()[i] == device) {
+ device_idx = i;
+ break;
+ }
+ }
+ // Populate the device stats for this device.
+ auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
+ dev_stats->set_device(device->name());
+ *dev_stats->add_node_stats() = *maybe_stats;
+ }
+ }
*num_retvals = std::min<int>(*num_retvals, outputs.size());
for (int i = 0; i < *num_retvals; ++i) {
tensorflow::Device* d = IsCPU(device) ? nullptr : device;
@@ -577,3 +612,20 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
}
return &h->t;
}
+
+void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
+ ctx->should_store_metadata.store(true);
+}
+
+void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ ctx->should_store_metadata.store(false);
+ ctx->run_metadata.Clear();
+}
+
+void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
+ TF_Status* status) {
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ status->status = MessageToBuffer(ctx->run_metadata, buf);
+ ctx->run_metadata.Clear();
+}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index ca105962df..7caab43d00 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -207,6 +207,19 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
TF_Function* function,
TF_Status* status);
+// Enables tracing of RunMetadata on the ops executed from this context.
+TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx);
+
+// Disables tracing of RunMetadata on the ops executed from this context.
+TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx);
+
+// Populates the passed-in buffer with a serialized RunMetadata protocol buffer
+// containing any run metadata information accumulated so far and clears this
+// information.
+TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
+ TF_Buffer* buf,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 0971e2ab2f..11b7a519c3 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -67,6 +67,11 @@ struct TFE_Context {
}
const std::vector<tensorflow::Device*>& devices() { return session->devices; }
+
+ // Whether we should compute RunMetadata.
+ std::atomic<bool> should_store_metadata{false};
+ tensorflow::mutex metadata_mu;
+ tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
};
struct TFE_TensorHandle {
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index c5ec0cfc31..423a7e1ff7 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/config.pb.h"
using tensorflow::string;
@@ -353,6 +354,47 @@ TEST(CAPI, Execute) {
TF_DeleteStatus(status);
}
+TEST(CAPI, ExecuteWithTracing) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ TFE_ContextEnableRunMetadata(ctx);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+ TFE_Op* matmul = MatMulOp(ctx, m, m);
+ TFE_TensorHandle* retvals[2] = {nullptr};
+ int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(m);
+ TF_Buffer* b = TF_NewBuffer();
+ TFE_ContextExportRunMetadata(ctx, b, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ tensorflow::RunMetadata rm;
+ EXPECT_TRUE(
+ rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length}));
+ TF_DeleteBuffer(b);
+ TFE_DeleteContext(ctx, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ TFE_DeleteTensorHandle(retvals[0]);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+ TF_DeleteStatus(status);
+}
+
TEST(CAPI, Function) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc
index 38066682a9..ec34b0ea77 100644
--- a/tensorflow/c/eager/runtime.cc
+++ b/tensorflow/c/eager/runtime.cc
@@ -262,7 +262,8 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
}
Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
- std::vector<Tensor>* output_tensors) {
+ std::vector<Tensor>* output_tensors,
+ NodeExecStats* stats) {
gtl::InlinedVector<TensorValue, 4> inputs;
for (Tensor& t : *input_tensors) {
inputs.push_back(TensorValue(&t));
@@ -284,6 +285,9 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
+ if (stats != nullptr) {
+ params.track_allocations = true;
+ }
// TODO(apassos): use a thread pool.
std::function<void(std::function<void()>)> runner =
[](std::function<void()> f) { f(); };
@@ -297,6 +301,34 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
for (int i = 0; i < context.num_outputs(); ++i) {
output_tensors->push_back(Tensor(*context.mutable_output(i)));
}
+ if (stats != nullptr) {
+ for (const auto& allocator_pair : context.wrapped_allocators()) {
+ AllocatorMemoryUsed* memory = stats->add_memory();
+ memory->set_allocator_name(allocator_pair.first->Name());
+ auto sizes = allocator_pair.second->GetSizes();
+ memory->set_total_bytes(std::get<0>(sizes));
+ memory->set_peak_bytes(std::get<1>(sizes));
+ memory->set_live_bytes(std::get<2>(sizes));
+
+ AllocatorStats allocator_stats;
+ allocator_pair.first->GetStats(&allocator_stats);
+ memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use);
+ allocator_pair.second->GetRecordsAndUnRef();
+ }
+ auto* ms = stats->mutable_memory_stats();
+ ms->set_host_temp_memory_size(context.host_temp_memory_size());
+ ms->set_device_temp_memory_size(context.device_temp_memory_size());
+ for (const auto& alloc_id : context.host_persistent_alloc_ids()) {
+ ms->mutable_host_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ for (const auto& alloc_id : context.device_persistent_alloc_ids()) {
+ ms->mutable_device_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_host_persistent_memory_size(
+ context.host_persistent_memory_allocated());
+ ms->set_device_persistent_memory_size(
+ context.device_persistent_memory_allocated());
+ }
return Status::OK();
}
diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h
index fb97e94a94..e28a416e67 100644
--- a/tensorflow/c/eager/runtime.h
+++ b/tensorflow/c/eager/runtime.h
@@ -175,7 +175,8 @@ class KernelAndDevice {
: device_(nullptr), flib_(nullptr), rendez_(rendez) {}
// TODO(ashankar): Handle list-valued inputs.
- Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs);
+ Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
+ NodeExecStats* stats);
const OpKernel* kernel() const { return kernel_.get(); }
diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc
index 3236c6be0e..2ccca66f67 100644
--- a/tensorflow/c/eager/runtime_test.cc
+++ b/tensorflow/c/eager/runtime_test.cc
@@ -96,7 +96,7 @@ TEST(KernelAndDevice, Run) {
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
- s = kernel.Run(&inputs, &outputs);
+ s = kernel.Run(&inputs, &outputs, nullptr);
ASSERT_TRUE(s.ok()) << s;
ASSERT_EQ(1, outputs.size());
const Tensor& out = outputs[0];
@@ -183,7 +183,7 @@ void BM_KernelAndDeviceRun(int iters) {
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(kernel.Run(&inputs, &outputs));
+ TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr));
}
}
BENCHMARK(BM_KernelAndDeviceRun);
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 2374620f58..13343967c4 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -215,7 +215,6 @@ cc_library(
":common",
":compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
- "//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -245,7 +244,6 @@ cc_library(
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
- "//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 4ccea57422..3c7dfef03d 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -120,6 +120,7 @@ cc_library(
cc_library(
name = "xla_compiler",
srcs = [
+ "const_analysis.cc",
"graph_compiler.cc",
"xla_compilation_device.cc",
"xla_compiler.cc",
@@ -133,6 +134,7 @@ cc_library(
"xla_gpu_backend.cc",
]),
hdrs = [
+ "const_analysis.h",
"graph_compiler.h",
"xla_compilation_device.h",
"xla_compiler.h",
@@ -145,7 +147,6 @@ cc_library(
visibility = [":friends"],
deps = [
":common",
- ":const_analysis",
":dump_graph",
":functionalize_control_flow",
":sharding_util",
@@ -160,7 +161,6 @@ cc_library(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -356,28 +356,16 @@ tf_cc_test(
],
)
-cc_library(
- name = "const_analysis",
- srcs = ["const_analysis.cc"],
- hdrs = ["const_analysis.h"],
- deps = [
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
tf_cc_test(
name = "const_analysis_test",
size = "small",
srcs = ["const_analysis_test.cc"],
deps = [
- ":const_analysis",
+ ":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:ops",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index ab2f1e9a7a..0249500910 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -27,96 +28,18 @@ namespace tensorflow {
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_args) {
- // TODO(phawkins): annotate these on the kernel registrations, rather than
- // using a hard-coded list.
- // (operator, argument) pairs that must be compile-time constants.
- const std::unordered_multimap<string, string> compile_time_const_inputs = {
- {"All", "reduction_indices"},
- {"Any", "reduction_indices"},
- {"ArgMin", "dimension"},
- {"ArgMax", "dimension"},
- {"AvgPoolGrad", "orig_input_shape"},
- {"AvgPool3DGrad", "orig_input_shape"},
- {"BatchToSpace", "crops"},
- {"BatchToSpaceND", "block_shape"},
- {"BatchToSpaceND", "crops"},
- {"BroadcastArgs", "s0"},
- {"BroadcastArgs", "s1"},
- {"BroadcastGradientArgs", "s0"},
- {"BroadcastGradientArgs", "s1"},
- {"Concat", "concat_dim"},
- {"ConcatV2", "axis"},
- {"ConcatOffset", "concat_dim"},
- {"ConcatOffset", "shape"},
- {"Conv2DBackpropFilter", "filter_sizes"},
- {"Conv2DBackpropInput", "input_sizes"},
- {"Conv3DBackpropFilterV2", "filter_sizes"},
- {"Conv3DBackpropInputV2", "input_sizes"},
- {"Cumprod", "axis"},
- {"Cumsum", "axis"},
- {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"},
- {"DepthwiseConv2dNativeBackpropInput", "input_sizes"},
- {"DynamicStitch", "indices"},
- {"ExpandDims", "dim"},
- {"Fill", "dims"},
- {"GatherV2", "axis"},
- {"InvertPermutation", "x"},
- {"LinSpace", "start"},
- {"LinSpace", "stop"},
- {"LinSpace", "num"},
- {"Max", "reduction_indices"},
- {"Mean", "reduction_indices"},
- {"Min", "reduction_indices"},
- {"OneHot", "depth"},
- {"Pad", "paddings"},
- {"PadV2", "paddings"},
- {"MirrorPad", "paddings"},
- {"Multinomial", "num_samples"},
- {"Prod", "reduction_indices"},
- {"RandomStandardNormal", "shape"},
- {"RandomUniform", "shape"},
- {"RandomUniformInt", "shape"},
- {"Range", "start"},
- {"Range", "limit"},
- {"Range", "delta"},
- {"Reshape", "shape"},
- {"ResizeBilinear", "size"},
- {"ResourceStridedSliceAssign", "begin"},
- {"ResourceStridedSliceAssign", "end"},
- {"ResourceStridedSliceAssign", "strides"},
- {"Reverse", "dims"},
- {"ReverseV2", "axis"},
- {"Slice", "begin"},
- {"Slice", "size"},
- {"SpaceToBatch", "paddings"},
- {"SpaceToBatchND", "block_shape"},
- {"SpaceToBatchND", "paddings"},
- {"Split", "split_dim"},
- {"SplitV", "split_dim"},
- {"SplitV", "size_splits"},
- {"StackV2", "max_size"},
- {"StridedSlice", "begin"},
- {"StridedSlice", "end"},
- {"StridedSlice", "strides"},
- {"StridedSliceGrad", "shape"},
- {"StridedSliceGrad", "begin"},
- {"StridedSliceGrad", "end"},
- {"StridedSliceGrad", "strides"},
- {"Sum", "reduction_indices"},
- {"TensorArrayV3", "size"},
- {"TensorArraySplitV3", "lengths"},
- {"Tile", "multiples"},
- {"Transpose", "perm"}};
-
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
- "Rank", "Shape", "ShapeN", "Size",
+ "Rank",
+ "Shape",
+ "ShapeN",
+ "Size",
};
Status status;
std::unordered_set<Node*> must_be_const;
- auto visit = [&status, &metadata_ops, &compile_time_const_inputs,
- &must_be_const, compile_time_const_args](Node* node) {
+ auto visit = [&status, &metadata_ops, &must_be_const,
+ compile_time_const_args](Node* node) {
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
@@ -139,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g,
}
// Mark any compile-time constant operator arguments as const.
- auto range = compile_time_const_inputs.equal_range(node->type_string());
- if (range.first == range.second) return;
+ const std::unordered_set<string>* const_inputs =
+ XlaOpRegistry::CompileTimeConstantInputs(node->type_string());
+ if (!const_inputs) return;
NameRangeMap input_name_ranges;
status =
NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr);
if (!status.ok()) return;
- for (auto it = range.first; it != range.second; ++it) {
- auto name_range = input_name_ranges.find(it->second);
+ for (const string& input : *const_inputs) {
+ auto name_range = input_name_ranges.find(input);
if (name_range == input_name_ranges.end()) continue;
for (Edge const* edge : node->in_edges()) {
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 21d3e64872..344a2ab2b6 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel {
block_shape, crops);
}
};
-REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp);
+REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"),
+ BatchToSpaceNDOp);
class BatchToSpaceOp : public XlaOpKernel {
public:
@@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel {
private:
int block_size_;
};
-REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp);
+REGISTER_XLA_OP(Name("BatchToSpace")
+ .CompileTimeConstInput("crops")
+ .CompileTimeConstInput("block_shape"),
+ BatchToSpaceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index bb031b8c47..ee2c920453 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel {
private:
TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp);
};
-REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp);
+REGISTER_XLA_OP(Name("BroadcastArgs")
+ .CompileTimeConstInput("s0")
+ .CompileTimeConstInput("s1"),
+ BCastArgsOp);
// Given shapes of two tensors, computes the reduction indices for the
// gradient computation.
@@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
};
-REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp);
+REGISTER_XLA_OP(Name("BroadcastGradientArgs")
+ .CompileTimeConstInput("s0")
+ .CompileTimeConstInput("s1"),
+ BCastGradArgsOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 592f3ecc3c..545aa364f9 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel {
};
// TODO(b/68769717): Rename this sampler to Categorical.
-REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp);
+REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"),
+ CategoricalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 73a4740e29..1a246e8df9 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel {
in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
- input_shape.DebugString(), " vs. shape[", i, "] = ",
- in_shape.DebugString()));
+ input_shape.DebugString(), " vs. shape[", i,
+ "] = ", in_shape.DebugString()));
if (in_shape.dims() == 0) {
// Inputs that come in as scalars must be reshaped to 1-vectors.
input_data.push_back(ctx->builder()->Reshape(handle, {1}));
@@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp {
: ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {}
};
-REGISTER_XLA_OP(Name("Concat"), ConcatOp);
-REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op);
+REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp);
+REGISTER_XLA_OP(Name("ConcatV2")
+ .TypeConstraint("Tidx", DT_INT32)
+ .CompileTimeConstInput("axis"),
+ ConcatV2Op);
class ConcatOffsetOp : public XlaOpKernel {
public:
@@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel {
} else {
const int32 inp0_element = inp0_literal.Get<int>({j});
const int32 inp_element = inp_literal.Get<int>({j});
- OP_REQUIRES(
- ctx, (inp0_element == inp_element),
- errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",
- inp0_element, " vs. ", inp_element));
+ OP_REQUIRES(ctx, (inp0_element == inp_element),
+ errors::InvalidArgument("input[", i, ",", j,
+ "] mismatch: ", inp0_element,
+ " vs. ", inp_element));
out_vec(j) = 0;
}
}
@@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp);
+REGISTER_XLA_OP(Name("ConcatOffset")
+ .CompileTimeConstInput("concat_dim")
+ .CompileTimeConstInput("shape"),
+ ConcatOffsetOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index aaddbe811c..da9be68732 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -445,21 +445,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp {
explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
};
-REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp);
+REGISTER_XLA_OP(
+ Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"),
+ Conv2DBackpropInputOp);
class Conv3DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
};
-REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp);
+REGISTER_XLA_OP(
+ Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"),
+ Conv3DBackpropInputOp);
class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
-REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"),
+REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput")
+ .CompileTimeConstInput("input_sizes"),
DepthwiseConv2DBackpropInputOp);
class ConvBackpropFilterOp : public XlaOpKernel {
@@ -644,7 +649,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
}
};
-REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp);
+REGISTER_XLA_OP(
+ Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"),
+ Conv2DBackpropFilterOp);
class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
public:
@@ -652,14 +659,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
}
};
-REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp);
+REGISTER_XLA_OP(
+ Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"),
+ Conv3DBackpropFilterOp);
class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
public:
explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
-REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"),
+REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter")
+ .CompileTimeConstInput("filter_sizes"),
DepthwiseConv2DBackpropFilterOp);
} // namespace
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index 7349dcb987..f2cd21ffb9 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel {
XLAShapeToTensorShape(indices_input[input_num].shape(),
&indices_shape));
const TensorShape& data_shape = data_shapes[input_num];
- OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape),
- errors::InvalidArgument(
- "data[", input_num, "].shape = ",
- data_shape.DebugString(), " does not start with indices[",
- input_num, "].shape = ", indices_shape.DebugString()));
- OP_REQUIRES(ctx,
- input_num == 0 || SameExtraShape(data0_shape, indices0_shape,
- data_shape, indices_shape),
- errors::InvalidArgument(
- "Need data[0].shape[", indices0_shape.dims(),
- ":] = data[", input_num, "].shape[", indices_shape.dims(),
- ":], got data[0].shape = ", data0_shape.DebugString(),
- ", data[", input_num, "].shape = ",
- data_shape.DebugString(), ", indices[0].shape = ",
- indices0_shape.DebugString(), ", indices[", input_num,
- "].shape = ", indices_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape),
+ errors::InvalidArgument("data[", input_num,
+ "].shape = ", data_shape.DebugString(),
+ " does not start with indices[", input_num,
+ "].shape = ", indices_shape.DebugString()));
+ OP_REQUIRES(
+ ctx,
+ input_num == 0 || SameExtraShape(data0_shape, indices0_shape,
+ data_shape, indices_shape),
+ errors::InvalidArgument(
+ "Need data[0].shape[", indices0_shape.dims(), ":] = data[",
+ input_num, "].shape[", indices_shape.dims(),
+ ":], got data[0].shape = ", data0_shape.DebugString(), ", data[",
+ input_num, "].shape = ", data_shape.DebugString(),
+ ", indices[0].shape = ", indices0_shape.DebugString(),
+ ", indices[", input_num,
+ "].shape = ", indices_shape.DebugString()));
OP_REQUIRES_OK(ctx,
XlaHelpers::ReshapeLiteral(indices_input[input_num],
@@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel {
indices0_shape.dims());
std::vector<int64> slice_limit(1 + data0_shape.dims() -
indices0_shape.dims());
- std::vector<int64> stride(1 + data0_shape.dims() -
- indices0_shape.dims(), 1);
+ std::vector<int64> stride(1 + data0_shape.dims() - indices0_shape.dims(),
+ 1);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
@@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp);
-REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp);
+REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"),
+ DynamicStitchOp);
+REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"),
+ DynamicStitchOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index 9e090fe01c..eaa13b8dfa 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Fill"), FillOp);
+REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index e420f21ca3..70192cb324 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -198,6 +198,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
}
REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice);
-REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice);
+REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"),
+ GatherOpDynamicSlice);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d91ebb500b..c0b8f9c179 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -243,7 +243,8 @@ class ResizeBilinearOp : public XlaOpKernel {
bool align_corners_;
};
-REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp);
+REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"),
+ ResizeBilinearOp);
class ResizeBilinearGradOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index e0dc1870f2..7bf4b435f5 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/false) {}
-REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp);
+REGISTER_XLA_OP(Name("ArgMax")
+ .Device(DEVICE_GPU_XLA_JIT)
+ .CompileTimeConstInput("dimension"),
+ XlaArgMaxOp);
namespace {
@@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp {
};
XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/true) {}
-REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp);
+REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 20946e247a..b1f3c3c298 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
errors::InvalidArgument("dim must be < input rank (",
input_shape.dims(), "), but got: ", dim));
const int64 dim_size = input_shape.dim_size(dim);
- OP_REQUIRES(
- ctx, dim_size > 0,
- errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ",
- input_shape.DebugString()));
+ OP_REQUIRES(ctx, dim_size > 0,
+ errors::InvalidArgument(
+ "Reduction axis ", dim,
+ " is empty in shape: ", input_shape.DebugString()));
// The output shape is the input shape contracted along dim.
TensorShape output_shape;
@@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
};
-REGISTER_XLA_OP(
- Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT),
- ArgMaxCustomCallOp);
+REGISTER_XLA_OP(Name("ArgMax")
+ .TypeConstraint("T", DT_FLOAT)
+ .Device(DEVICE_CPU_XLA_JIT)
+ .CompileTimeConstInput("dimension"),
+ ArgMaxCustomCallOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index bea1d1600b..05a36a031a 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp);
};
-REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp);
+REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"),
+ MirrorPadOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
index 2a9cfcb2eb..9f7c991380 100644
--- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
@@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
};
-REGISTER_XLA_OP(Name("OneHot"), OneHotOp);
+REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index d841bd37b3..791351637a 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Pad"), PadOp);
-REGISTER_XLA_OP(Name("PadV2"), PadOp);
+REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp);
+REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 2b6053d19d..0b5a38967a 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -455,14 +455,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp {
errors::InvalidArgument("Invalid data format"));
}
};
-REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp);
+REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"),
+ AvgPool2DGradOp);
class AvgPool3DGradOp : public AvgPoolGradOp {
public:
explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
};
-REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp);
+REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"),
+ AvgPool3DGradOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 2421825ead..c0994c434b 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
};
-REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp);
+REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
+ RandomUniformOp);
class RandomUniformIntOp : public XlaOpKernel {
public:
@@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
};
-REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp);
+REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"),
+ RandomUniformIntOp);
class RandomStandardNormalOp : public XlaOpKernel {
public:
@@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
};
-REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp);
+REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"),
+ RandomStandardNormalOp);
class TruncatedNormalOp : public XlaOpKernel {
public:
@@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp);
+REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"),
+ TruncatedNormalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 647b627408..03b13b2924 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Sum"), SumOp);
+REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp);
class ProdOp : public XlaReductionOp {
public:
@@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Prod"), ProdOp);
+REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"),
+ ProdOp);
class MinOp : public XlaReductionOp {
public:
@@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Min"), MinOp);
+REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp);
class MaxOp : public XlaReductionOp {
public:
@@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Max"), MaxOp);
+REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp);
class MeanOp : public XlaReductionOp {
public:
@@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Mean"), MeanOp);
+REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"),
+ MeanOp);
class AllOp : public XlaReductionOp {
public:
@@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("All"), AllOp);
+REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp);
class AnyOp : public XlaReductionOp {
public:
@@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Any"), AnyOp);
+REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 5952e75272..af4d64b159 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Reshape"), ReshapeOp);
+REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index bdfd066f01..17a345fc94 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -65,7 +65,7 @@ class ReverseOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Reverse"), ReverseOp);
+REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp);
class ReverseV2Op : public XlaOpKernel {
public:
@@ -103,7 +103,7 @@ class ReverseV2Op : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op);
+REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 650f8c7dc8..ee4a94164c 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -129,13 +129,19 @@ class CumsumOp : public ScanOp {
public:
explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
};
-REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp);
+REGISTER_XLA_OP(Name("Cumsum")
+ .TypeConstraint("T", kScanOpTypes)
+ .CompileTimeConstInput("axis"),
+ CumsumOp);
class CumprodOp : public ScanOp {
public:
explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
};
-REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp);
+REGISTER_XLA_OP(Name("Cumprod")
+ .TypeConstraint("T", kScanOpTypes)
+ .CompileTimeConstInput("axis"),
+ CumprodOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index c2b0e1bb4c..2c31f8d908 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Range"), RangeOp);
+REGISTER_XLA_OP(Name("Range")
+ .CompileTimeConstInput("start")
+ .CompileTimeConstInput("limit")
+ .CompileTimeConstInput("delta"),
+ RangeOp);
class LinSpaceOp : public XlaOpKernel {
public:
@@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp);
+REGISTER_XLA_OP(Name("LinSpace")
+ .CompileTimeConstInput("start")
+ .CompileTimeConstInput("stop")
+ .CompileTimeConstInput("num"),
+ LinSpaceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index e205fadd2b..8fb7a74310 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -150,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel {
ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
}
};
-REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp);
+REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
class SqueezeOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index fbe8c78d8f..be1e97bf26 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Slice"), SliceOp);
+REGISTER_XLA_OP(
+ Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"),
+ SliceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 83a87f19a7..01b46e160d 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel {
block_shape, paddings);
}
};
-REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp);
+REGISTER_XLA_OP(Name("SpaceToBatchND")
+ .CompileTimeConstInput("paddings")
+ .CompileTimeConstInput("block_shape"),
+ SpaceToBatchNDOp);
class SpaceToBatchOp : public XlaOpKernel {
public:
@@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel {
private:
int block_size_;
};
-REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp);
+REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"),
+ SpaceToBatchOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 795eb1794f..79c435c90a 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Split"), SplitOp);
+REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp);
class SplitVOp : public XlaOpKernel {
public:
@@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel {
int neg_one_dim = -1;
std::vector<int64> split_sizes_vec(num_split, -1);
const TensorShape split_size_shape = ctx->InputShape(1);
- OP_REQUIRES(ctx, split_size_shape.dims() == 1 &&
- split_size_shape.num_elements() == num_split,
+ OP_REQUIRES(ctx,
+ split_size_shape.dims() == 1 &&
+ split_size_shape.num_elements() == num_split,
errors::InvalidArgument(
"shape of tensor describing "
" the output must have dimension 1 and the same "
@@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel {
}
OP_REQUIRES(
- ctx, (neg_one_dim == -1 &&
- total_split_size == input_shape.dim_size(split_dim)) ||
- (neg_one_dim >= 0 &&
- total_split_size <= input_shape.dim_size(split_dim)),
+ ctx,
+ (neg_one_dim == -1 &&
+ total_split_size == input_shape.dim_size(split_dim)) ||
+ (neg_one_dim >= 0 &&
+ total_split_size <= input_shape.dim_size(split_dim)),
errors::InvalidArgument("Determined shape must either match "
"input shape along split_dim exactly if "
"fully specified, or be less than the size of "
@@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("SplitV"), SplitVOp);
+REGISTER_XLA_OP(Name("SplitV")
+ .CompileTimeConstInput("split_dim")
+ .CompileTimeConstInput("size_splits"),
+ SplitVOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 8013ece861..c912876e65 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -129,7 +129,7 @@ class StackOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
};
-REGISTER_XLA_OP(Name("StackV2"), StackOp);
+REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp);
class StackPushOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 6af4bd0496..f0525a5fb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp);
+REGISTER_XLA_OP(Name("StridedSlice")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceOp);
class StridedSliceGradOp : public XlaOpKernel {
public:
@@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
+REGISTER_XLA_OP(Name("StridedSliceGrad")
+ .CompileTimeConstInput("shape")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceGradOp);
class StridedSliceAssignOp : public XlaOpKernel {
public:
@@ -320,7 +329,11 @@ class StridedSliceAssignOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
+REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceAssignOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 8a742ff11c..9224072a3c 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
};
-REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"),
+ TensorArrayOp);
class TensorArrayWriteOp : public XlaOpKernel {
public:
@@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
};
-REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"),
+ TensorArraySplitOp);
class TensorArraySizeOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 9ee6bd8925..9aefcd4fc7 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
};
-REGISTER_XLA_OP(Name("Tile"), TileOp);
+REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index 2fc5d40d10..5c17b7fbf0 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -72,8 +72,9 @@ class TransposeOp : public XlaOpKernel {
}
}
for (int i = 0; i < dims; ++i) {
- OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
- i, " is missing from 'perm' argument."));
+ OP_REQUIRES(
+ ctx, bits[i],
+ errors::InvalidArgument(i, " is missing from 'perm' argument."));
}
// 0-D, 1-D, and identity transposes do nothing.
@@ -87,7 +88,7 @@ class TransposeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Transpose"), TransposeOp);
+REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
// InvertPermutation frequently forms part of the gradient of Transpose.
//
@@ -103,8 +104,9 @@ class InvertPermutationOp : public XlaOpKernel {
explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(),
- std::numeric_limits<int32>::max()),
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(ctx->InputShape(0).num_elements(),
+ std::numeric_limits<int32>::max()),
errors::InvalidArgument("permutation of nonnegative int32s "
"must have <= int32 max elements"));
@@ -128,7 +130,9 @@ class InvertPermutationOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32),
+REGISTER_XLA_OP(Name("InvertPermutation")
+ .TypeConstraint("T", DT_INT32)
+ .CompileTimeConstInput("x"),
InvertPermutationOp);
} // namespace
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index faf47434b5..97bb100fb1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -83,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
return false;
}
}
+ if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible compile time constant inputs.";
+ return false;
+ }
return true;
}
@@ -263,6 +268,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
return kernels;
}
+/* static */ const std::unordered_set<string>*
+XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end()) {
+ return nullptr;
+ }
+ return &it->second->compile_time_constant_inputs;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -337,6 +353,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
+ StringPiece input_name) {
+ registration_->compile_time_constant_inputs.insert(input_name.ToString());
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 8bfd9758f7..ff7453194a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -128,6 +128,11 @@ class XlaOpRegistry {
const string& compilation_device_name,
bool include_compilation_only_kernels);
+ // Returns the set of compile-time constant inputs to 'op'. Returns nullptr
+ // if the op is not registered.
+ static const std::unordered_set<string>* CompileTimeConstantInputs(
+ const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -181,6 +186,9 @@ class XlaOpRegistry {
bool has_device_whitelist = false;
std::unordered_set<string> device_whitelist;
+ // Names of arguments that must be compile-time constants.
+ std::unordered_set<string> compile_time_constant_inputs;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -242,6 +250,9 @@ class XlaOpRegistrationBuilder {
// Allow DT_RESOURCE types for type parameters.
XlaOpRegistrationBuilder& AllowResourceTypes();
+ // Mark 'input_name' as an argument whose value must be known at compile-time.
+ XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 2a335843f5..80d89d851e 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -276,6 +276,23 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
return false;
}
+bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) {
+ CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
+ if (!slice->IsInPlaceSlice()) {
+ return false;
+ }
+ // If this slice is lowered to an in-place version, then it forwards the
+ // operand value to the output.
+ const InstructionValueSet& operand_set =
+ GetInstructionValueSet(slice->operand(0));
+ InstructionValueSet& slice_set = GetInstructionValueSet(slice);
+ if (operand_set != slice_set) {
+ slice_set = operand_set;
+ return true;
+ }
+ return false;
+}
+
bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
CHECK_EQ(send->opcode(), HloOpcode::kSend);
bool changed = false;
@@ -527,6 +544,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
return UpdateBitcastValueSet(instruction);
+ case HloOpcode::kSlice:
+ return UpdateSliceValueSet(instruction);
case HloOpcode::kCopy:
return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
@@ -688,6 +707,11 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values();
}
break;
+ case HloOpcode::kSlice:
+ if (!instruction->IsInPlaceSlice()) {
+ define_all_values();
+ }
+ break;
case HloOpcode::kWhile:
case HloOpcode::kCall:
case HloOpcode::kConditional:
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 469620d012..89d318188f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -145,6 +145,7 @@ class HloDataflowAnalysis {
// Updates the value set for a particular instruction type. Returns whether
// the instruction value set changed.
bool UpdateBitcastValueSet(HloInstruction* bitcast);
+ bool UpdateSliceValueSet(HloInstruction* slice);
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 9b778c574c..d455cfc3f1 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -964,6 +964,17 @@ class HloInstruction {
}
const std::vector<int64>& slice_strides() const { return slice_strides_; }
+ // Returns the flag that describes whether a slice must be lowered into an
+ // offset into the original operand.
+ bool IsInPlaceSlice() const { return is_in_place_slice_; }
+
+ // Sets and returns the flag that describes whether a slice must be lowered
+ // into an offset into the original operand.
+ bool SetIsInPlaceSlice(bool value) {
+ is_in_place_slice_ = value;
+ return value;
+ }
+
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -1297,6 +1308,9 @@ class HloInstruction {
std::vector<int64> slice_limits_;
std::vector<int64> slice_strides_;
+ // Describes whether the slice can be lowered to an offset into the operand.
+ bool is_in_place_slice_ = false;
+
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_ = 0;
int32 mantissa_bits_ = 0;
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index b01fcccdb4..0cb9b5d810 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -39,7 +39,6 @@ namespace xla {
namespace interpreter {
namespace se = ::perftools::gputools;
-namespace sep = ::perftools::gputools::interpreter;
InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<const HloModule> hlo_module)
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0c84856647..657a8fe09a 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -273,6 +273,16 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
+Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) {
+ // A kSlice instruction aliases its operand if the backend lowers it to an
+ // in-place implementation.
+ if (slice->IsInPlaceSlice()) {
+ CreateCopiedPointsToSet(slice, slice->operand(0));
+ return Status::OK();
+ }
+ return DefaultAction(slice);
+}
+
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// RecvDone aliases its input (Recv) tuple element {0} to its output.
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
@@ -427,10 +437,15 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
- return FailedPrecondition(
- "LogicalBuffer %s is ill-defined: instruction %s does not define a "
- "buffer at that index",
- buffer.ToString().c_str(), buffer.instruction()->name().c_str());
+ // kSlice ops that are lowered to an in-place version are expected to not
+ // define their output buffer.
+ if (buffer.instruction()->opcode() != HloOpcode::kSlice ||
+ !buffer.instruction()->IsInPlaceSlice()) {
+ return FailedPrecondition(
+ "LogicalBuffer %s is ill-defined: instruction %s does not define a "
+ "buffer at that index",
+ buffer.ToString().c_str(), buffer.instruction()->name().c_str());
+ }
}
if (buffer.id() < 0 ||
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 8928de107e..5ca6ccb5c9 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -250,6 +250,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleSlice(HloInstruction* slice) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 2e329cc513..4b43810a68 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -1,4 +1,4 @@
-# HloModule string syntax
+# HLO Text Syntax
```yacc
hlo_module
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index bc749339bd..046c509626 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -16,15 +16,15 @@
See the @{$python/contrib.crf} guide.
-@@crf_sequence_score
-@@crf_log_norm
-@@crf_log_likelihood
-@@crf_unary_score
@@crf_binary_score
@@crf_decode
-@@CrfForwardRnnCell
-@@CrfDecodeForwardRnnCell
+@@crf_log_likelihood
+@@crf_log_norm
+@@crf_sequence_score
+@@crf_unary_score
@@CrfDecodeBackwardRnnCell
+@@CrfDecodeForwardRnnCell
+@@CrfForwardRnnCell
@@viterbi_decode
"""
@@ -32,16 +32,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
from tensorflow.contrib.crf.python.ops.crf import crf_decode
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
-from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
-from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index b47fb426a1..721dc4d080 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -179,17 +179,6 @@ class CrfTest(test.TestCase):
tf_total_log_likelihood = sess.run(total_log_likelihood)
self.assertAllClose(tf_total_log_likelihood, 0.0)
- def testLengthsToMasks(self):
- with self.test_session() as sess:
- sequence_lengths = [4, 1, 8, 2]
- max_sequence_length = max(sequence_lengths)
- mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length)
- tf_mask = sess.run(mask)
- self.assertEqual(len(tf_mask), len(sequence_lengths))
- for m, l in zip(tf_mask, sequence_lengths):
- self.assertAllEqual(m[:l], [1] * l)
- self.assertAllEqual(m[l:], [0] * (len(m) - l))
-
def testViterbiDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 7f5ae937b2..62708636c6 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -70,25 +70,6 @@ __all__ = [
]
-def _lengths_to_masks(lengths, max_length):
- """Creates a binary matrix that can be used to mask away padding.
-
- Args:
- lengths: A vector of integers representing lengths.
- max_length: An integer indicating the maximum length. All values in
- lengths should be less than max_length.
- Returns:
- masks: Masks that can be used to get rid of padding.
- """
- tiled_ranges = array_ops.tile(
- array_ops.expand_dims(math_ops.range(max_length), 0),
- [array_ops.shape(lengths)[0], 1])
- lengths = array_ops.expand_dims(lengths, 1)
- masks = math_ops.to_float(
- math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
- return masks
-
-
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
@@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
array_ops.gather(flattened_inputs, flattened_tag_indices),
[batch_size, max_seq_len])
- masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+ masks = array_ops.sequence_mask(sequence_lengths,
+ maxlen=array_ops.shape(tag_indices)[1],
+ dtype=dtypes.float32)
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
return unary_scores
@@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
binary_scores = array_ops.gather(flattened_transition_params,
flattened_transition_indices)
- masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+ masks = array_ops.sequence_mask(sequence_lengths,
+ maxlen=array_ops.shape(tag_indices)[1],
+ dtype=dtypes.float32)
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
return binary_scores
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 54f6974dba..015f69c567 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -305,10 +305,10 @@ class BatchDatasetTest(test.TestCase):
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4,
- [12])).make_initializable_iterator())
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
init_op = iterator.initializer
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
@@ -334,9 +334,9 @@ class BatchDatasetTest(test.TestCase):
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x, x], x)).apply(
batching.dense_to_sparse_batch(
- 4, [5, -1])).make_initializable_iterator())
+ 4, [5, None])).make_initializable_iterator())
init_op = iterator.initializer
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
@@ -363,25 +363,18 @@ class BatchDatasetTest(test.TestCase):
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [-2]))
- .make_initializable_iterator())
- init_op = iterator.initializer
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Dimension -2 must be >= -1"):
- sess.run(init_op)
+ with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (
dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4,
- [12])).make_initializable_iterator())
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
init_op = iterator.initializer
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ get_next = iterator.get_next()
with self.test_session() as sess:
# Initialize with an input tensor of incompatible rank.
@@ -740,7 +733,9 @@ class BatchDatasetSerializationTest(
lambda x: array_ops.fill([x], x)).apply(
batching.dense_to_sparse_batch(4, [12]))
- def testDenseToSparseBatchDatasetCore(self):
+ # TODO(b/70988345): Re-enable when sparse tensors are properly supported by
+ # the DatasetSerializationTestBase.
+ def _testDenseToSparseBatchDatasetCore(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index bf25cc60a1..7cde6e05b2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -40,6 +40,8 @@ class DatasetSerializationTestBase(test.TestCase):
def tearDown(self):
self._delete_ckpt()
+ # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of
+ # `sparse_tensors` argument.
def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
"""Runs the core tests.
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index e8b2d44a8b..e0860cbe8a 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -22,6 +22,7 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@@ -231,32 +232,29 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset):
input_dataset.output_types)
self._input_dataset = input_dataset
self._batch_size = batch_size
- # pylint: disable=protected-access
- self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape)
- # pylint: enable=protected-access
+ self._row_shape = row_shape
def _as_variant_tensor(self):
return gen_dataset_ops.dense_to_sparse_batch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._batch_size,
- self._row_shape,
- output_shapes=self.output_shapes,
- output_types=self.output_types)
+ row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
@property
def output_classes(self):
- return (ops.Tensor, ops.Tensor, ops.Tensor)
+ return sparse_tensor.SparseTensor
@property
def output_shapes(self):
- num_elements = tensor_shape.Dimension(None)
- return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1),
- tensor_shape.vector(num_elements),
- tensor_shape.vector(self._row_shape.shape[0] + 1))
+ return tensor_shape.vector(None).concatenate(self._row_shape)
@property
def output_types(self):
- return (dtypes.int64, self._input_dataset.output_types, dtypes.int64)
+ return self._input_dataset.output_types
class _RestructuredDataset(dataset_ops.Dataset):
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 00a18569fc..4bea99fbb7 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -18,15 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
from tensorflow.contrib.distributions.python.ops import mvn_tril
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
__all__ = [
@@ -170,15 +167,11 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
covariance_matrix = ops.convert_to_tensor(
covariance_matrix, name="covariance_matrix")
if validate_args:
- tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10
- diff = math_ops.abs(
- covariance_matrix
- - array_ops.matrix_transpose(covariance_matrix))
- assert_symmetric = check_ops.assert_less(
- diff, tol + tol * math_ops.abs(covariance_matrix),
- message="Matrix was not symmetric.")
- covariance_matrix = control_flow_ops.with_dependencies(
- [assert_symmetric], covariance_matrix)
+ covariance_matrix = control_flow_ops.with_dependencies([
+ check_ops.assert_near(
+ covariance_matrix,
+ array_ops.matrix_transpose(covariance_matrix),
+ message="Matrix was not symmetric")], covariance_matrix)
# No need to validate that covariance_matrix is non-singular.
# LinearOperatorLowerTriangular has an assert_non_singular method that
# is called by the Bijector.
diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
index 136085eba2..205709fe2e 100644
--- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
+++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
@@ -39,40 +39,22 @@ def random_dataset():
return tf.data.Dataset.from_tensors((images, labels))
-def train_one_epoch(defun=False):
- model = mnist.MNISTModel(data_format())
- if defun:
- model.call = tfe.defun(model.call)
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
- dataset = random_dataset()
- with tf.device(device()):
- tf.train.get_or_create_global_step()
- mnist.train_one_epoch(model, optimizer, dataset)
-
-
-def evaluate(defun=False):
- model = mnist.MNISTModel(data_format())
- dataset = random_dataset()
- if defun:
- model.call = tfe.defun(model.call)
- with tf.device(device()):
- tf.train.get_or_create_global_step()
- mnist.test(model, dataset)
-
-
class MNISTTest(tf.test.TestCase):
def testTrainOneEpoch(self):
- train_one_epoch(defun=False)
+ model = mnist.MNISTModel(data_format())
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+ dataset = random_dataset()
+ with tf.device(device()):
+ tf.train.get_or_create_global_step()
+ mnist.train_one_epoch(model, optimizer, dataset)
def testTest(self):
- evaluate(defun=False)
-
- def testTrainOneEpochWithDefunCall(self):
- train_one_epoch(defun=True)
-
- def testTestWithDefunCall(self):
- evaluate(defun=True)
+ model = mnist.MNISTModel(data_format())
+ dataset = random_dataset()
+ with tf.device(device()):
+ tf.train.get_or_create_global_step()
+ mnist.test(model, dataset)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index e2ae665a74..d8d8644dde 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -64,22 +64,14 @@ def train_one_step(model, images, labels, optimizer):
class ResNet50Test(tf.test.TestCase):
- def _apply(self, defun=False):
+ def test_apply(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
- if defun:
- model.call = tfe.defun(model.call)
with tf.device(device):
images, _ = random_batch(2)
output = model(images)
self.assertEqual((2, 1000), output.shape)
- def test_apply(self):
- self._apply(defun=False)
-
- def test_apply_with_defun(self):
- self._apply(defun=True)
-
def test_apply_no_top(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
@@ -183,11 +175,9 @@ class ResNet50Benchmarks(tf.test.Benchmark):
# a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def _benchmark_eager_apply(self, label, defun=False):
+ def benchmark_eager_apply(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
- if defun:
- model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -199,23 +189,16 @@ class ResNet50Benchmarks(tf.test.Benchmark):
start = time.time()
for _ in xrange(num_iters):
model(images).cpu()
- self._report(label, start, num_iters, device, batch_size, data_format)
-
- def benchmark_eager_apply(self):
- self._benchmark_eager_apply('eager_apply', defun=False)
-
- def benchmark_eager_apply_with_defun(self):
- self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+ self._report('eager_apply', start, num_iters, device, batch_size,
+ data_format)
- def _benchmark_eager_train(self, label, make_iterator, defun=False):
+ def _benchmark_eager_train(self, label, make_iterator):
device, data_format = device_and_data_format()
for batch_size in self._train_batch_sizes():
(images, labels) = random_batch(batch_size)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
- if defun:
- model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
@@ -234,11 +217,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train(self):
- self._benchmark_eager_train('eager_train', MockIterator, defun=False)
-
- def benchmark_eager_train_with_defun(self):
- self._benchmark_eager_train(
- 'eager_train_with_defun', MockIterator, defun=True)
+ self._benchmark_eager_train('eager_train', MockIterator)
def benchmark_eager_train_datasets(self):
@@ -247,18 +226,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
ds = tf.data.Dataset.from_tensors(tensors).repeat()
return tfe.Iterator(ds)
- self._benchmark_eager_train(
- 'eager_train_dataset', make_iterator, defun=False)
-
- def benchmark_eager_train_datasets_with_defun(self):
-
- def make_iterator(tensors):
- with tf.device('/device:CPU:0'):
- ds = tf.data.Dataset.from_tensors(tensors).repeat()
- return tfe.Iterator(ds)
-
- self._benchmark_eager_train(
- 'eager_train_dataset_with_defun', make_iterator, defun=True)
+ self._benchmark_eager_train('eager_train_dataset', make_iterator)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 2f8016ede3..bf029ca5f9 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -51,6 +51,20 @@ class Metric(object):
```python
m = SomeMetric(...)
+ inputs = ... # Some tensors to compute the metric on.
+ m_update = m(inputs)
+ # Variables defined in first call, so get the initialization op afterwards.
+ m_init = m.init_variables() # or tf.global_variables_initializer()
+ m_result = m.result()
+ with tf.Session() as sess:
+ sess.run(m_init)
+ for input in ...:
+ sess.run(m_update)
+ print(sess.run(m_result))
+ ```
+ Example use with graph execution with placeholders and feed_dict:
+ ```python
+ m = SomeMetric(...)
m_placeholder = tf.placeholder(...)
m_update = m(m_placeholder)
# Variables defined in first call, so get the initialization op afterwards.
@@ -107,6 +121,7 @@ class Metric(object):
"""Returns op to execute to update this metric for these inputs.
Returns None if eager execution is enabled.
+ Returns a graph-mode function if graph execution is enabled.
Args:
*args:
@@ -183,6 +198,13 @@ class Metric(object):
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
+ def value(self):
+ """In graph mode returns the result Tensor while in eager the callable."""
+ if context.in_graph_mode():
+ return self.result()
+ else:
+ return self.result
+
# We can support two different strategies of for doing data-parallel
# distributed metric computations:
# * Put metric variables on the first device and rely on small
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 1055f4563c..9cf34fd9b2 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.training import training_util
@@ -137,7 +138,7 @@ class MetricsTest(test.TestCase):
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.numer.name, "has_space/numer:0")
- def testGraph(self):
+ def testGraphWithPlaceholder(self):
with context.graph_mode(), self.test_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
@@ -153,6 +154,22 @@ class MetricsTest(test.TestCase):
sess.run(accumulate, feed_dict={p: 7})
self.assertAllEqual(m.result().eval(), 7)
+ @test_util.run_in_graph_and_eager_modes()
+ def testGraphAndEagerTensor(self):
+ m = metrics.Mean()
+ inputs = ops.convert_to_tensor([1.0, 2.0])
+ accumulate = m(inputs)
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.5)
+
def testTwoMeansGraph(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 8e6b947e5c..3eb4f5f8b3 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -105,13 +105,15 @@ class NetworkTest(test.TestCase):
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
+ # TODO(akshayka): This test should be changed once an API for compiling
+ # `call` into a defun is implemented.
def testReplacingNetworkCallWithDefun(self):
net = MyNetwork(name="abcd")
- net.call = function.defun(net.call)
x = constant_op.constant([[2.0]])
net(x) # Force variables to be created.
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
+ net.call = function.defun(net.call)
result = net(x) # Build and execute the TensorFlow function
self.assertEqual(34.0, self.evaluate(result))
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 6146b92311..ce38ea4754 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1710,6 +1710,7 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::Graph graph(fld);
tensorflow::ImportGraphDefOptions gc_opts;
+ gc_opts.validate_shape = false;
const auto& tf_convert_status = tensorflow::ImportGraphDef(
gc_opts, graphdef_copy, &graph, nullptr, nullptr);
if (!tf_convert_status.ok()) {
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index a34c7f91f2..43cb38f169 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -201,6 +201,17 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "tpu_config_test",
+ size = "small",
+ srcs = ["python/tpu/tpu_config_test.py"],
+ additional_deps = [
+ ":tpu_estimator",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 77ce38991b..1b6ce2dfdf 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -20,9 +20,19 @@ from __future__ import division
from __future__ import print_function
import collections
+import json
+import os
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.platform import tf_logging as logging
+
+# pylint: disable=protected-access
+_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
+_SERVICE_KEY = run_config_lib._SERVICE_KEY
+_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
+
+# pylint: enable=protected-access
class TPUConfig(
@@ -74,6 +84,9 @@ class TPUConfig(
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
'TPUConfig initial_infeed_sleep_secs')
+
+ tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
+
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
@@ -126,3 +139,14 @@ class RunConfig(run_config_lib.RunConfig):
new_instance = super(RunConfig, self).replace(**kwargs)
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
return new_instance
+
+
+def _get_tpu_job_name_from_tf_config():
+ """Extracts the TPU job name from TF_CONFIG env variable."""
+ # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
+ # spec propagation.
+ tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
+ tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
+ if tpu_job_name:
+ logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
+ return tpu_job_name
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
new file mode 100644
index 0000000000..618f263618
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU RunConfig tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
+from tensorflow.python.platform import test
+
+
+def _set_tf_config_env_variable(tf_config):
+ return test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ })
+
+
+class TPURunConfigTest(test.TestCase):
+
+ def test_fail_with_invalid_num_shards(self):
+ with self.assertRaisesRegexp(ValueError, 'must be positive'):
+ tpu_config_lib.RunConfig(
+ tpu_config=tpu_config_lib.TPUConfig(num_shards=0))
+
+ def test_fail_with_iterations_per_loop(self):
+ with self.assertRaisesRegexp(ValueError, 'must be positive'):
+ tpu_config_lib.RunConfig(
+ tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
+
+
+class TPUJobNameTest(test.TestCase):
+
+ def test_default_name(self):
+ config = tpu_config_lib.RunConfig()
+ self.assertIsNone(config.tpu_config.tpu_job_name)
+
+ def test_with_tf_config(self):
+ tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}}
+ with _set_tf_config_env_variable(tf_config):
+ config = tpu_config_lib.RunConfig()
+ self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt
index f2f5594c7c..e275cfdd3d 100644
--- a/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DenseToSparseBatchDataset.pbtxt
@@ -21,5 +21,5 @@ SparseTensor. The shape may be partially specified, using `-1` to indicate
that a particular dimension should use the maximum size of all batch elements.
END
}
- summary: "Creates a dataset that yields a SparseTensor for each element of the input."
+ summary: "Creates a dataset that batches input elements into a SparseTensor."
}
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 15edce6a68..99b33e2ef0 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1265,7 +1265,7 @@ TEST(DirectSessionTest, LocalDeviceManager) {
// A simple benchmark for the overhead of `DirectSession::Run()` calls
// with varying numbers of feeds/fetches.
-void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
+void FeedFetchBenchmarkHelper(int iters, int num_feeds) {
testing::StopTiming();
Tensor value(DT_FLOAT, TensorShape());
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index b5f4ebb005..3fd932da5b 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -202,11 +202,12 @@ class ExecutorBarrier {
// below.
if (--pending_ == 0) {
CHECK(done_cb_ != nullptr);
- done = done_cb_;
- done_cb_ = nullptr;
+ std::swap(done, done_cb_);
}
- status = status_;
+ if (!status_.ok()) {
+ status = status_;
+ }
}
if (error) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 7c09451a8a..08961fc105 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -25,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/platform/stream_executor.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index 45e97fdbf0..eea857f8ce 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+#include <cstddef>
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
@@ -22,16 +23,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/platform/stream_executor.h"
-namespace gpu = ::perftools::gputools;
-
-namespace tensorflow {
-
#define MASK_WORDS 2
#define MASK_BYTES (MASK_WORDS * sizeof(int64))
+namespace tensorflow {
namespace {
-static int64* NewMask(int64 word) {
+int64* NewMask(int64 word) {
int64* m = new int64[MASK_WORDS];
for (int i = 0; i < MASK_WORDS; ++i) {
m[i] = word;
@@ -39,8 +37,8 @@ static int64* NewMask(int64 word) {
return m;
}
-static int64* before_mask = NewMask(0xabababababababab);
-static int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd);
+int64* before_mask = NewMask(0xabababababababab);
+int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd);
bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr,
int64* mask) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index ca4b93815c..d34f0cb3c2 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -30,9 +30,8 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
+namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
const CudaGpuId cuda_gpu_id(0);
@@ -223,6 +222,7 @@ TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
a.DeallocateRaw(t1);
}
+} // namespace
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 8a3220ce2b..995fd1253f 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include <cstring>
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
@@ -47,27 +48,19 @@ const bool FLAGS_brain_mem_reg_cuda_dma = true;
// performance issues.
const bool FLAGS_brain_gpu_record_mem_types = false;
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
-
namespace {
+
bool useCudaMallocAllocator() {
const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
- if (debug_allocator_str != nullptr &&
- strcmp(debug_allocator_str, "cuda_malloc") == 0)
- return true;
- else
- return false;
+ return debug_allocator_str != nullptr &&
+ std::strcmp(debug_allocator_str, "cuda_malloc") == 0;
}
bool useCudaMemoryGuardAllocator() {
const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
- if (debug_allocator_str != nullptr &&
- strcmp(debug_allocator_str, "memory_guard") == 0)
- return true;
- else
- return false;
+ return debug_allocator_str != nullptr &&
+ std::strcmp(debug_allocator_str, "memory_guard") == 0;
}
} // namespace
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 54f082e823..a913f20751 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -369,7 +369,8 @@ class ColocationGraph {
"Could not satisfy explicit device specification '",
node->requested_device(), "' because no supported kernel for ",
specified_device_name.type, " devices is available.",
- debug_info);
+ debug_info, "\nRegistered kernels:\n",
+ KernelsRegisteredForOp(node->type_string()));
} else {
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index a87c2c9a57..4e9352ee32 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -33,7 +33,7 @@ limitations under the License.
// GetFeatureValues<int64>("tag", &example)->Add(id);
//
// Modification of bytes features is slightly different:
-// auto tag = GetFeatureValues<string>("tag", example);
+// auto tag = GetFeatureValues<string>("tag", &example);
// *tag->Add() = "lorem ipsum";
//
// To copy multiple values into a feature:
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index e19f4aebba..2a52c7516e 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -84,7 +84,8 @@ class GraphConstructor {
return_tensors(in.return_tensors),
return_nodes(in.return_nodes),
importing(true),
- validate_colocation_constraints(in.validate_colocation_constraints) {}
+ validate_colocation_constraints(in.validate_colocation_constraints),
+ validate_shape(in.validate_shape) {}
bool allow_internal_ops;
bool expect_device_spec;
@@ -108,6 +109,7 @@ class GraphConstructor {
// remove this.
bool importing;
bool validate_colocation_constraints;
+ bool validate_shape = true;
};
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
@@ -561,7 +563,7 @@ Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) {
}
Status GraphConstructor::ValidateShape(Node* node) {
- if (!opts_.importing) return Status::OK();
+ if (!opts_.importing || !opts_.validate_shape) return Status::OK();
TF_RETURN_IF_ERROR(refiner_->AddNode(node));
// For nodes with the _output_shapes attribute, override the shape.
std::vector<TensorShapeProto> shape_attrs;
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index 07814b2ef7..b03d655fe6 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -57,7 +57,8 @@ struct ImportGraphDefOptions {
ImportGraphDefOptions()
: uniquify_names(false),
uniquify_prefix(false),
- skip_mapped_nodes(false) {}
+ skip_mapped_nodes(false),
+ validate_shape(true) {}
// Name prefix to use for nodes imported from the GraphDef. For example, if
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
@@ -130,6 +131,9 @@ struct ImportGraphDefOptions {
// If true, checks that all colocation constraints are nodes in the GraphDef.
bool validate_colocation_constraints = true;
+ // If false skips shape validation.
+ bool validate_shape;
+
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index b97e3d1db1..ae70c98608 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -25,14 +25,16 @@ namespace grappler {
VirtualCluster::VirtualCluster(
const std::unordered_map<string, DeviceProperties>& devices)
- : Cluster(0), node_estimator_(new OpLevelCostEstimator()) {
+ : Cluster(0),
+ node_estimator_(new OpLevelCostEstimator()),
+ node_manager_(new FirstReadyManager()) {
devices_ = devices;
}
VirtualCluster::VirtualCluster(
const std::unordered_map<string, DeviceProperties>& devices,
- OpLevelCostEstimator* node_estimator)
- : Cluster(0), node_estimator_(node_estimator) {
+ OpLevelCostEstimator* node_estimator, ReadyNodeManager* node_manager)
+ : Cluster(0), node_estimator_(node_estimator), node_manager_(node_manager) {
devices_ = devices;
}
VirtualCluster::~VirtualCluster() {}
@@ -54,7 +56,7 @@ Status VirtualCluster::Run(const GraphDef& graph,
item.graph = graph;
item.feed = feed;
item.fetch = fetch;
- VirtualScheduler scheduler(&item, true, this);
+ VirtualScheduler scheduler(&item, true, this, node_manager_.get());
TF_RETURN_IF_ERROR(scheduler.Init());
if (metadata) {
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
index 1c73dbb240..dde70bab7a 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.h
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
namespace tensorflow {
@@ -31,7 +32,8 @@ class VirtualCluster : public Cluster {
public:
VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices);
VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
- OpLevelCostEstimator* node_estimator);
+ OpLevelCostEstimator* node_estimator,
+ ReadyNodeManager* node_manager);
~VirtualCluster() override;
@@ -45,6 +47,7 @@ class VirtualCluster : public Cluster {
private:
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
+ std::unique_ptr<ReadyNodeManager> node_manager_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index ca66f7c75a..e8d77f405c 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -34,13 +34,15 @@ AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
bool use_static_shapes)
: cluster_(cluster),
node_estimator_(new OpLevelCostEstimator()),
+ node_manager_(VirtualScheduler::ReadyNodeManagerFactory("FirstReady")),
use_static_shapes_(use_static_shapes) {}
AnalyticalCostEstimator::AnalyticalCostEstimator(
Cluster* cluster, OpLevelCostEstimator* node_estimator,
- bool use_static_shapes)
+ ReadyNodeManager* node_manager, bool use_static_shapes)
: cluster_(cluster),
node_estimator_(node_estimator),
+ node_manager_(node_manager),
use_static_shapes_(use_static_shapes) {}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
@@ -61,7 +63,8 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
}
}
std::vector<string> inaccurate_nodes;
- VirtualScheduler scheduler(&item, use_static_shapes_, cluster_);
+ VirtualScheduler scheduler(&item, use_static_shapes_, cluster_,
+ node_manager_.get());
auto status = scheduler.Init();
if (!status.ok()) {
costs->execution_time = Costs::Duration::max();
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
index cf9163302c..dd2738e088 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/lib/core/status.h"
@@ -39,9 +40,10 @@ class AnalyticalCostEstimator : public CostEstimator {
// Does not take ownership of cluster.
AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
// Does not take ownership of the cluster, but takes ownership of the
- // node_estimator
+ // node_estimator and the node_manager
AnalyticalCostEstimator(Cluster* cluster,
OpLevelCostEstimator* node_estimator,
+ ReadyNodeManager* node_manager,
bool use_static_shapes);
~AnalyticalCostEstimator() override {}
@@ -59,6 +61,7 @@ class AnalyticalCostEstimator : public CostEstimator {
Cluster* cluster_; // Not owned.
GrapplerItem item_;
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
+ std::unique_ptr<ReadyNodeManager> node_manager_;
bool use_static_shapes_;
};
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index fb3bdedcc6..a0f61c5392 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -113,10 +113,17 @@ void LIFOManager::RemoveCurrNode() {
curr_pos_ = nodes_.end(); // Reset curr_pos_.
}
-FirstReadyManager::FirstReadyManager(
- const std::unordered_map<const NodeDef*, NodeState>* node_state)
- : ReadyNodeManager(), node_state_(node_state) {
+FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
std::make_heap(nodes_.begin(), nodes_.end());
+}
+
+void FirstReadyManager::Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) {
+ // Reset the node state since different instances of the scheduler can reuse
+ // the same node_manager.
+ node_state_ = node_state;
+ nodes_.clear();
+ waiting_queue_.clear();
greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) {
// Use Node name as tie-breaker for deterministic node scheduling.
@@ -163,12 +170,14 @@ void FirstReadyManager::DrainWaitingQueue() {
waiting_queue_.clear();
}
-CompositeNodeManager::CompositeNodeManager(
- const std::unordered_map<const NodeDef*, NodeState>* node_state)
- : ReadyNodeManager(),
- send_manager_(node_state),
- recv_manager_(node_state),
- node_state_(node_state) {
+CompositeNodeManager::CompositeNodeManager()
+ : ReadyNodeManager(), send_manager_(), recv_manager_() {}
+
+void CompositeNodeManager::Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) {
+ node_state_ = node_state;
+ send_manager_.Init(node_state);
+ recv_manager_.Init(node_state);
curr_node_ = nullptr;
}
@@ -241,11 +250,11 @@ bool CompositeNodeManager::Empty() const {
return empty && send_manager_.Empty() && recv_manager_.Empty();
}
-// VirtualScheduler
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
- Cluster* cluster)
- : ready_nodes_(ReadyNodeManagerFactory("FirstReady")),
+ Cluster* cluster,
+ ReadyNodeManager* ready_nodes)
+ : ready_nodes_(ready_nodes),
graph_costs_(Costs::ZeroCosts()),
graph_properties_(*grappler_item),
cluster_(cluster),
@@ -262,9 +271,9 @@ ReadyNodeManager* VirtualScheduler::ReadyNodeManagerFactory(
} else if (ready_node_manager == "LIFO") {
return new LIFOManager();
} else if (ready_node_manager == "FirstReady") {
- return new FirstReadyManager(GetNodeStates());
+ return new FirstReadyManager();
} else if (ready_node_manager == "Composite") {
- return new CompositeNodeManager(GetNodeStates());
+ return new CompositeNodeManager();
}
LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
}
@@ -274,7 +283,7 @@ Status VirtualScheduler::Init() {
// necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual
// scheduling.
-
+ ready_nodes_->Init(GetNodeStates());
// Construct graph properties.
Status status;
if (use_static_shapes_) {
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index df8ae5861a..c180250908 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -127,6 +127,8 @@ class ReadyNodeManager {
public:
ReadyNodeManager() {}
virtual ~ReadyNodeManager() {}
+ virtual void Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
virtual void AddNode(const NodeDef* node) = 0;
virtual const NodeDef* GetCurrNode() = 0;
virtual void RemoveCurrNode() = 0;
@@ -137,6 +139,8 @@ class FIFOManager : public ReadyNodeManager {
public:
FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {}
+ virtual void Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override {
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
@@ -157,6 +161,8 @@ class LIFOManager : public ReadyNodeManager {
public:
LIFOManager() : ReadyNodeManager() {}
~LIFOManager() override {}
+ void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
+ override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override;
@@ -176,8 +182,9 @@ class LIFOManager : public ReadyNodeManager {
// time_ready value (it depends on C++ STL push_heap and pop_heap).
class FirstReadyManager : public ReadyNodeManager {
public:
- FirstReadyManager(
- const std::unordered_map<const NodeDef*, NodeState>* node_state);
+ FirstReadyManager();
+ void Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
~FirstReadyManager() override {}
void AddNode(const NodeDef* node) override { waiting_queue_.push_back(node); }
const NodeDef* GetCurrNode() override;
@@ -212,10 +219,11 @@ class FirstReadyManager : public ReadyNodeManager {
// _Send and _Recv, fairly, in terms of their time_ready.
class CompositeNodeManager : public ReadyNodeManager {
public:
- CompositeNodeManager(
- const std::unordered_map<const NodeDef*, NodeState>* node_state);
+ CompositeNodeManager();
~CompositeNodeManager() override {}
+ void Init(
+ const std::unordered_map<const NodeDef*, NodeState>* node_state) override;
void AddNode(const NodeDef* node) override;
const NodeDef* GetCurrNode() override;
void RemoveCurrNode() override;
@@ -244,8 +252,8 @@ class CompositeNodeManager : public ReadyNodeManager {
class VirtualScheduler {
public:
VirtualScheduler(const GrapplerItem* grappler_item,
- const bool use_static_shapes, Cluster* cluster);
-
+ const bool use_static_shapes, Cluster* cluster,
+ ReadyNodeManager* ready_nodes);
// Initializes NodeState and DeviceState from grappler_item_ and
// graph_properties_.
Status Init();
@@ -260,6 +268,9 @@ class VirtualScheduler {
// Like the above, but writes detailed stats to RunMetadata.
// If metadata is nullptr, then just calls and return Summary().
Costs Summary(RunMetadata* metadata);
+ // Methods called from constructor.
+ static ReadyNodeManager* ReadyNodeManagerFactory(
+ const string& ready_node_manager);
// Return per device peak memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
@@ -285,9 +296,6 @@ class VirtualScheduler {
const string kAttrDstDevice = "dst_device_";
const string kChannelDevice = "Channel";
- // Methods called from constructor.
- ReadyNodeManager* ReadyNodeManagerFactory(const string& ready_node_manager);
-
// Methods called from Init(). Fails if initialize_ is set.
void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
@@ -304,7 +312,7 @@ class VirtualScheduler {
bool IsPersistentNode(const NodeDef* node) const;
// Scheduler states:
- std::unique_ptr<ReadyNodeManager> ready_nodes_;
+ ReadyNodeManager* ready_nodes_; // Not owned.
std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index cd960bab29..9e6561230a 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -29,7 +29,8 @@ class TestVirtualScheduler : public VirtualScheduler {
public:
TestVirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes, Cluster* cluster)
- : VirtualScheduler(grappler_item, use_static_shapes, cluster) {}
+ : VirtualScheduler(grappler_item, use_static_shapes, cluster,
+ &ready_node_manager_) {}
FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
@@ -37,6 +38,9 @@ class TestVirtualScheduler : public VirtualScheduler {
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
FRIEND_TEST(VirtualSchedulerTest, Variable);
FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
+
+ protected:
+ FirstReadyManager ready_node_manager_;
};
class VirtualSchedulerTest : public ::testing::Test {
@@ -1148,23 +1152,24 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
}
TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
- FirstReadyManager manager = FirstReadyManager(&node_states_);
+ FirstReadyManager manager;
+ manager.Init(&node_states_);
manager.AddNode(&node1_);
EXPECT_EQ("Node1", manager.GetCurrNode()->name());
}
TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
- FirstReadyManager manager = FirstReadyManager(&node_states_);
-
+ FirstReadyManager manager;
+ manager.Init(&node_states_);
manager.AddNode(&node1_);
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
}
TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
- FirstReadyManager manager = FirstReadyManager(&node_states_);
-
+ FirstReadyManager manager;
+ manager.Init(&node_states_);
// Insert nodes in some random order.
manager.AddNode(&node2_);
manager.AddNode(&node1_);
@@ -1191,8 +1196,8 @@ TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
}
TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
- FirstReadyManager manager = FirstReadyManager(&node_states_);
-
+ FirstReadyManager manager;
+ manager.Init(&node_states_);
// Insert nodes in some random order.
manager.AddNode(&node2_);
manager.AddNode(&node1_);
@@ -1248,8 +1253,10 @@ TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
}
TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
- FirstReadyManager manager1 = FirstReadyManager(&node_states_);
- FirstReadyManager manager2 = FirstReadyManager(&node_states_);
+ FirstReadyManager manager1;
+ manager1.Init(&node_states_);
+ FirstReadyManager manager2;
+ manager2.Init(&node_states_);
// 6 nodes with same time_ready.
NodeDef node7;
@@ -1312,15 +1319,16 @@ TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
}
TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) {
- CompositeNodeManager manager = CompositeNodeManager(&node_states_);
-
+ CompositeNodeManager manager;
+ manager.Init(&node_states_);
manager.AddNode(&node1_);
manager.RemoveCurrNode();
EXPECT_TRUE(manager.Empty());
}
TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
- CompositeNodeManager manager = CompositeNodeManager(&node_states_);
+ CompositeNodeManager manager;
+ manager.Init(&node_states_);
manager.AddNode(&node1_);
manager.RemoveCurrNode();
@@ -1328,7 +1336,8 @@ TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
}
TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
- CompositeNodeManager manager = CompositeNodeManager(&node_states_);
+ CompositeNodeManager manager;
+ manager.Init(&node_states_);
// Add the nodes to LIFOManager.
manager.AddNode(&node1_);
@@ -1359,8 +1368,8 @@ TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
}
TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) {
- CompositeNodeManager manager = CompositeNodeManager(&node_states_);
-
+ CompositeNodeManager manager;
+ manager.Init(&node_states_);
// Additional nodes on kCPU1
NodeDef node7;
NodeDef node8;
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index c9e5f842be..2786b8cf62 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -926,7 +926,7 @@ class AvgPoolGradProcessor : public NodeProcessor {
protected:
std::vector<int> GetInputPos() const override { return {1}; }
Status CustomizedProcessing() override {
- return UpdateAttrValueOfInput(0, true);
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};
@@ -1062,9 +1062,7 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
std::vector<int> GetInputPos() const override { return {2}; }
Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32));
- return Status::OK();
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};
@@ -1371,9 +1369,7 @@ class FillProcessor : public AgnosticNodeProcessor {
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("index_type").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
}
};
@@ -1470,9 +1466,7 @@ class PadProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tpaddings").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
}
};
@@ -1484,9 +1478,7 @@ class ReverseProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tidx").type();
- TF_RETURN_IF_ERROR(
- UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
- return Status::OK();
+ return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};
@@ -1511,9 +1503,8 @@ class SplitProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- TF_RETURN_IF_ERROR(UpdateOrTransformParamInput(
- axis_node_pos_, "DataFormatDimMap", DT_INT32));
- return Status::OK();
+ return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
+ DT_INT32);
}
int axis_node_pos_;
@@ -1629,40 +1620,14 @@ class SumProcessor : public AgnosticNodeProcessor {
int port;
ParseNodeName(node_->input(0), &port);
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsPortDimsFour(*input0, port) && IsAlongDimNHW() && IsOnGPU();
+ IsPortDimsFour(*input0, port) && IsOnGPU();
}
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
Status CustomizedProcessing() override {
- return UpdateAttrValueOfInput(1, false);
- }
-
- private:
- bool IsAlongDimNHW() const {
- NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
- if (!IsConstant(*reduction_indices)) {
- return false;
- }
- Tensor tensor;
- if (reduction_indices->attr().find({"value"}) ==
- reduction_indices->attr().end()) {
- return false;
- }
- auto success =
- tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
- if (!success) {
- LOG(ERROR) << "Failed to parse TensorProto.";
- return false;
- }
- if (tensor.flat<int>().size() != 3) {
- return false;
- }
- if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
- tensor.flat<int>()(2) == 2) {
- return true;
- }
- return false;
+ DataType dtype = node_->attr().at("Tidx").type();
+ return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 6735373aac..24381a13ea 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
@@ -85,11 +86,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
input_(input) {
input_->Ref();
- output_shapes_.reserve(3);
- // Outputs represent a SparseTensor as (indices, values, dense_shape).
- output_shapes_.push_back({-1, row_shape_.dims() + 1});
- output_shapes_.push_back({-1});
- output_shapes_.push_back({row_shape_.dims() + 1});
+ output_shapes_.reserve(1);
+ PartialTensorShape output_shape({-1});
+ output_shape.AppendShape(row_shape_);
+ output_shapes_.push_back(output_shape);
}
~Dataset() override { input_->Unref(); }
@@ -101,8 +101,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
}
const DataTypeVector& output_dtypes() const override {
- static DataTypeVector* output_dtypes_ =
- new DataTypeVector({DT_INT64, DataTypeToEnum<T>::value, DT_INT64});
+ static DataTypeVector* output_dtypes_ = new DataTypeVector({DT_VARIANT});
return *output_dtypes_;
}
@@ -220,7 +219,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
{total_elements, row_ndims + 1});
Tensor values(
cpu_allocator(),
- DatasetIterator<Dataset<T>>::dataset()->output_dtypes()[1],
+ DatasetIterator<Dataset<T>>::dataset()->input_->output_dtypes()[0],
{total_elements});
auto indices_matrix = indices.matrix<int64>();
auto values_flat = values.flat<T>();
@@ -256,9 +255,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
dense_shape_vec(0) = batch_elements.size();
- out_tensors->push_back(std::move(indices));
- out_tensors->push_back(std::move(values));
- out_tensors->push_back(std::move(dense_shape));
+ Tensor serialized_sparse(DT_VARIANT, TensorShape({3}));
+ auto serialized_sparse_t = serialized_sparse.vec<Variant>();
+ serialized_sparse_t(0) = std::move(indices);
+ serialized_sparse_t(1) = std::move(values);
+ serialized_sparse_t(2) = std::move(dense_shape);
+ out_tensors->push_back(std::move(serialized_sparse));
*end_of_sequence = false;
return Status::OK();
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 1688674eb7..09ba092f40 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -566,6 +566,27 @@ class FusedBatchNormOp : public OpKernel {
bool is_training_;
};
+namespace {
+
+template <typename Device>
+void FillZeros(Tensor* t);
+
+#if GOOGLE_CUDA
+template <>
+void FillZeros<GPUDevice>(Tensor* t) {
+ cudaMemset(const_cast<char*>(t->tensor_data().data()), 0,
+ t->tensor_data().size());
+}
+#endif
+
+template <>
+void FillZeros<CPUDevice>(Tensor* t) {
+ memset(const_cast<char*>(t->tensor_data().data()), 0,
+ t->tensor_data().size());
+}
+
+} // namespace
+
template <typename Device, typename T, typename U>
class FusedBatchNormGradOp : public OpKernel {
public:
@@ -623,14 +644,17 @@ class FusedBatchNormGradOp : public OpKernel {
Tensor* offset_backprop = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
&offset_backprop));
- // two placeholders for estimated_mean and estimated_variance, which are
+ // Two placeholders for estimated_mean and estimated_variance, which are
// used for inference and thus not needed here for gradient computation.
+ // They are filled with zeros so as to avoid NaN outputs.
Tensor* placeholder_1 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(3, TensorShape({}), &placeholder_1));
+ FillZeros<Device>(placeholder_1);
Tensor* placeholder_2 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
+ FillZeros<Device>(placeholder_2);
if (is_training_) {
functor::FusedBatchNormGrad<Device, T, U>()(
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index caf4c533c3..4d95946467 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -12463,6 +12463,24 @@ op {
}
}
op {
+ name: "DecodeCompressed"
+ input_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "compression_type"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+}
+op {
name: "DecodeGif"
input_arg {
name: "contents"
@@ -18388,6 +18406,32 @@ op {
}
}
op {
+ name: "HSVToRGB"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "HashTable"
output_arg {
name: "table_handle"
@@ -31006,6 +31050,32 @@ op {
}
}
op {
+ name: "RGBToHSV"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "RandomCrop"
input_arg {
name: "image"
@@ -32351,6 +32421,60 @@ op {
is_stateful: true
}
op {
+ name: "RecordInput"
+ output_arg {
+ name: "records"
+ type: DT_STRING
+ }
+ attr {
+ name: "file_pattern"
+ type: "string"
+ }
+ attr {
+ name: "file_random_seed"
+ type: "int"
+ default_value {
+ i: 301
+ }
+ }
+ attr {
+ name: "file_shuffle_shift_ratio"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "file_buffer_size"
+ type: "int"
+ default_value {
+ i: 10000
+ }
+ }
+ attr {
+ name: "file_parallelism"
+ type: "int"
+ default_value {
+ i: 16
+ }
+ }
+ attr {
+ name: "batch_size"
+ type: "int"
+ default_value {
+ i: 32
+ }
+ }
+ attr {
+ name: "compression_type"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index e943a698ae..9f4e0e91a7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -437,13 +437,11 @@ REGISTER_OP("DenseToSparseBatchDataset")
.Input("batch_size: int64")
.Input("row_shape: int64")
.Output("handle: variant")
- // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
.Attr("output_types: list(type) >= 1")
- // NOTE(mrry): the 1st and 2nd elements will be vectors.
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
-Creates a dataset that yields a SparseTensor for each element of the input.
+Creates a dataset that batches input elements into a SparseTensor.
input_dataset: A handle to an input dataset. Must have a single component.
batch_size: A scalar representing the number of elements to accumulate in a
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 333c280146..726d376b25 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6868,6 +6868,29 @@ op {
description: "RFC 4180 format is expected for the CSV records.\n(https://tools.ietf.org/html/rfc4180)\nNote that we allow leading and trailing spaces with int or float field."
}
op {
+ name: "DecodeCompressed"
+ input_arg {
+ name: "bytes"
+ description: "A Tensor of string which is compressed."
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ description: "A Tensor with the same shape as input `bytes`, uncompressed\nfrom bytes."
+ type: DT_STRING
+ }
+ attr {
+ name: "compression_type"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ description: "A scalar containing either (i) the empty string (no\ncompression), (ii) \"ZLIB\", or (iii) \"GZIP\"."
+ }
+ summary: "Decompress strings."
+ description: "This op decompresses each element of the `bytes` input `Tensor`, which\nis assumed to be compressed using the given `compression_type`.\n\nThe `output` is a string `Tensor` of the same shape as `bytes`,\neach element containing the decompressed data from the corresponding\nelement in `bytes`."
+}
+op {
name: "DecodeGif"
input_arg {
name: "contents"
@@ -7169,7 +7192,7 @@ op {
has_minimum: true
minimum: 1
}
- summary: "Creates a dataset that yields a SparseTensor for each element of the input."
+ summary: "Creates a dataset that batches input elements into a SparseTensor."
}
op {
name: "DenseToSparseSetOperation"
@@ -10888,6 +10911,8 @@ op {
}
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -20070,6 +20095,8 @@ op {
}
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -21287,6 +21314,14 @@ op {
}
description: "The batch size."
}
+ attr {
+ name: "compression_type"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ description: "The type of compression for the file. Currently ZLIB and\nGZIP are supported. Defaults to none."
+ }
summary: "Emits randomized records."
is_stateful: true
}
diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files
index 4d2c3b6234..ac50f528ba 100644
--- a/tensorflow/docs_src/mobile/leftnav_files
+++ b/tensorflow/docs_src/mobile/leftnav_files
@@ -1,6 +1,7 @@
index.md
### TensorFlow Lite
tflite/index.md
+tflite/demo_android.md
>>>
### TensorFlow Mobile
mobile_intro.md
diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/docs_src/mobile/tflite/demo_android.md
new file mode 100644
index 0000000000..79b567897c
--- /dev/null
+++ b/tensorflow/docs_src/mobile/tflite/demo_android.md
@@ -0,0 +1,39 @@
+# TensorFlow Lite Demo for Android
+
+The TensorFlow Lite demo is a camera app that continuously classifies whatever
+it sees from your device's back camera, using a quantized MobileNet model.
+
+You'll need an Android device running Android 5.0 or higher to run the demo.
+
+To get you started working with TensorFlow Lite on Android, we'll walk you
+through building and deploying our TensorFlow demo app in Android Studio.
+
+It's also possible to build the demo app with Bazel, but we only recommend
+this for advanced users who are very familiar with the Bazel build
+environment. For more information on that, see our page [on Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite#building-tensorflow-lite-and-the-demo-app-from-source).
+
+## Build and deploy with Android Studio
+
+1. Clone the TensorFlow repository from GitHub if you haven't already:
+
+ git clone https://github.com/tensorflow/tensorflow
+
+2. Install the latest version of Android Studio from [here](https://developer.android.com/studio/index.html).
+
+3. From the **Welcome to Android Studio** screen, use the **Import Project
+ (Gradle, Eclipse ADT, etc)** option to import the
+ `tensorflow/contrib/lite/java/demo` directory as an existing Android Studio
+ Project.
+
+ Android Studio may prompt you to install Gradle upgrades and other tool
+ versions; you should accept these upgrades.
+
+4. Download the TensorFlow Lite MobileNet model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
+
+ Unzip this and copy the `mobilenet_quant_v1_224.tflite` file to the assets
+ directory: `tensorflow/contrib/lite/java/demo/app/src/main/assets/`
+
+5. Build and run the app in Android Studio.
+
+You'll have to grant permissions for the app to use the device's camera. Point
+the camera at various objects and enjoy seeing how the model classifies things!
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 47f46cde50..7d1418d420 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -5953,7 +5953,7 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source
return op.Output(0)
}
-// Creates a dataset that yields a SparseTensor for each element of the input.
+// Creates a dataset that batches input elements into a SparseTensor.
//
// Arguments:
// input_dataset: A handle to an input dataset. Must have a single component.
@@ -8732,39 +8732,6 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt
return op.Output(0)
}
-// Makes its input available to the next iteration.
-//
-// Arguments:
-// data: The tensor to be made available to the next iteration.
-//
-// Returns The same tensor as `data`.
-func NextIteration(scope *Scope, data tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "NextIteration",
- Input: []tf.Input{
- data,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Does nothing. Only useful as a placeholder for control edges.
-//
-// Returns the created operation.
-func NoOp(scope *Scope) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "NoOp",
- }
- return scope.AddOperation(opspec)
-}
-
// DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative.
type DepthwiseConv2dNativeAttr func(optionalAttr)
@@ -9731,6 +9698,39 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.
return op.Output(0)
}
+// Makes its input available to the next iteration.
+//
+// Arguments:
+// data: The tensor to be made available to the next iteration.
+//
+// Returns The same tensor as `data`.
+func NextIteration(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "NextIteration",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Does nothing. Only useful as a placeholder for control edges.
+//
+// Returns the created operation.
+func NoOp(scope *Scope) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "NoOp",
+ }
+ return scope.AddOperation(opspec)
+}
+
// Returns the rank of a tensor.
//
// This operation returns an integer representing the rank of `input`.
@@ -9969,6 +9969,53 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.
return sparse_indices, sparse_values, sparse_shapes, dense_values
}
+// DecodeCompressedAttr is an optional argument to DecodeCompressed.
+type DecodeCompressedAttr func(optionalAttr)
+
+// DecodeCompressedCompressionType sets the optional compression_type attribute to value.
+//
+// value: A scalar containing either (i) the empty string (no
+// compression), (ii) "ZLIB", or (iii) "GZIP".
+// If not specified, defaults to ""
+func DecodeCompressedCompressionType(value string) DecodeCompressedAttr {
+ return func(m optionalAttr) {
+ m["compression_type"] = value
+ }
+}
+
+// Decompress strings.
+//
+// This op decompresses each element of the `bytes` input `Tensor`, which
+// is assumed to be compressed using the given `compression_type`.
+//
+// The `output` is a string `Tensor` of the same shape as `bytes`,
+// each element containing the decompressed data from the corresponding
+// element in `bytes`.
+//
+// Arguments:
+// bytes: A Tensor of string which is compressed.
+//
+// Returns A Tensor with the same shape as input `bytes`, uncompressed
+// from bytes.
+func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeCompressed",
+ Input: []tf.Input{
+ bytes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Copy a tensor setting everything outside a central band in each innermost matrix
//
// to zero.
@@ -24616,6 +24663,17 @@ func RecordInputBatchSize(value int64) RecordInputAttr {
}
}
+// RecordInputCompressionType sets the optional compression_type attribute to value.
+//
+// value: The type of compression for the file. Currently ZLIB and
+// GZIP are supported. Defaults to none.
+// If not specified, defaults to ""
+func RecordInputCompressionType(value string) RecordInputAttr {
+ return func(m optionalAttr) {
+ m["compression_type"] = value
+ }
+}
+
// Emits randomized records.
//
// Arguments:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3a2a7015a1..c2d9c67a29 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4418,7 +4418,10 @@ py_test(
"grappler/datasets_test.py",
],
srcs_version = "PY2AND3",
- tags = ["no_pip"], # tf_optimizer is not available in pip.
+ tags = [
+ "grappler",
+ "no_pip", # tf_optimizer is not available in pip.
+ ],
deps = [
":array_ops",
":client_testlib",
@@ -4547,9 +4550,9 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
+ shard_count = 10,
tags = [
"grappler",
- "manual",
],
)
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 569ea04f01..e08cc51b4d 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -426,6 +426,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec,
eval_spec=eval_spec)
+ _execute_based_on_task_type(executor, config)
+
+
+def _execute_based_on_task_type(executor, config):
+ """Executes the `executor` based on `config.task_type`."""
if (not config.cluster_spec and
config.task_type != run_config_lib.TaskType.EVALUATOR):
logging.info('Running training and evaluation locally (non-distributed).')
@@ -462,7 +467,6 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'Task type {} is not supported. Supported task types are {}'.format(
config.task_type, [x[len('run_'):] for x in available_tasks]))
getattr(executor, task_to_run)()
- return
class _StopAtSecsHook(session_run_hook.SessionRunHook):
@@ -492,6 +496,7 @@ class _TrainingExecutor(object):
estimator,
train_spec,
eval_spec,
+ train_hooks=None,
continuous_eval_listener=None):
if not isinstance(estimator, estimator_lib.Estimator):
raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
@@ -505,6 +510,8 @@ class _TrainingExecutor(object):
raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
self._eval_spec = eval_spec
+ self._train_hooks = _validate_hooks(train_hooks)
+
if (continuous_eval_listener and
not isinstance(continuous_eval_listener, _ContinuousEvalListener)):
raise TypeError('`continuous_eval_listener` must have type '
@@ -607,7 +614,8 @@ class _TrainingExecutor(object):
self._eval_spec.throttle_secs))
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
- train_hooks = list(self._train_spec.hooks) + [stop_hook]
+ train_hooks = (
+ list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
logging.info('Start train and evaluate loop. The evaluate will happen '
'after {} secs (eval_spec.throttle_secs) or training is '
'finished.'.format(self._eval_spec.throttle_secs))
@@ -695,10 +703,11 @@ class _TrainingExecutor(object):
start_delay_secs)
time.sleep(start_delay_secs)
- self._estimator.train(input_fn=self._train_spec.input_fn,
- max_steps=self._train_spec.max_steps,
- hooks=self._train_spec.hooks,
- saving_listeners=saving_listeners)
+ self._estimator.train(
+ input_fn=self._train_spec.input_fn,
+ max_steps=self._train_spec.max_steps,
+ hooks=list(self._train_spec.hooks) + list(self._train_hooks),
+ saving_listeners=saving_listeners)
def _start_continuous_evaluation(self):
"""Repeatedly calls `Estimator` evaluate and export until training ends."""
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 6390a67762..9536ee44d5 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -72,6 +72,7 @@ _NONE_EXPORTER_NAME_MSG = (
'An Exporter cannot have a name that is `None` or empty.')
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
+_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
@@ -522,6 +523,29 @@ class TrainingExecutorConstructorTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
+ def test_invalid_train_hooks(self):
+ estimator = estimator_lib.Estimator(model_fn=lambda features: features)
+ train_spec = training.TrainSpec(input_fn=lambda: 1)
+ eval_spec = training.EvalSpec(input_fn=lambda: 1)
+ invalid_train_hooks = [object()]
+
+ with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
+ training._TrainingExecutor(
+ estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks)
+
+ def test_invalid_continuous_eval_listener(self):
+ estimator = estimator_lib.Estimator(model_fn=lambda features: features)
+ train_spec = training.TrainSpec(input_fn=lambda: 1)
+ eval_spec = training.EvalSpec(input_fn=lambda: 1)
+ invalid_continuous_eval_listener = object()
+
+ with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG):
+ training._TrainingExecutor(
+ estimator,
+ train_spec,
+ eval_spec,
+ continuous_eval_listener=invalid_continuous_eval_listener)
+
class _TrainingExecutorTrainingTest(object):
"""Tests training of _TrainingExecutor."""
@@ -554,19 +578,40 @@ class _TrainingExecutorTrainingTest(object):
self.assertTrue(mock_server_instance.start.called)
- mock_est.train.assert_called_with(input_fn=train_spec.input_fn,
- max_steps=train_spec.max_steps,
- hooks=train_spec.hooks,
- saving_listeners=test.mock.ANY)
+ mock_est.train.assert_called_with(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks),
+ saving_listeners=test.mock.ANY)
mock_est.evaluate.assert_not_called()
mock_est.export_savedmodel.assert_not_called()
@test.mock.patch.object(time, 'sleep')
@test.mock.patch.object(server_lib, 'Server')
+ def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = self._run_config
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+ extra_hooks = [_FakeHook()]
+
+ executor = training._TrainingExecutor(
+ mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
+ self._run_task(executor)
+
+ mock_est.train.assert_called_with(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks) + extra_hooks,
+ saving_listeners=test.mock.ANY)
+
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
@@ -614,7 +659,7 @@ class _TrainingExecutorTrainingTest(object):
def test_single_worker_node_with_empty_tf_master(
self, mock_server, unused_mock_sleep):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
@@ -676,7 +721,7 @@ class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,
def test_delay_for_worker(self, _):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
@@ -703,7 +748,7 @@ class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,
def test_no_delay_for_chief(self, _):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
@@ -726,7 +771,8 @@ class TrainingExecutorRunMasterTest(test.TestCase):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
mock_est.config = self._run_config
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123)
+ mock_train_spec = test.mock.Mock(
+ spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
@@ -759,19 +805,42 @@ class TrainingExecutorRunMasterTest(test.TestCase):
self.assertTrue(mock_server_instance.start.called)
- mock_est.train.assert_called_with(input_fn=train_spec.input_fn,
- max_steps=train_spec.max_steps,
- hooks=train_spec.hooks,
- saving_listeners=test.mock.ANY)
+ mock_est.train.assert_called_with(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks),
+ saving_listeners=test.mock.ANY)
mock_est.export_savedmodel.assert_not_called()
@test.mock.patch.object(time, 'sleep')
@test.mock.patch.object(server_lib, 'Server')
+ def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
+ mock_est.config = self._run_config
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
+ extra_hooks = [_FakeHook()]
+
+ executor = training._TrainingExecutor(
+ mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
+ executor.run_master()
+
+ mock_est.train.assert_called_with(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks) + extra_hooks,
+ saving_listeners=test.mock.ANY)
+
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
mock_est.config = self._run_config
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123)
+ mock_train_spec = test.mock.Mock(
+ spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
@@ -821,7 +890,8 @@ class TrainingExecutorRunMasterTest(test.TestCase):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec, max_steps=123)
+ mock_train_spec = test.mock.Mock(
+ spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
@@ -1039,6 +1109,28 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
hooks=eval_spec.hooks)
self.assertFalse(mock_est.train.called)
+ def test_evaluate_with_train_hooks(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.latest_checkpoint.return_value = 'latest_it_is'
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1,
+ steps=2,
+ hooks=[_FakeHook()],
+ name='cont_eval',
+ start_delay_secs=0,
+ throttle_secs=0)
+
+ # The train_hooks will not be called during eval.
+ mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook)
+ executor = training._TrainingExecutor(
+ mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])
+ executor.run_evaluator()
+
+ mock_hook.begin.assert_not_called()
+
def test_evaluate_multiple_times(self):
training_max_step = 200
@@ -1095,7 +1187,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
}]
mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
mock_train_spec.max_steps = training_max_step
class _Listener(training._ContinuousEvalListener):
@@ -1112,8 +1204,9 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
eval_spec = training.EvalSpec(
input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
- training._TrainingExecutor(mock_est, mock_train_spec, eval_spec,
- listener).run_evaluator()
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, eval_spec,
+ continuous_eval_listener=listener).run_evaluator()
# Before_eval returns False during the second time, so, evaluate will be
# called once.
@@ -1152,8 +1245,9 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
eval_spec = training.EvalSpec(
input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
- training._TrainingExecutor(mock_est, mock_train_spec, eval_spec,
- listener).run_evaluator()
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, eval_spec,
+ continuous_eval_listener=listener).run_evaluator()
# after_eval returns False during the first time, so, evaluate will be
# called once.
@@ -1280,8 +1374,11 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
eval_spec = training.EvalSpec(
input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
- executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec,
- continuous_eval_listener)
+ executor = training._TrainingExecutor(
+ mock_est,
+ mock_train_spec,
+ eval_spec,
+ continuous_eval_listener=continuous_eval_listener)
executor.run_evaluator()
# Three checkpoint paths are invalid.
@@ -1667,6 +1764,26 @@ class TrainingExecutorRunLocalTest(test.TestCase):
self.assertEqual(train_spec.input_fn, train_args['input_fn'])
self.assertEqual(train_spec.max_steps, train_args['max_steps'])
+ def test_train_hooks(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+ mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2)
+ mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ extra_hooks = [_FakeHook()]
+
+ executor = training._TrainingExecutor(
+ mock_est, train_spec, eval_spec, train_hooks=extra_hooks)
+ executor.run_local()
+
+ train_args = mock_est.train.call_args[1]
+ self.assertEqual(
+ list(train_spec.hooks) + extra_hooks, [
+ h for h in train_args['hooks']
+ if not isinstance(h, training._StopAtSecsHook)
+ ])
+
def test_errors_out_if_throttle_secs_is_zero(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index bb4f6388fd..72e3520272 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -823,7 +823,7 @@ class LayoutOptimizerTest(test.TestCase):
for node in optimized_graph.node:
if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']:
found += 1
- self.assertEqual(node.attr['data_format'].s, 'NCHW')
+ self.assertEqual(node.attr['data_format'].s, b'NCHW')
self.assertEqual(found, 5)
def testDepthwise(self):
@@ -840,7 +840,7 @@ class LayoutOptimizerTest(test.TestCase):
'DepthwiseConv2dNativeBackpropInput'
]:
found += 1
- self.assertEqual(node.attr['data_format'].s, 'NCHW')
+ self.assertEqual(node.attr['data_format'].s, b'NCHW')
self.assertEqual(found, 6)
def testCheckpointCompatibility(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 26ce0a6518..acbbb21322 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -634,8 +634,7 @@ class Layer(object):
except AttributeError:
pass
input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
- with ops.init_scope():
- self.build(input_shapes)
+ self.build(input_shapes)
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
# the ones under tensorflow/python/keras). Hence we recompute this
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index e6070e8123..2135f6dd01 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -1565,9 +1565,9 @@ def tf_version_info_genrule():
native.genrule(
name="version_info_gen",
srcs=[
- clean_dep("//tensorflow/tools/git:gen/spec.json"),
- clean_dep("//tensorflow/tools/git:gen/head"),
- clean_dep("//tensorflow/tools/git:gen/branch_ref"),
+ clean_dep("@local_config_git//:gen/spec.json"),
+ clean_dep("@local_config_git//:gen/head"),
+ clean_dep("@local_config_git//:gen/branch_ref"),
],
outs=["util/version_info.cc"],
cmd=
diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD
index f502c8dde0..942ceab85f 100644
--- a/tensorflow/tools/git/BUILD
+++ b/tensorflow/tools/git/BUILD
@@ -7,9 +7,7 @@ package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(
- glob(["gen/*"]) + [
- "gen_git_source.py",
- ],
+ ["gen_git_source.py"],
)
# -----------------------------------------------------------------------------
diff --git a/tensorflow/tools/git/gen/branch_ref b/tensorflow/tools/git/gen/branch_ref
deleted file mode 100644
index 8b13789179..0000000000
--- a/tensorflow/tools/git/gen/branch_ref
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tensorflow/tools/git/gen/head b/tensorflow/tools/git/gen/head
deleted file mode 100644
index 8b13789179..0000000000
--- a/tensorflow/tools/git/gen/head
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tensorflow/tools/git/gen/spec.json b/tensorflow/tools/git/gen/spec.json
deleted file mode 100644
index 176bbc21cc..0000000000
--- a/tensorflow/tools/git/gen/spec.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "git": false
-}
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index 2e27487d2f..3630dbd740 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -62,7 +62,7 @@ def parse_branch_ref(filename):
raise RuntimeError("Git directory has unparseable HEAD")
-def configure(src_base_path, debug=False):
+def configure(src_base_path, gen_path, debug=False):
"""Configure `src_base_path` to embed git hashes if available."""
# TODO(aselle): No files generated or symlinked here are deleted by
@@ -71,7 +71,6 @@ def configure(src_base_path, debug=False):
# without running ./configure again.
git_path = os.path.join(src_base_path, ".git")
- gen_path = os.path.join(src_base_path, "tensorflow", "tools", "git", "gen")
# Remove and recreate the path
if os.path.exists(gen_path):
@@ -261,6 +260,10 @@ parser.add_argument(
help="Path to configure as a git repo dependency tracking sentinel")
parser.add_argument(
+ "--gen_root_path", type=str,
+ help="Root path to place generated git files (created by --configure).")
+
+parser.add_argument(
"--generate",
type=str,
help="Generate given spec-file, HEAD-symlink-file, ref-symlink-file",
@@ -274,7 +277,9 @@ parser.add_argument(
args = parser.parse_args()
if args.configure is not None:
- configure(args.configure, debug=args.debug)
+ if args.gen_root_path is None:
+ raise RuntimeError("Must pass --gen_root_path arg when running --configure")
+ configure(args.configure, args.gen_root_path, debug=args.debug)
elif args.generate is not None:
generate(args.generate)
elif args.raw_generate is not None:
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6a496f53f0..6571d9c463 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -2,6 +2,7 @@
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
+load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure")
@@ -47,6 +48,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# version we require here.
check_version("0.5.4")
cuda_configure(name="local_config_cuda")
+ git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
index c00d8d9015..6aa95ccce9 100644
--- a/third_party/examples/eager/spinn/README.md
+++ b/third_party/examples/eager/spinn/README.md
@@ -15,6 +15,8 @@ https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py,
which was released under the BSD 3-Clause License at:
https://github.com/jekbradbury/examples/blob/spinn/LICENSE
+Other eager execution examples can be found under [tensorflow/contrib/eager/python/examples](../../../../tensorflow/contrib/eager/python/examples).
+
## Content
Python source file(s):
diff --git a/third_party/git/BUILD b/third_party/git/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/git/BUILD
diff --git a/third_party/git/BUILD.tpl b/third_party/git/BUILD.tpl
new file mode 100644
index 0000000000..7b031e74d5
--- /dev/null
+++ b/third_party/git/BUILD.tpl
@@ -0,0 +1,10 @@
+# Description:
+# Exports generated files used to generate tensorflow/core/util/version_info.cc
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+exports_files(
+ glob(["gen/*"]),
+)
diff --git a/third_party/git/git_configure.bzl b/third_party/git/git_configure.bzl
new file mode 100644
index 0000000000..bd197bfd24
--- /dev/null
+++ b/third_party/git/git_configure.bzl
@@ -0,0 +1,20 @@
+"""Repository rule for Git autoconfiguration."""
+
+def _git_conf_impl(repository_ctx):
+ repository_ctx.template(
+ "BUILD",
+ Label("//third_party/git:BUILD.tpl"))
+
+ tensorflow_root_path = str(repository_ctx.path(
+ Label("@org_tensorflow//:BUILD")))[:-len("BUILD")]
+ python_script_path = repository_ctx.path(
+ Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py"))
+ generated_files_path = repository_ctx.path("gen")
+
+ repository_ctx.execute([
+ python_script_path, "--configure", tensorflow_root_path,
+ "--gen_root_path", generated_files_path], quiet=False)
+
+git_configure = repository_rule(
+ implementation = _git_conf_impl,
+)