aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@gmail.com>2018-03-13 15:33:30 -0700
committerGravatar GitHub <noreply@github.com>2018-03-13 15:33:30 -0700
commit0c18589a71dba03edbaaca67a40815006ed26043 (patch)
treed399267fa8e45a4ab4b60c79752e4f60807a46aa
parent9dc58c6c5c28766a52152a7df865ca20dcb434a7 (diff)
parent4780f9e24874903b4b533efcfa0042cc8e9b5e44 (diff)
Merge pull request #17688 from frankchn/branch_188893722
Branch 188893722
-rw-r--r--SECURITY.md18
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/c/c_api.cc38
-rw-r--r--tensorflow/c/c_api_internal.h12
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc824
-rw-r--r--tensorflow/c/eager/c_api.h58
-rw-r--r--tensorflow/c/eager/c_api_internal.h206
-rw-r--r--tensorflow/c/eager/c_api_test.cc380
-rw-r--r--tensorflow/c/eager/runtime.h3
-rw-r--r--tensorflow/c/python_api.cc3
-rw-r--r--tensorflow/cc/framework/while_gradients.cc6
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc12
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc74
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc4
-rw-r--r--tensorflow/compiler/xla/service/service.cc2
-rw-r--r--tensorflow/contrib/bayesflow/BUILD20
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py157
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad.py34
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py110
-rw-r--r--tensorflow/contrib/boosted_trees/ops/quantile_ops.cc14
-rw-r--r--tensorflow/contrib/distributions/BUILD38
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py153
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py263
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py47
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py3
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py58
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine.py29
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py138
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py40
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/square.py84
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/eager/python/datasets.py128
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py21
-rw-r--r--tensorflow/contrib/feature_column/BUILD2
-rw-r--r--tensorflow/contrib/feature_column/__init__.py6
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py259
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py194
-rw-r--r--tensorflow/contrib/image/kernels/segmentation_ops.cc4
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py48
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py71
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py320
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py57
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py80
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py13
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py349
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py139
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py233
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py2
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py122
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py3
-rw-r--r--tensorflow/contrib/lite/context.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc115
-rw-r--r--tensorflow/contrib/lite/interpreter.h22
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc4
-rw-r--r--tensorflow/contrib/lite/python/BUILD32
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py135
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py82
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/BUILD32
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc313
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h72
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i25
-rw-r--r--tensorflow/contrib/lite/python/lite.py3
-rw-r--r--tensorflow/contrib/lite/python/op_hint.py6
-rw-r--r--tensorflow/contrib/lite/schema/BUILD8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc54
-rw-r--r--tensorflow/contrib/lite/util.cc16
-rw-r--r--tensorflow/contrib/lite/util.h8
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc6
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc6
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py12
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.cc6
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h11
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc27
-rw-r--r--tensorflow/core/common_runtime/eval_const_tensor.cc19
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc27
-rw-r--r--tensorflow/core/common_runtime/memory_types.cc4
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc83
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc4
-rw-r--r--tensorflow/core/graph/subgraph.cc19
-rw-r--r--tensorflow/core/graph/subgraph.h6
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc461
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc212
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc321
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc100
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.h2
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc2
-rw-r--r--tensorflow/core/kernels/data_format_ops.cc36
-rw-r--r--tensorflow/core/kernels/data_format_ops.h79
-rw-r--r--tensorflow/core/kernels/mutex_ops.cc12
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc2
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h7
-rw-r--r--tensorflow/core/kernels/sparse_cross_op.cc2
-rw-r--r--tensorflow/core/kernels/split_v_op.cc8
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt177
-rw-r--r--tensorflow/core/ops/ops.pbtxt177
-rw-r--r--tensorflow/docs_src/programmers_guide/version_compat.md42
-rw-r--r--tensorflow/go/op/wrappers.go346
-rw-r--r--tensorflow/python/data/ops/BUILD2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py25
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py148
-rw-r--r--tensorflow/python/debug/README.md6
-rw-r--r--tensorflow/python/eager/core_test.py24
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc6
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc9
-rw-r--r--tensorflow/python/estimator/replicate_model_fn.py55
-rw-r--r--tensorflow/python/estimator/replicate_model_fn_test.py164
-rw-r--r--tensorflow/python/feature_column/feature_column.py163
-rw-r--r--tensorflow/python/framework/framework_lib.py1
-rw-r--r--tensorflow/python/framework/ops.py2
-rw-r--r--tensorflow/python/framework/smart_cond_test.py18
-rw-r--r--tensorflow/python/framework/tensor_spec.py20
-rw-r--r--tensorflow/python/framework/tensor_spec_test.py21
-rw-r--r--tensorflow/python/framework/test_util.py26
-rw-r--r--tensorflow/python/framework/test_util_test.py20
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py6
-rw-r--r--tensorflow/python/layers/core_test.py23
-rw-r--r--tensorflow/python/lib/core/py_func.cc16
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc9
-rw-r--r--tensorflow/python/ops/math_ops.py2
-rw-r--r--tensorflow/python/ops/nn_test.py32
-rw-r--r--tensorflow/python/tools/BUILD1
-rwxr-xr-xtensorflow/tools/integration_tests/gcs_smoke_test/BUILD67
-rwxr-xr-xtensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py253
-rwxr-xr-xtensorflow/tools/integration_tests/gcs_smoke_test/setup.sh20
-rwxr-xr-xtensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh26
-rwxr-xr-xtensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh21
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py4
-rw-r--r--tensorflow/workspace.bzl18
153 files changed, 6475 insertions, 2945 deletions
diff --git a/SECURITY.md b/SECURITY.md
index 665a480ba7..2aaa9202d5 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow.
## TensorFlow models are programs
-TensorFlow's runtime system interprets and executes programs. What machine
+TensorFlow's runtime system interprets and executes programs. What machine
learning practitioners term
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
@@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run
inference on the data.
**TensorFlow models are programs, and need to be treated as such from a security
-perspective.**
+perspective.**
## Running untrusted models
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
-[nsjail](https://github.com/google/nsjail)).
+[nsjail](https://github.com/google/nsjail)).
There are several ways in which a model could become untrusted. Obviously, if an
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
@@ -109,7 +109,7 @@ graphs known to the `ModelServer`. This means that an attacker may run
graphs using untrusted inputs as described above, but they would not be able to
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
directly to an untrusted network, **but only if the graphs it is configured to
-use have been carefully audited to be safe**.
+use have been carefully audited to be safe**.
Similar to best practices for other servers, we recommend running any
`ModelServer` with appropriate privileges (i.e., using a separate user with
@@ -133,7 +133,7 @@ which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
can perform arbitrary computations means that they may read and write files,
communicate via the network, produce deadlocks and infinite loops, or run out
of memory. It is only when these behaviors are outside the specifications of the
-operations involved that such behavior is a vulnerability.
+operations involved that such behavior is a vulnerability.
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
@@ -168,7 +168,7 @@ below).
Please use a descriptive subject line for your report email. After the initial
reply to your report, the security team will endeavor to keep you informed of
-the progress being made towards a fix and announcement.
+the progress being made towards a fix and announcement.
If you believe that an existing (public) issue is security-related, please send
an email to `security@tensorflow.org`. The email should include the issue ID and
@@ -233,7 +233,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known vulnerabilities
-| Type | Versions affected | Reported by | Additional Information |
-|-------------------|:-----------------:|-----------------------|-----------------------------|
-| out of bounds read| <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+| Type | Versions affected | Reported by | Additional Information |
+|-------------------|:-----------------:|--------------------|-----------------------------|
+| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index d152281d5d..3828ee0ddb 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -682,7 +682,6 @@ filegroup(
"//tensorflow/tools/docs:all_files",
"//tensorflow/tools/git:all_files",
"//tensorflow/tools/graph_transforms:all_files",
- "//tensorflow/tools/integration_tests/gcs_smoke_test:all_files",
"//tensorflow/tools/mlpbtxt:all_files",
"//tensorflow/tools/proto_text:all_files",
"//tensorflow/tools/quantization:all_files",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 8b9b3da21c..778cb667e2 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -63,6 +63,7 @@ limitations under the License.
// brain namespace because we are defining 'extern "C"' functions.
using tensorflow::AllocationDescription;
using tensorflow::DataType;
+using tensorflow::ExtendSessionGraphHelper;
using tensorflow::Graph;
using tensorflow::GraphDef;
using tensorflow::mutex_lock;
@@ -640,11 +641,11 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
}
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
- const char* mutation_type)
- EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
+ const char* mutation_type) {
// If any session has already run this node_id, mark this session as
// unrunnable.
for (auto it : graph->sessions) {
+ mutex_lock session_lock(it.first->mu);
if (it.first->last_num_graph_nodes > op.node.id()) {
it.second = FailedPrecondition(
"Operation '", op.node.DebugString(), "' was changed by ",
@@ -713,10 +714,12 @@ Status LoadLibrary(const char* library_filename, void** result,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
-bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(session->mu) {
+bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
if (session->graph != nullptr) {
+ // Take the graph lock before the session lock to avoid deadlock. This is
+ // safe since session->graph does not change.
session->graph->mu.lock();
+ mutex_lock session_lock(session->mu);
const Graph& graph = session->graph->graph;
status->status = session->graph->sessions[session];
@@ -2571,12 +2574,9 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- {
- mutex_lock l(session->mu);
- if (session->extend_before_run &&
- !tensorflow::ExtendSessionGraphHelper(session, status)) {
- return;
- }
+ if (session->extend_before_run &&
+ !ExtendSessionGraphHelper(session, status)) {
+ return;
}
TF_Run_Setup(noutputs, output_values, status);
@@ -2612,12 +2612,9 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
const char** handle, TF_Status* status) {
*handle = nullptr;
- {
- mutex_lock l(session->mu);
- if (session->extend_before_run &&
- !tensorflow::ExtendSessionGraphHelper(session, status)) {
- return;
- }
+ if (session->extend_before_run &&
+ !ExtendSessionGraphHelper(session, status)) {
+ return;
}
std::vector<string> input_names(ninputs);
@@ -2659,12 +2656,9 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- {
- mutex_lock l(session->mu);
- if (session->extend_before_run &&
- !tensorflow::ExtendSessionGraphHelper(session, status)) {
- return;
- }
+ if (session->extend_before_run &&
+ !ExtendSessionGraphHelper(session, status)) {
+ return;
}
TF_Run_Setup(noutputs, output_values, status);
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index 25233931de..e885a69927 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -124,16 +124,16 @@ struct TF_Session {
TF_Session(tensorflow::Session* s, TF_Graph* g);
tensorflow::Session* session;
- TF_Graph* graph;
+ TF_Graph* const graph;
- tensorflow::mutex mu;
+ tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
// ExtendSessionGraphHelper before running the graph (this is the default
// public behavior). Can be set to false if the caller needs to call
// ExtendSessionGraphHelper manually.
- bool extend_before_run GUARDED_BY(mu);
+ std::atomic<bool> extend_before_run;
};
struct TF_ImportGraphDefOptions {
@@ -211,9 +211,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
TF_Status* status);
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
- const char* mutation_type);
+ const char* mutation_type)
+ EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
-bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status);
+bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
+ LOCKS_EXCLUDED(session->graph->mu, session->mu);
} // end namespace tensorflow
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index e55cb672e9..3046d9064a 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -58,6 +58,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework_lite",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index b9a47ea244..56cec2d668 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
@@ -67,6 +68,7 @@ string DeviceName(const tensorflow::Device* d) {
#ifdef TENSORFLOW_EAGER_USE_XLA
std::atomic_int_fast64_t func_id_generator(0);
#endif // TENSORFLOW_EAGER_USE_XLA
+
} // namespace
TFE_ContextDevicePlacementPolicy PlacementPolicy(
@@ -90,11 +92,33 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
TF_SetConfig(&options->session_options, proto, proto_len, status);
}
+void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
+ unsigned char async) {
+ options->async = async;
+}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->policy = policy;
}
+TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
+ unsigned char async,
+ TF_Status* status) {
+ {
+ tensorflow::mutex_lock l(ctx->async_map_mu);
+ ctx->thread_local_async[std::this_thread::get_id()] = async;
+ }
+ if (async) {
+ ctx->executor.EnableAsync();
+ } else {
+ // TODO(agarwal): Currently we add a wait here to handle cases where a sync
+ // op has a control dependency on an async op, and the latter has not
+ // executed yet. This wait can be removed by storing all the control inputs
+ // and waiting for them when executing ops.
+ status->status = ctx->executor.WaitForAllPendingNodes();
+ }
+}
+
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
@@ -113,7 +137,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
- status->status = tensorflow::Status::OK();
+ status->status = ctx->executor.WaitForAllPendingNodes();
{
tensorflow::mutex_lock ml(ctx->cache_mu);
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
@@ -139,6 +163,9 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
ctx->thread_local_policies[std::this_thread::get_id()] = policy;
}
+// Note: this function looks up a thread local policy. So it should be called in
+// the appropriate client thread. In particular, in async mode, it may not be
+// safe to call this function from the async TFE_Executor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::mutex_lock ml(ctx->policy_map_mu);
@@ -150,6 +177,18 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
return ctx->policy;
}
+void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
+ status->status = ctx->executor.WaitForAllPendingNodes();
+}
+
+void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
+ status->status = ctx->executor.status();
+}
+
+void TFE_ContextAsyncClearError(TFE_Context* ctx) {
+ ctx->executor.ClearError();
+}
+
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -157,56 +196,70 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
return new TFE_TensorHandle(tensor, nullptr, nullptr);
}
-void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
+void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
+ DCHECK(h);
+ h->Unref();
+}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
- return static_cast<TF_DataType>(h->t.dtype());
+ return static_cast<TF_DataType>(h->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
- status->status = tensorflow::Status::OK();
- return h->t.dims();
+ const tensorflow::Tensor* t = nullptr;
+ status->status = h->Tensor(&t);
+ return t == nullptr ? 0 : t->dims();
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
- status->status = tensorflow::Status::OK();
- return h->t.dim_size(dim_index);
+ const tensorflow::Tensor* t = nullptr;
+ status->status = h->Tensor(&t);
+ return t == nullptr ? 0 : t->dim_size(dim_index);
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
- status->status = tensorflow::Status::OK();
- return (h->op_device == nullptr)
- ? "/job:localhost/replica:0/task:0/device:CPU:0"
- : h->op_device->name().c_str();
+ tensorflow::Device* d = nullptr;
+ status->status = h->OpDevice(&d);
+ return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+ : d->name().c_str();
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
- if (!IsCPU(h->d)) {
+ // TODO(agarwal): move this implementation inside TFE_TensorHandle.
+ tensorflow::Device* d = nullptr;
+ tensorflow::Device* op_device = nullptr;
+ const tensorflow::Tensor* t = nullptr;
+ status->status = h->TensorAndDevice(&t, &d, &op_device);
+ if (!status->status.ok()) return nullptr;
+ if (!IsCPU(d)) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat(
"TFE_TensorHandle can be resolved iff it is on CPU (this "
"handle is on ",
- h->d->name(),
+ d->name(),
"). Consider using TFE_TensorHandleCopyToDevice to get a "
"copy of the tensor on CPU")
.c_str());
return nullptr;
}
- return tensorflow::TF_TensorFromTensor(h->t, status);
+ return tensorflow::TF_TensorFromTensor(*t, status);
}
+} // extern "C"
-TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
- TFE_Context* ctx,
- const char* device_name,
- TF_Status* status) {
- tensorflow::Device* dstd = ctx->devices[0];
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
- if (!status->status.ok()) return nullptr;
- }
+namespace {
- tensorflow::Device* srcd = h->d == nullptr ? ctx->devices[0] : h->d;
+tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ tensorflow::Device* dstd,
+ TFE_TensorHandle** output) {
+ const tensorflow::Tensor* src = nullptr;
+ tensorflow::Device* srcd = nullptr;
+ // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
+ // nullptr.
+ tensorflow::Device* src_opd = nullptr;
+ TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
+ if (srcd == nullptr) srcd = ctx->devices[0];
bool is_same_device =
(srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
const bool dst_cpu = IsCPU(dstd);
@@ -216,18 +269,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
const bool both_on_cpu = src_cpu && dst_cpu;
if (is_same_device || both_on_cpu) {
dstd = dst_cpu ? nullptr : dstd;
- return new TFE_TensorHandle(h->t, dstd, dstd);
+ *output = new TFE_TensorHandle(*src, dstd, dstd);
+ return tensorflow::Status::OK();
}
- tensorflow::Tensor* src = &(h->t);
if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
!tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
- TF_SetStatus(
- status, TF_INVALID_ARGUMENT,
- tensorflow::strings::StrCat("Can't copy Tensor with type ",
- tensorflow::DataTypeString(src->dtype()),
- " to device ", DeviceName(dstd), ".")
- .c_str());
- return nullptr;
+ return tensorflow::errors::InvalidArgument(
+ "Can't copy Tensor with type ",
+ tensorflow::DataTypeString(src->dtype()), " to device ",
+ DeviceName(dstd), ".");
}
tensorflow::AllocatorAttributes attr;
if (src->dtype() == tensorflow::DT_VARIANT) {
@@ -236,7 +286,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
if (src->shape().num_elements() == 0) {
dstd = dst_cpu ? nullptr : dstd;
- return new TFE_TensorHandle(dst, dstd, dstd);
+ *output = new TFE_TensorHandle(dst, dstd, dstd);
+ return tensorflow::Status::OK();
}
tensorflow::DeviceContext* src_device_context = nullptr;
if (!src_cpu) {
@@ -253,21 +304,26 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// With that setup, Sync()ing across all 3 streams should be sufficient
// but more than necessary (since it waits for operations that might have
// nothing to do with this tensor to complete).
- status->status = srcd->Sync();
+ TF_RETURN_IF_ERROR(srcd->Sync());
tensorflow::Notification n;
+ tensorflow::Status status;
tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
srcd, dstd, tensorflow::AllocatorAttributes(),
tensorflow::AllocatorAttributes(), src, &dst,
- [status, &n](const tensorflow::Status& s) {
- status->status = s;
+ [&status, &n](const tensorflow::Status& s) {
+ status = s;
n.Notify();
});
n.WaitForNotification();
- return (TF_GetCode(status) == TF_OK)
- ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd,
- dst_cpu ? nullptr : dstd)
- : nullptr;
+ if (status.ok()) {
+ dstd = dst_cpu ? nullptr : dstd;
+ *output = new TFE_TensorHandle(dst, dstd, dstd);
+ }
+ return status;
}
+} // namespace
+
+extern "C" {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
@@ -311,16 +367,19 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
- // Questionable heuristic ...
- // - If a device was explicitly set on the op, always use that.
- // - If not, place on the first non-host device seen.
- if (op->device == nullptr && !IsCPU(h->d)) {
- op->device = h->d;
+ if (op->device == nullptr) {
+ // Questionable heuristic ...
+ // - If a device was explicitly set on the op, always use that.
+ // - If not, place on the first non-host device seen.
+ tensorflow::Device* d = nullptr;
+ // TODO(agarwal): This call may block if h is not ready. Avoid this if
+ // possible.
+ status->status = h->Device(&d);
+ if (!status->status.ok()) return;
+ if (!IsCPU(d)) op->device = d;
}
- if (!status->status.ok()) return;
- op->inputs.push_back(h->t);
- op->input_devices.push_back(h->d);
- op->input_op_devices.push_back(h->op_device);
+ h->Ref();
+ op->inputs.push_back(h);
op->attrs.NumInputs(op->inputs.size());
}
@@ -482,14 +541,14 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
+} // extern "C"
namespace {
tensorflow::Status ValidateInputTypeAndPlacement(
TFE_Context* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
- const tensorflow::OpKernel* kernel,
- std::vector<TFE_TensorHandle*>* copied_tensors) {
+ const tensorflow::OpKernel* kernel) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument(
@@ -498,14 +557,17 @@ tensorflow::Status ValidateInputTypeAndPlacement(
for (int i = 0; i < op->inputs.size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
+ TFE_TensorHandle* handle = op->inputs[i];
+ tensorflow::Device* handle_device = nullptr;
+ TF_RETURN_IF_ERROR(handle->Device(&handle_device));
const tensorflow::Device* actual_device =
- op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
+ handle_device == nullptr ? host_device : handle_device;
if (expected_device != actual_device) {
switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up
// to python level.
- if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
+ if (handle->dtype == tensorflow::DT_INT32) {
// Note: enabling silent copies of int32 tensors to match behavior
// of graph mode.
break;
@@ -536,36 +598,245 @@ tensorflow::Status ValidateInputTypeAndPlacement(
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
- TFE_TensorHandle original{op->inputs[i], op->input_devices[i],
- op->device};
TF_Status* s = TF_NewStatus();
TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
- &original, ctx, expected_device->name().c_str(), s);
- if (!s->status.ok()) {
- tensorflow::Status status = s->status;
- delete s;
+ handle, ctx, expected_device->name().c_str(), s);
+ tensorflow::Status status = s->status;
+ TF_DeleteStatus(s);
+ if (!status.ok()) {
+ if (copied_tensor != nullptr) copied_tensor->Unref();
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
expected_device->name(), " in order to run ", op->name, ": ",
status.error_message());
}
- op->inputs[i] = copied_tensor->t;
- copied_tensors->push_back(copied_tensor);
- op->input_devices[i] = copied_tensor->d;
- delete s;
+ handle->Unref();
+ handle = copied_tensor;
+ op->inputs[i] = copied_tensor;
}
- if (op->inputs[i].dtype() != kernel->input_type(i)) {
+ if (handle->dtype != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be a ",
tensorflow::DataTypeString(kernel->input_type(i)),
- " tensor but is a ",
- tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor");
+ " tensor but is a ", tensorflow::DataTypeString(handle->dtype),
+ " tensor");
}
}
return tensorflow::Status::OK();
}
+tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
+ TFE_Context* ctx, TF_Status* status) {
+ tensorflow::DeviceSet ds;
+ for (tensorflow::Device* d : ctx->devices) {
+ ds.AddDevice(d);
+ }
+ tensorflow::DeviceTypeVector final_devices;
+ status->status = tensorflow::SupportedDeviceTypesForNode(
+ ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ if (final_devices.empty()) {
+ status->status = tensorflow::errors::Internal(
+ "Could not find valid device for node ", ndef.DebugString());
+ return nullptr;
+ }
+ for (tensorflow::Device* d : ctx->devices) {
+ if (d->device_type() == final_devices[0].type_string()) {
+ return d;
+ }
+ }
+ status->status = tensorflow::errors::Unknown(
+ "Could not find a device for node ", ndef.DebugString());
+ return nullptr;
+}
+
+tensorflow::Status Execute(
+ TFE_Context* ctx, tensorflow::Device* device,
+ const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
+ tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
+ TFE_TensorHandle** retvals, int num_retvals) {
+ if (!ctx->soft_placement && device == nullptr) {
+ // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
+ device = ctx->devices[0];
+ }
+
+ if (device == nullptr) {
+ // TODO(apassos) debug how the assignment below might return a different
+ // device from the one requested above.
+ device = kernel->device();
+ }
+
+ std::vector<tensorflow::Tensor> outputs(1);
+ const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
+ output_memory_types = &kernel->kernel()->output_memory_types();
+ std::vector<tensorflow::Tensor> inputs(op_inputs.size());
+ for (int i = 0; i < op_inputs.size(); ++i) {
+ const tensorflow::Tensor* input_tensor = nullptr;
+ TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
+ inputs[i] = *input_tensor;
+ }
+ // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
+ // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
+ // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
+ // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
+ // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
+ // This is quite subtle. Re-work things to make this better? (Would it make
+ // sense for FunctionLibraryRuntime to ensure thread-safe access to
+ // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
+ // for ops which are a part of functions.
+ // TODO(agarwal): change Run to take vector of handles ?
+ TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
+ if (maybe_stats != nullptr) {
+ maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
+ maybe_stats->all_start_micros());
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ if (ctx->should_store_metadata.load()) {
+ auto* step_stats = ctx->run_metadata.mutable_step_stats();
+ // Lazily initialize the RunMetadata with information about all devices if
+ // this is the first call.
+ while (step_stats->dev_stats_size() < ctx->devices.size()) {
+ step_stats->add_dev_stats();
+ }
+ // Find the current device's index.
+ int device_idx = 0;
+ for (int i = 0; i < ctx->devices.size(); ++i) {
+ if (ctx->devices[i] == device) {
+ device_idx = i;
+ break;
+ }
+ }
+ // Populate the device stats for this device.
+ auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
+ dev_stats->set_device(device->name());
+ *dev_stats->add_node_stats() = *maybe_stats;
+ }
+ }
+ if (num_retvals != outputs.size()) {
+ return tensorflow::errors::InvalidArgument(
+ "Expecting ", num_retvals, " outputs but got ", outputs.size());
+ }
+ tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
+ for (int i = 0; i < num_retvals; ++i) {
+ tensorflow::Device* d = op_device;
+ if (d != nullptr && output_memory_types != nullptr &&
+ (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
+ d = nullptr;
+ }
+ if (retvals[i] == nullptr) {
+ retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device);
+ } else {
+ retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+// TODO(agarwal): move TFE_Executor and TFE_Node related code to a separate
+// file.
+class ExecuteNode : public TFE_Node {
+ public:
+ ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel,
+ tensorflow::NodeExecStats* maybe_stats,
+ const tensorflow::DataTypeVector& output_dtypes,
+ TFE_TensorHandle** retvals, int num_retvals)
+ : TFE_Node(op->ctx->executor.NextId()),
+ ctx_(op->ctx),
+ op_device_(op->device),
+ inputs_(op->inputs),
+ kernel_(kernel),
+ maybe_stats_(maybe_stats),
+ retvals_(num_retvals) {
+ for (auto handle : inputs_) {
+ handle->Ref();
+ }
+ TFE_Context* ctx = op->ctx;
+ for (int i = 0; i < num_retvals; ++i) {
+ TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx);
+ h->Ref();
+ retvals[i] = h;
+ retvals_[i] = h;
+ }
+ }
+
+ ~ExecuteNode() override {
+ for (auto handle : inputs_) {
+ handle->Unref();
+ }
+ for (auto handle : retvals_) {
+ handle->Unref();
+ }
+ }
+
+ tensorflow::Status Run() override {
+ const tensorflow::Status status =
+ Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
+ retvals_.begin(), retvals_.size());
+ if (status.ok()) {
+ return status;
+ } else {
+ return tensorflow::Status(
+ status.code(),
+ tensorflow::strings::StrCat("Got error, \"", status.error_message(),
+ "\" while executing kernel ",
+ kernel_->kernel()->def().DebugString()));
+ }
+ }
+
+ private:
+ TFE_Context* ctx_;
+ tensorflow::Device* op_device_;
+ tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_;
+ tensorflow::KernelAndDevice* kernel_;
+ std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
+ tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_;
+};
+
+class CopyToDeviceNode : public TFE_Node {
+ public:
+ CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
+ TFE_Context* ctx)
+ : TFE_Node(ctx->executor.NextId()),
+ src_(src),
+ dstd_(dstd),
+ ctx_(ctx),
+ dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) {
+ src_->Ref();
+ dst_->Ref();
+ }
+
+ ~CopyToDeviceNode() override {
+ src_->Unref();
+ dst_->Unref();
+ }
+
+ tensorflow::Status Run() override {
+ TFE_TensorHandle* temp = nullptr;
+ TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
+ const tensorflow::Tensor* tensor = nullptr;
+ tensorflow::Device* device = nullptr;
+ tensorflow::Device* op_device = nullptr;
+ tensorflow::Status status =
+ temp->TensorAndDevice(&tensor, &device, &op_device);
+ // `temp` is a ready handle. So the following call should return OK.
+ TF_DCHECK_OK(status) << status.error_message();
+ DCHECK(tensor);
+ dst_->SetTensorAndDevice(*tensor, device, op_device);
+ temp->Unref();
+ return tensorflow::Status::OK();
+ }
+
+ TFE_TensorHandle* dst() { return dst_; }
+
+ private:
+ TFE_TensorHandle* src_;
+ tensorflow::Device* dstd_;
+ TFE_Context* ctx_;
+ TFE_TensorHandle* dst_;
+};
+
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
// primitive op (e.g. matmul).
@@ -631,7 +902,7 @@ const tensorflow::FunctionDef* OpToFunction(
(*op_input_to_func_input)[i] = const_index;
func_input_arg = signature->mutable_input_arg(const_index++);
const_input_types->push_back(
- static_cast<TF_DataType>(op->inputs[i].dtype()));
+ static_cast<TF_DataType>(op->inputs[i]->dtype));
} else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
VLOG(1) << "For resource input, mapping op input " << i
<< " to func input " << resource_index;
@@ -643,11 +914,11 @@ const tensorflow::FunctionDef* OpToFunction(
(*op_input_to_func_input)[i] = arg_index;
func_input_arg = signature->mutable_input_arg(arg_index++);
arg_input_types->push_back(
- static_cast<TF_DataType>(op->inputs[i].dtype()));
+ static_cast<TF_DataType>(op->inputs[i]->dtype));
}
func_input_arg->set_name(op_input_arg.name());
- func_input_arg->set_type(op->inputs[i].dtype());
+ func_input_arg->set_type(op->inputs[i]->dtype);
}
VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
@@ -740,22 +1011,16 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
// Since input param reordering may have occurred between `op` and `launch_op`
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
launch_op->inputs = op->inputs;
- launch_op->input_devices = op->input_devices;
- launch_op->input_op_devices = op->input_op_devices;
+ for (TFE_TensorHandle* h : launch_op->inputs) {
+ h->Ref();
+ }
if (!op_input_to_func_input.empty()) {
DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
- if (!op->input_devices.empty()) {
- DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
- }
for (int i = 0; i < op_input_to_func_input.size(); ++i) {
VLOG(1) << "mapping op input " << i << " to func input "
<< op_input_to_func_input[i];
launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
- if (!op->input_devices.empty()) {
- launch_op->input_devices[op_input_to_func_input[i]] =
- op->input_devices[i];
- }
}
}
launch_op->attrs.NumInputs(op->inputs.size());
@@ -789,37 +1054,17 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
}
#endif // TENSORFLOW_EAGER_USE_XLA
-tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
- TFE_Context* ctx, TF_Status* status) {
- tensorflow::DeviceSet ds;
- for (tensorflow::Device* d : ctx->devices) {
- ds.AddDevice(d);
- }
- tensorflow::DeviceTypeVector final_devices;
- status->status = tensorflow::SupportedDeviceTypesForNode(
- ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
- if (!status->status.ok()) {
- return nullptr;
- }
- if (final_devices.empty()) {
- status->status = tensorflow::errors::Internal(
- "Could not find valid device for node ", ndef.DebugString());
- return nullptr;
- }
- for (tensorflow::Device* d : ctx->devices) {
- if (d->device_type() == final_devices[0].type_string()) {
- return d;
- }
- }
- status->status = tensorflow::errors::Unknown(
- "Could not find a device for node ", ndef.DebugString());
- return nullptr;
-}
-
} // namespace
+extern "C" {
+
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
+ TFE_Context* ctx = op->ctx;
+ status->status = ctx->executor.status();
+ if (!status->status.ok()) {
+ return;
+ }
#ifdef TENSORFLOW_EAGER_USE_XLA
std::unique_ptr<TFE_Op> xla_launch_op;
if (op->use_xla && op->name != "_XlaLaunch") {
@@ -830,31 +1075,29 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
op = xla_launch_op.get();
}
#endif // TENSORFLOW_EAGER_USE_XLA
- TFE_Context* ctx = op->ctx;
- tensorflow::Device* device = op->device;
// Ensure all resource-touching ops run in the device the resource is,
// regardless of anything else that has been specified. This is identical to
// the graph mode behavior.
for (int i = 0; i < op->inputs.size(); ++i) {
- if (op->inputs[i].dtype() == tensorflow::DT_RESOURCE &&
- op->input_op_devices[i] != device) {
- tensorflow::Device* d = op->input_op_devices[i] == nullptr
- ? ctx->devices[0]
- : op->input_op_devices[i];
+ tensorflow::Device* input_op_device = nullptr;
+ status->status = op->inputs[i]->OpDevice(&input_op_device);
+ if (!status->status.ok()) return;
+ if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
+ input_op_device != op->device) {
+ tensorflow::Device* d =
+ input_op_device == nullptr ? ctx->devices[0] : input_op_device;
VLOG(1) << "Changing device of operation " << op->name << " to "
<< d->name() << " because input #" << i
<< " is a resource in this device.";
- device = d;
op->device = d;
}
}
+ tensorflow::Device* device = op->device;
if (!ctx->soft_placement && device == nullptr) {
// TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
device = ctx->devices[0];
}
- std::vector<tensorflow::Tensor> outputs(1);
- const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
tensorflow::Fprint128 cache_key =
op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
tensorflow::KernelAndDevice* kernel;
@@ -879,8 +1122,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
- // See WARNING comment below - would be nice to rework to avoid this
- // subtlety.
+ // See WARNING comment in Execute (before kernel->Run) - would be nice to
+ // rework to avoid this subtlety.
tensorflow::tf_shared_lock l(ctx->functions_mu);
status->status =
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
@@ -903,29 +1146,30 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
tensorflow::DataTypeVector input_dtypes;
status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
- kernel->output_dtypes());
+ kernel->mutable_output_dtypes());
if (!status->status.ok()) {
return;
}
tensorflow::mutex_lock ml(ctx->cache_mu);
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
+ const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
+ if (output_dtypes.size() != *num_retvals) {
+ TF_SetStatus(status, TF_INVALID_ARGUMENT,
+ tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
+ " outputs, but *num_retvals is ",
+ *num_retvals)
+ .c_str());
+ return;
+ }
if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different
// device from the one requested above.
device = kernel->device();
}
-
- std::vector<TFE_TensorHandle*> copied_tensors;
- status->status = ValidateInputTypeAndPlacement(
- ctx, ctx->devices[0], device, op, kernel->kernel(), &copied_tensors);
- output_memory_types = &kernel->kernel()->output_memory_types();
- if (!status->status.ok()) {
- for (auto* t : copied_tensors) {
- TFE_DeleteTensorHandle(t);
- }
- return;
- }
+ status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device,
+ op, kernel->kernel());
+ if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
if (ctx->should_store_metadata.load()) {
maybe_stats.reset(new tensorflow::NodeExecStats);
@@ -935,53 +1179,47 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
// TODO(apassos) track referenced tensors
}
- // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
- // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
- // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
- // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
- // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
- // This is quite subtle. Re-work things to make this better? (Would it make
- // sense for FunctionLibraryRuntime to ensure thread-safe access to
- // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
- // for ops which are a part of functions.
- status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
- for (auto* t : copied_tensors) {
- TFE_DeleteTensorHandle(t);
- }
- if (!status->status.ok()) return;
- if (maybe_stats != nullptr) {
- maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
- maybe_stats->all_start_micros());
- tensorflow::mutex_lock ml(ctx->metadata_mu);
- if (ctx->should_store_metadata.load()) {
- auto* step_stats = ctx->run_metadata.mutable_step_stats();
- // Lazily initialize the RunMetadata with information about all devices if
- // this is the first call.
- while (step_stats->dev_stats_size() < ctx->devices.size()) {
- step_stats->add_dev_stats();
- }
- // Find the current device's index.
- int device_idx = 0;
- for (int i = 0; i < ctx->devices.size(); ++i) {
- if (ctx->devices[i] == device) {
- device_idx = i;
- break;
- }
- }
- // Populate the device stats for this device.
- auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
- dev_stats->set_device(device->name());
- *dev_stats->add_node_stats() = *maybe_stats;
+ if (ctx->Async()) {
+ // Note that for async mode, execution order will make sure that all
+ // input handles are ready before executing them.
+ // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
+ TFE_Node* node = new ExecuteNode(op, kernel, maybe_stats.release(),
+ output_dtypes, retvals, *num_retvals);
+ ctx->executor.Add(node);
+ } else {
+ // Execute checks if retvals[i] is nullptr or not to figure if it needs to
+ // allocate it.
+ for (int i = 0; i < *num_retvals; ++i) {
+ retvals[i] = nullptr;
}
+ status->status = Execute(op->ctx, op->device, op->inputs, kernel,
+ maybe_stats.get(), retvals, *num_retvals);
}
- *num_retvals = std::min<int>(*num_retvals, outputs.size());
- for (int i = 0; i < *num_retvals; ++i) {
- tensorflow::Device* d = IsCPU(device) ? nullptr : device;
- if (d != nullptr && output_memory_types != nullptr &&
- (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
- d = nullptr;
- }
- retvals[i] = new TFE_TensorHandle(outputs[i], d, device);
+}
+
+TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+ TFE_Context* ctx,
+ const char* device_name,
+ TF_Status* status) {
+ status->status = ctx->executor.status();
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ tensorflow::Device* dstd = ctx->devices[0];
+ if (device_name != nullptr && strlen(device_name) > 0) {
+ status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
+ if (!status->status.ok()) return nullptr;
+ }
+ if (ctx->Async()) {
+ // Note that `h` may not be currently ready. However execution order will
+ // make sure that `h` is ready before the copy is actually done.
+ CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
+ ctx->executor.Add(node);
+ return node->dst();
+ } else {
+ TFE_TensorHandle* output = nullptr;
+ status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
+ return output;
}
}
@@ -1004,6 +1242,16 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
}
+void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
+ ctx->should_store_metadata.store(true);
+}
+
+void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
+ tensorflow::mutex_lock ml(ctx->metadata_mu);
+ ctx->should_store_metadata.store(false);
+ ctx->run_metadata.Clear();
+}
+
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
@@ -1012,27 +1260,24 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
TFE_TensorHandle* h, TF_Status* status) {
- if (h->d != nullptr) {
+ tensorflow::Device* d = nullptr;
+ tensorflow::Device* op_device = nullptr;
+ const tensorflow::Tensor* t = nullptr;
+ status->status = h->TensorAndDevice(&t, &d, &op_device);
+ if (!status->status.ok()) return nullptr;
+ if (d != nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
"TFE_TensorHandle is placed in device (not host) memory. Cannot return "
"a tensorflow::Tensor");
return nullptr;
}
- return &h->t;
-}
-
-void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
- ctx->should_store_metadata.store(true);
-}
-
-void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
- tensorflow::mutex_lock ml(ctx->metadata_mu);
- ctx->should_store_metadata.store(false);
- ctx->run_metadata.Clear();
+ return t;
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
+ TFE_ContextAsyncWait(ctx, status);
+ if (!status->status.ok()) return;
tensorflow::mutex_lock ml(ctx->metadata_mu);
status->status = MessageToBuffer(ctx->run_metadata, buf);
ctx->run_metadata.Clear();
@@ -1108,3 +1353,208 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
+
+TFE_Node::TFE_Node(tensorflow::uint64 id) : id(id) {}
+
+TFE_Executor::~TFE_Executor() {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ thread_done_ = true;
+ nodes_pending_.notify_all();
+}
+
+tensorflow::uint64 TFE_Executor::NextId() {
+ tensorflow::mutex_lock l(next_id_mutex_);
+ return next_id_++;
+}
+
+void TFE_Executor::EnableAsync() {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ if (thread_ == nullptr) {
+ thread_.reset(tensorflow::Env::Default()->StartThread(
+ tensorflow::ThreadOptions(), "eager_async_executor",
+ std::bind(&TFE_Executor::Run, this)));
+ }
+}
+
+void TFE_Executor::Add(TFE_Node* node) {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ DCHECK(thread_) << "EnableAsync should have been called before Add";
+ if (!status_.ok()) {
+ delete node;
+ return;
+ }
+ int qlen = node_queue_.size();
+ if (qlen > 0) {
+ if (node_queue_.back()->id >= node->id) {
+ status_ = tensorflow::errors::InvalidArgument(
+ "Inserting TFE_Node with non-increasing ids:", node_queue_.back()->id,
+ " vs ", node->id);
+ delete node;
+ return;
+ }
+ node_queue_.push(node);
+ } else {
+ node_queue_.push(node);
+ nodes_pending_.notify_all();
+ }
+}
+
+tensorflow::Status TFE_Executor::WaitFor(tensorflow::uint64 node_id) {
+ return WaitImpl(false, node_id);
+}
+
+tensorflow::Status TFE_Executor::WaitForAllPendingNodes() {
+ return WaitImpl(true, 0);
+}
+
+tensorflow::Status TFE_Executor::WaitImpl(bool wait_all,
+ tensorflow::uint64 node_id) {
+ tensorflow::condition_variable cond;
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ // Don't wait if an error is already set.
+ if (!status_.ok()) return status_;
+ if (node_queue_.empty()) return tensorflow::Status::OK();
+ if (wait_all) {
+ node_id = node_queue_.back()->id;
+ } else if (node_id < node_queue_.front()->id) {
+ // Note that we are relying on the ops being dispatched sequentially from
+ // the queue.
+ return tensorflow::Status::OK();
+ }
+ node_done_notifications_.insert(std::make_pair(node_id, &cond));
+ cond.wait(l);
+ // Note that we could be woken up if an error occurs, even though the node has
+ // not actually executed.
+ return status_;
+}
+
+void TFE_Executor::ClearError() {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ if (status_.ok()) return;
+ // If an error was set, node_done_notifications_ and node_queue_ should have
+ // been cleared, and no new entries should have been added since.
+ DCHECK(node_done_notifications_.empty());
+ DCHECK(node_queue_.empty());
+ status_ = tensorflow::Status::OK();
+ nodes_pending_.notify_all();
+}
+
+tensorflow::Status TFE_Executor::status() {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ return status_;
+}
+
+void TFE_Executor::Run() {
+ while (true) {
+ std::unique_ptr<TFE_Node> curr_node;
+ {
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ while (node_queue_.empty() || !status_.ok()) {
+ if (thread_done_) return;
+ nodes_pending_.wait(l);
+ }
+ curr_node.reset(node_queue_.front());
+ }
+ tensorflow::Status status = curr_node->Run();
+ const bool ok = status.ok();
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ node_queue_.pop();
+ if (!ok) {
+ status_ = status;
+ // TODO(agarwal): mark all affected handles as corrupted before clearing
+ // this queue.
+ // We remove any pending ops so that we don't try to execute them if
+ // ClearError is called.
+ for (int i = 0; i < node_queue_.size(); ++i) {
+ delete node_queue_.front();
+ node_queue_.pop();
+ }
+ }
+ if (!node_done_notifications_.empty()) {
+ tensorflow::uint64 node_id = curr_node->id;
+ // Note that we notify all waiting threads in case an error has occurred.
+ // These calling threads are responsible for checking status_ before
+ // proceeding.
+ const auto range = ok ? node_done_notifications_.equal_range(node_id)
+ : make_pair(node_done_notifications_.begin(),
+ node_done_notifications_.end());
+ for (auto it = range.first; it != range.second; ++it) {
+ it->second->notify_all();
+ }
+ node_done_notifications_.erase(range.first, range.second);
+ }
+ }
+}
+
+bool TFE_Context::Async() const {
+ tensorflow::mutex_lock l(async_map_mu);
+ return tensorflow::gtl::FindWithDefault(
+ thread_local_async, std::this_thread::get_id(), async_default);
+}
+
+bool TFE_TensorHandle::IsReady() {
+ if (node_id == 0) return true;
+ tensorflow::mutex_lock l(ctx_mutex_);
+ return ctx_ == nullptr;
+}
+
+tensorflow::Status TFE_TensorHandle::WaitReady() {
+ if (node_id == 0) return tensorflow::Status::OK();
+ TFE_Executor* executor = nullptr;
+ {
+ tensorflow::mutex_lock l(ctx_mutex_);
+ if (ctx_ == nullptr) return tensorflow::Status::OK();
+ executor = &ctx_->executor;
+ }
+ return executor->WaitFor(node_id);
+}
+
+tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *t = &tensor_;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = device_;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = op_device_;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::TensorAndDevice(
+ const tensorflow::Tensor** tensor, tensorflow::Device** device,
+ tensorflow::Device** op_device) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *tensor = &tensor_;
+ *device = device_;
+ *op_device = op_device_;
+ return tensorflow::Status::OK();
+}
+
+void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device) {
+ tensorflow::mutex_lock l(ctx_mutex_);
+ DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called "
+ << "on non-ready handles.";
+ ctx_ = nullptr;
+ tensor_ = tensor;
+ device_ = device;
+ op_device_ = op_device;
+}
+
+TFE_Op::~TFE_Op() {
+ for (TFE_TensorHandle* h : inputs) {
+ h->Unref();
+ }
+}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 9610ca1b3b..316006bafb 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -75,6 +75,11 @@ typedef enum TFE_ContextDevicePlacementPolicy {
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
+// Sets the default execution mode (sync/async). Note that this can be
+// overridden per thread using TFE_ContextSetAsyncForThread.
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
+ unsigned char async);
+
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
@@ -110,6 +115,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
+// Overrides the execution mode (sync/async) for the current thread.
+TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
+ unsigned char async,
+ TF_Status* status);
+
+// Causes the calling thread to block till all ops dispatched in async mode
+// have been executed. Note that "execution" here refers to kernel execution /
+// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
+// that lower level device queues (like GPU streams) have been flushed.
+//
+// This call may not block for execution of ops enqueued concurrently with this
+// call.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*,
+ TF_Status* status);
+
+// When an error happens, any pending operations are discarded and newly issued
+// ops return an error. This call clears the error state and re-enables
+// execution of newly issued ops.
+//
+// Note that outputs of discarded ops remain in a corrupt state and should not
+// be used for future calls.
+// TODO(agarwal): mark the affected handles and raise errors if they are used.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*);
+
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
@@ -119,15 +148,21 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_Status* status);
+// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+
+// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
@@ -137,6 +172,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
// that shares the underlying buffer. Otherwise, it currently requires at least
// one of the source or destination devices to be CPU (i.e., for the source or
// destination tensor to be placed in host memory).
+// If async execution is enabled, the copy may be enqueued and the call will
+// return "non-ready" handle. Else, this function returns after the copy has
+// been done.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
TF_Status* status);
@@ -157,6 +195,7 @@ typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status);
+
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
@@ -242,13 +281,20 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
int num_values);
// Execute the operation defined by 'op' and return handles to computed
-// tensors in 'retvals'.
+// tensors in `retvals`.
+//
+// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
+// '*num_retvals' should be set to the size of this array. It is an error if
+// the number of outputs is different from *num_retvals.
//
-// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
-// and '*num_retvals' should be set to the size of this array.
+// If async execution is enabled, the call may simply enqueue the execution
+// and return "non-ready" handles in `retvals`. Note that any handles contained
+// in 'op' should not be mutated till the kernel execution actually finishes.
//
-// On return, 'num_retvals' will be set to the actual number of outputs
-// returned by the operation.
+// For sync execution, if any of the inputs to `op` are not ready, this call
+// will block till they become ready and then return when the kernel execution
+// is done.
+// TODO(agarwal): change num_retvals to int from int*.
TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
@@ -274,6 +320,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx);
// Populates the passed-in buffer with a serialized RunMetadata protocol buffer
// containing any run metadata information accumulated so far and clears this
// information.
+// If async mode is enabled, this call blocks till all currently pending ops are
+// done.
TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 49b9434457..8dba12f47b 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -19,7 +19,9 @@ limitations under the License.
#include <algorithm>
#include <cstddef>
+#include <map>
#include <memory>
+#include <queue>
#include <string>
#include <thread>
#include <vector>
@@ -31,14 +33,113 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
+// A unit of execution for the TFE_Executor class below. Example subclasses
+// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
+// device to another.
+class TFE_Node {
+ public:
+ explicit TFE_Node(tensorflow::uint64 id);
+
+ virtual ~TFE_Node() {}
+
+ // Runs the computation corresponding to this node and blocks till the
+ // execution is done.
+ virtual tensorflow::Status Run() = 0;
+
+ // An id unique to the TFE_Context under which this node is created. Allocated
+ // monotonically.
+ const tensorflow::uint64 id;
+};
+
+// A class for handling async execution (see TFE_ContextSetAsync).
+// Note that this class is thread-safe.
+// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
+// device of the input handle. Fix that.
+// TODO(agarwal): On error, mark all affected handles as corrupted.
+// TODO(agarwal): Implement support for control dependencies.
+// TODO(agarwal): Support out-of-order execution and dispatching multiple
+// TFE_Node in parallel.
+// TODO(agarwal): Implement optimizations over TFE_Node traces.
+class TFE_Executor {
+ public:
+ ~TFE_Executor();
+
+ // This is called whenever async mode is enabled. Note that it may be called
+ // multiple times as different calling threads may switch async mode on or off
+ // independently.
+ void EnableAsync();
+
+ // Helper function to create monotonically increasing ids unique to this
+ // object.
+ tensorflow::uint64 NextId();
+
+ // Schedules `node` for execution.
+ // Note that Add must be called in monotonically increasing order of node->id.
+ void Add(TFE_Node* node);
+
+ // Causes the caller to block till node with id `node_id` has finished
+ // execution.
+ tensorflow::Status WaitFor(tensorflow::uint64 node_id);
+
+ // Blocks till all currently pending ops are done.
+ tensorflow::Status WaitForAllPendingNodes();
+
+ // Clears all currently set errors which re-enables async execution.
+ void ClearError();
+
+ // Returns Status based on any errors that occurred during async execution.
+ tensorflow::Status status();
+
+ private:
+ // Starts execution of pending TFE_Nodes. This function loops till
+ // thread_done_ is set to true. If any errors are encontered, these are set
+ // inside `status_`. The loop blocks anytime there are no pending nodes, or if
+ // `status_` is not ok.
+ void Run();
+
+ tensorflow::Status WaitImpl(bool wait_all, tensorflow::uint64 node_id);
+
+ tensorflow::mutex node_queue_mutex_;
+
+ // Used to signal that some TFE_Nodes are pending execution.
+ tensorflow::condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
+
+ // Queue of pending TFE_Nodes.
+ std::queue<TFE_Node*> node_queue_ GUARDED_BY(node_queue_mutex_);
+
+ // `status_` is set based on any errors raised during execution of a TFE_Node.
+ // It remains set until ClearError is called.
+ tensorflow::Status status_ GUARDED_BY(node_queue_mutex_);
+
+ // Map from id of a TFE_Node to condition_variables (not owned by the map).
+ // These condition_variables are notified and removed when that TFE_Node is
+ // done executing, or if an error is found in execution of any TFE_Node.
+ std::multimap<tensorflow::uint64, tensorflow::condition_variable*>
+ node_done_notifications_ GUARDED_BY(node_queue_mutex_);
+
+ // Thread object that calls the `Run` method. Currently we use only one thread
+ // for executing the TFE_Nodes one-by-one.
+ std::unique_ptr<tensorflow::Thread> thread_ GUARDED_BY(node_queue_mutex_);
+
+ // Indicates that `thread_` should stop as soon as it is done executing the
+ // current TFE_Node.
+ bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
+
+ tensorflow::mutex next_id_mutex_;
+ tensorflow::uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1;
+};
+
struct TFE_ContextOptions {
TF_SessionOptions session_options;
+ // true if async execution is enabled.
+ bool async = false;
TFE_ContextDevicePlacementPolicy policy{
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
};
@@ -60,7 +161,10 @@ struct TFE_Context {
device_manager.get(), opts.session_options.options.env,
TF_GRAPH_DEF_VERSION, &func_lib_def, {})),
log_device_placement(
- opts.session_options.options.config.log_device_placement()) {}
+ opts.session_options.options.config.log_device_placement()),
+ async_default(opts.async) {
+ if (async_default) executor.EnableAsync();
+ }
const bool soft_placement;
const TFE_ContextDevicePlacementPolicy policy;
@@ -98,29 +202,99 @@ struct TFE_Context {
std::atomic<bool> should_store_metadata{false};
tensorflow::mutex metadata_mu;
tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
-
const bool log_device_placement;
+ // TFE_Executor for async execution.
+ TFE_Executor executor;
+
+ // True if running in asynchronous mode.
+ bool Async() const;
+
+ // True if the default value for execution mode is async. Note that this value
+ // can be overridden per thread based on `thread_local_async` overrides.
+ const bool async_default;
+ mutable tensorflow::mutex async_map_mu;
+ std::unordered_map<std::thread::id, bool> thread_local_async
+ GUARDED_BY(async_map_mu);
};
-struct TFE_TensorHandle {
+struct TFE_TensorHandle : public tensorflow::core::RefCounted {
+ public:
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
tensorflow::Device* op_device)
- : t(t), d(d), op_device(op_device) {}
+ : dtype(t.dtype()),
+ node_id(0),
+ tensor_(t),
+ device_(d),
+ op_device_(op_device),
+ ctx_(nullptr) {}
+
+ TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
+ TFE_Context* ctx)
+ : dtype(dtype),
+ node_id(node_id),
+ tensor_(dtype),
+ device_(nullptr),
+ op_device_(nullptr),
+ ctx_(ctx) {
+ DCHECK_GT(node_id, 0);
+ }
+
+ ~TFE_TensorHandle() override {}
+
+ tensorflow::Status Tensor(const tensorflow::Tensor** t);
+
+ tensorflow::Status Device(tensorflow::Device** d);
- tensorflow::Tensor t;
- // TODO(ashankar): d == nullptr iff local CPU
- // This was expedient, but perhaps worth revisiting ('d' should always be a
- // valid pointer?)
+ tensorflow::Status OpDevice(tensorflow::Device** d);
+
+ tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor,
+ tensorflow::Device** device,
+ tensorflow::Device** op_device);
+
+ // Note that this can be called at most once, and only on non-ready handles,
+ // and makes them ready.
+ void SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device);
+
+ // dtype for the handle. It must be the same as t.dtype() once the handle is
+ // ready.
+ const tensorflow::DataType dtype;
+
+ private:
+ // If the contents of the Tensor pointed to by this handle is yet to be
+ // computed by a TFE_Node, this function will block till that compuatation is
+ // done and the handle is "ready".
+ tensorflow::Status WaitReady();
+
+ bool IsReady();
+
+ // Id for the TFE_Node that will compute the value pointed to by this handle.
+ // If the value is 0, the handle is already ready, but not vice-versa.
+ const tensorflow::uint64 node_id;
+
+ tensorflow::Tensor tensor_;
+
+ // TODO(ashankar): device_ == nullptr iff local CPU
+ // This was expedient, but perhaps worth revisiting ('device_' should always
+ // be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
// provided with the appropriate TFE_Context.
//
- // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
+ // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
- tensorflow::Device* d;
+ tensorflow::Device* device_;
+
+ // Device in which the op producing this tensor was executed. Equals to
+ // device_ for constant tensors.
+ tensorflow::Device* op_device_;
- // Device in which the op producing this tensor was executed. Equals to d for
- // constant tensors.
- tensorflow::Device* op_device;
+ tensorflow::mutex ctx_mutex_;
+
+ // `ctx` is only guaranteed to be set if the handle is not "ready". This is
+ // typically true when the handle was produced during async execution.
+ // `ctx` object is not owned and should outlive this handle.
+ TFE_Context* ctx_ GUARDED_BY(ctx_mutex_);
};
struct TFE_Op {
@@ -129,15 +303,15 @@ struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
+ ~TFE_Op();
+
bool const is_function() const { return attr_types == nullptr; }
TFE_Context* ctx; // Must outlive the TFE_Op.
const tensorflow::string name;
tensorflow::AttrBuilder attrs;
const tensorflow::AttrTypeMap* attr_types;
- std::vector<tensorflow::Tensor> inputs;
- std::vector<tensorflow::Device*> input_devices;
- std::vector<tensorflow::Device*> input_op_devices;
+ tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs;
tensorflow::Device* device;
bool use_xla = false;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 00fb7e68d0..927d119389 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -29,6 +29,20 @@ using tensorflow::string;
namespace {
+TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
TFE_TensorHandle* TestMatrixTensorHandle() {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
@@ -43,6 +57,20 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
return th;
}
+TFE_TensorHandle* TestMatrixTensorHandle3X2() {
+ int64_t dims[] = {3, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
@@ -139,10 +167,12 @@ void BM_InitOp(int iters) {
}
BENCHMARK(BM_InitOp);
-void BM_Execute(int iters) {
+void BM_Execute(int iters, int async) {
tensorflow::testing::StopTiming();
+ tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute");
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -156,6 +186,9 @@ void BM_Execute(int iters) {
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
+ if (async) {
+ TFE_ContextAsyncWait(ctx, status);
+ }
tensorflow::testing::StopTiming();
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
@@ -163,7 +196,7 @@ void BM_Execute(int iters) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
-BENCHMARK(BM_Execute);
+BENCHMARK(BM_Execute)->Arg(0)->Arg(1);
TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus();
@@ -205,10 +238,11 @@ TEST(CAPI, TensorHandle) {
TFE_DeleteTensorHandle(h);
}
-TEST(CAPI, TensorHandleCopyBetweenDevices) {
+void TensorHandleCopyBetweenDevices(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -274,10 +308,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
-TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
+TEST(CAPI, TensorHandleCopyBetweenDevices) {
+ TensorHandleCopyBetweenDevices(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) {
+ TensorHandleCopyBetweenDevices(true);
+}
+
+void TensorHandleCopyBetweenDevicesError(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ const char* kErrorDevice = "NoSuchDevice:0";
+ TFE_TensorHandle* hdevice =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
+ EXPECT_NE(TF_OK, TF_GetCode(status.get()));
+ const char* msg = "NoSuchDevice:0 unknown device";
+ EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr)
+ << TF_Message(status.get());
+ TF_SetStatus(status.get(), TF_OK, "");
+ const char* kCPUDevice = "CPU:0";
+ TFE_TensorHandle* hcopy =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get()));
+ TFE_DeleteTensorHandle(hcopy);
+ TFE_DeleteTensorHandle(hcpu);
+ if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
+ TFE_DeleteContext(ctx, status.get());
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
+ TensorHandleCopyBetweenDevicesError(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) {
+ TensorHandleCopyBetweenDevicesError(true);
+}
+
+void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -332,11 +412,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
-TEST(CAPI, TensorHandleSilentCopy) {
+TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
+ TensorHandleCopyBetweenTwoGPUDevices(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
+ TensorHandleCopyBetweenTwoGPUDevices(true);
+}
+
+void TensorHandleSilentCopy(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -366,14 +455,20 @@ TEST(CAPI, TensorHandleSilentCopy) {
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
-TEST(CAPI, TensorHandleSilentCopyLocal) {
+TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
+TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
+
+void TensorHandleSilentCopyLocal(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
@@ -407,11 +502,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
+TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
+TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
+ TensorHandleSilentCopyLocal(true);
+}
-TEST(CAPI, SetAndGetOpDevices) {
+void SetAndGetOpDevices(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
@@ -442,27 +543,27 @@ TEST(CAPI, SetAndGetOpDevices) {
TF_DeleteStatus(status);
}
-TEST(CAPI, Execute_MatMul_CPU) {
+void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
- TFE_TensorHandle* retvals[2] = {nullptr};
- int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
- ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -474,7 +575,101 @@ TEST(CAPI, Execute_MatMul_CPU) {
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
+TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); }
+TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); }
+
+void Execute_MatMul_CPU_Runtime_Error(bool async) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m1 = TestMatrixTensorHandle();
+ TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2();
+ TFE_Op* matmul = MatMulOp(ctx, m1, m2);
+ TFE_Op* matmul2 = MatMulOp(ctx, m1, m1);
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ TFE_DeleteOp(matmul);
+ if (!async) {
+ EXPECT_NE(TF_OK, TF_GetCode(status));
+ } else {
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ EXPECT_NE(TF_OK, TF_GetCode(status));
+ EXPECT_EQ(nullptr, t);
+ const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
+ EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
+ << TF_Message(status);
+ // Since error is not cleared, the following copy with correct device will
+ // still fail.
+ TF_SetStatus(status, TF_OK, "");
+ TFE_DeleteTensorHandle(retvals[0]);
+ retvals[0] = nullptr;
+ TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status));
+ TFE_ContextAsyncClearError(ctx);
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ }
+ // Following works in async mode since TFE_ContextAsyncClearError was called.
+ TF_SetStatus(status, TF_OK, "");
+ if (retvals[0] != nullptr) {
+ TFE_DeleteTensorHandle(retvals[0]);
+ }
+ retvals[0] = nullptr;
+ TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status));
+ TF_DeleteTensor(t);
+ TFE_DeleteOp(matmul2);
+ TFE_DeleteTensorHandle(m1);
+ TFE_DeleteTensorHandle(m2);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx, status);
+ TF_DeleteStatus(status);
+}
+TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
+ Execute_MatMul_CPU_Runtime_Error(false);
+}
+TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) {
+ Execute_MatMul_CPU_Runtime_Error(true);
+}
+
+void Execute_MatMul_CPU_Type_Error(bool async) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m1 = TestMatrixTensorHandle();
+ TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
+ TFE_Op* matmul = MatMulOp(ctx, m1, m2);
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status));
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(m1);
+ TFE_DeleteTensorHandle(m2);
+ if (retvals[0] != nullptr) {
+ TFE_DeleteTensorHandle(retvals[0]);
+ }
+ TFE_DeleteContext(ctx, status);
+ TF_DeleteStatus(status);
+}
+TEST(CAPI, Execute_MatMul_CPU_Type_Error) {
+ Execute_MatMul_CPU_Type_Error(false);
+}
+TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) {
+ Execute_MatMul_CPU_Type_Error(true);
+}
TEST(CAPI, Execute_Min_CPU) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -485,8 +680,8 @@ TEST(CAPI, Execute_Min_CPU) {
TFE_TensorHandle* input = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* minOp = MinOp(ctx, input, axis);
- TFE_TensorHandle* retvals[2] = {nullptr};
- int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(minOp);
@@ -509,9 +704,10 @@ TEST(CAPI, Execute_Min_CPU) {
}
#ifdef TENSORFLOW_EAGER_USE_XLA
-TEST(CAPI, Execute_MatMul_XLA_CPU) {
+void Execute_MatMul_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -521,15 +717,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) {
TFE_OpSetXLACompilation(matmul, true);
- TFE_TensorHandle* retvals[2] = {nullptr};
- int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// Running a primitive TF operator via XLA is not yet supported.
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
- TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, num_retvals);
@@ -545,13 +740,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
-
+ TFE_DeleteContext(ctx, status);
TF_DeleteStatus(status);
}
+TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
+TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
-TEST(CAPI, Execute_Min_XLA_CPU) {
+void Execute_Min_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -562,14 +760,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) {
TFE_OpSetXLACompilation(minOp, true);
- TFE_TensorHandle* retvals[2] = {nullptr};
- int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
- TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
@@ -582,13 +779,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) {
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
+ TFE_DeleteContext(ctx, status);
TF_DeleteStatus(status);
}
+TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
+TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
#endif // TENSORFLOW_EAGER_USE_XLA
-TEST(CAPI, ExecuteWithTracing) {
+void ExecuteWithTracing(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_ContextEnableRunMetadata(ctx);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -596,8 +797,8 @@ TEST(CAPI, ExecuteWithTracing) {
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
- TFE_TensorHandle* retvals[2] = {nullptr};
- int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call.
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
@@ -609,12 +810,12 @@ TEST(CAPI, ExecuteWithTracing) {
EXPECT_TRUE(
rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length}));
TF_DeleteBuffer(b);
- TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -626,6 +827,8 @@ TEST(CAPI, ExecuteWithTracing) {
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
+TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
+TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
@@ -657,32 +860,37 @@ TEST(CAPI, Function_ident_CPU) {
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
- TF_Tensor* t =
- TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
- *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
- TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- TF_DeleteTensor(t);
+ for (bool async : {false, true, false}) {
+ TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
+ status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK);
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+ *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+ TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteTensor(t);
- TFE_Op* op = TFE_NewOp(ctx, "ident", status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- TFE_OpAddInput(op, h, status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_OpAddInput(op, h, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- std::vector<TFE_TensorHandle*> result;
- result.push_back(nullptr);
- int num_retvals = 1;
- TFE_Execute(op, result.data(), &num_retvals, status);
- TFE_DeleteOp(op);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- ASSERT_EQ(num_retvals, 1);
+ std::vector<TFE_TensorHandle*> result;
+ result.push_back(nullptr);
+ int num_retvals = 1;
+ TFE_Execute(op, result.data(), &num_retvals, status);
+ TFE_DeleteOp(op);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ ASSERT_EQ(num_retvals, 1);
- TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
- TFE_DeleteTensorHandle(h);
- TF_DeleteTensor(r);
- TFE_DeleteTensorHandle(result[0]);
+ TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteTensor(r);
+ TFE_DeleteTensorHandle(result[0]);
+ }
TFE_DeleteContext(ctx, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
@@ -719,35 +927,40 @@ TEST(CAPI, Function_ident_XLA_CPU) {
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
- TF_Tensor* t =
- TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
- *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
- TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- TF_DeleteTensor(t);
+ for (bool async : {false, true, false}) {
+ TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
+ status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK);
+ TF_Tensor* t =
+ TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+ *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+ TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TF_DeleteTensor(t);
- TFE_Op* op = TFE_NewOp(ctx, "ident", status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- TFE_OpAddInput(op, h, status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_OpAddInput(op, h, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- // Now run it via XLA.
- TFE_OpSetXLACompilation(op, true);
+ // Now run it via XLA.
+ TFE_OpSetXLACompilation(op, true);
- std::vector<TFE_TensorHandle*> result;
- result.push_back(nullptr);
- int num_retvals = 1;
- TFE_Execute(op, result.data(), &num_retvals, status);
- TFE_DeleteOp(op);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- ASSERT_EQ(num_retvals, 1);
+ std::vector<TFE_TensorHandle*> result;
+ result.push_back(nullptr);
+ int num_retvals = 1;
+ TFE_Execute(op, result.data(), &num_retvals, status);
+ TFE_DeleteOp(op);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ ASSERT_EQ(num_retvals, 1);
- TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
- ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
- EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
- TFE_DeleteTensorHandle(h);
- TF_DeleteTensor(r);
- TFE_DeleteTensorHandle(result[0]);
+ TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteTensor(r);
+ TFE_DeleteTensorHandle(result[0]);
+ }
TFE_DeleteContext(ctx, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
@@ -788,9 +1001,10 @@ string MatMulFunction() {
return def.SerializeAsString();
}
-TEST(CAPI, FunctionDefAndExecute) {
+void FunctionDefAndExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -827,11 +1041,16 @@ TEST(CAPI, FunctionDefAndExecute) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
+TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
+TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
-void BM_ExecuteFunction(int iters) {
+void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming();
+ tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync"
+ : "ExecuteFunction");
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -853,6 +1072,9 @@ void BM_ExecuteFunction(int iters) {
TFE_Execute(matmul, &retval[0], &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
+ if (async) {
+ TFE_ContextAsyncWait(ctx, status);
+ }
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval[0]);
@@ -860,7 +1082,7 @@ void BM_ExecuteFunction(int iters) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
-BENCHMARK(BM_ExecuteFunction);
+BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
TF_Status* status) {
diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h
index 985ed96735..ad16f65495 100644
--- a/tensorflow/c/eager/runtime.h
+++ b/tensorflow/c/eager/runtime.h
@@ -185,7 +185,8 @@ class KernelAndDevice {
Device* device() const { return device_; }
- DataTypeVector* output_dtypes() { return &output_dtypes_; }
+ DataTypeVector* mutable_output_dtypes() { return &output_dtypes_; }
+ const DataTypeVector& output_dtypes() { return output_dtypes_; }
private:
std::unique_ptr<OpKernel> kernel_;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 26683f50ec..cd604538f1 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -105,9 +105,8 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
}
void ExtendSession(TF_Session* session, TF_Status* status) {
- mutex_lock l(session->mu);
- session->extend_before_run = false;
ExtendSessionGraphHelper(session, status);
+ session->extend_before_run = false;
}
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc
index 0734075fc6..81870a0efa 100644
--- a/tensorflow/cc/framework/while_gradients.cc
+++ b/tensorflow/cc/framework/while_gradients.cc
@@ -72,9 +72,9 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
};
// Body function that adds one to input.
- BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
- const std::vector<Output>& inputs,
- std::vector<Output>* outputs) {
+ BodyGraphBuilderFn body_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
DCHECK_EQ(inputs.size(), 1);
outputs->emplace_back(ops::Add(scope, inputs[0], 1));
return scope.status();
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 1c23f3257e..0cb3132e94 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -195,9 +195,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
MaxPool3DGrad::Attrs grad_attrs;
- grad_attrs.DataFormat(data_format);
auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0],
- ksize, strides, padding, grad_attrs);
+ ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -216,10 +216,10 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
internal::AvgPoolGrad::Attrs grad_attrs;
- grad_attrs.DataFormat(data_format);
auto dx =
internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding, grad_attrs);
+ ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
@@ -238,9 +238,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
AvgPool3DGrad::Attrs grad_attrs;
- grad_attrs.DataFormat(data_format);
auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
- ksize, strides, padding, grad_attrs);
+ ksize, strides, padding,
+ grad_attrs.DataFormat(data_format));
grad_outputs->push_back(dx);
return scope.status();
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index df73c28597..e9c974a046 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -960,7 +960,7 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
// Identify which shape indices of which instructions need to be copied. Store
// these results in 'instructions_to_copy'.
- std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy;
+ HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
const ShapeIndex& index) {
auto it = instructions_to_copy.find(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index ca54b2eed8..38668ff455 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -49,7 +49,7 @@ struct MatrixDescriptor {
// rhs_matrix, and stores the result to output_matrix.
template <typename Element>
bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix, se::Stream* stream) {
+ MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
DCHECK(!output_matrix.transpose);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
@@ -65,7 +65,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
return stream
->ThenBlasGemm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
- output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0,
+ output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
&output_data, /*leading dim of output=*/output_matrix.num_rows)
@@ -89,7 +89,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
template <typename Element>
bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix,
+ MatrixDescriptor output_matrix, double alpha,
se::blas::ComputationType computation_type,
se::blas::AlgorithmType algorithm, se::Stream* stream,
se::blas::ProfileResult* output_profile_result) {
@@ -109,7 +109,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
->ThenBlasGemmWithAlgorithm(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k,
- /*alpha=*/static_cast<Element>(1.0f), lhs_data,
+ /*alpha=*/static_cast<Element>(alpha), lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
/*leading dim of RHS=*/rhs_matrix.num_rows,
/*beta=*/static_cast<Element>(0.0f), &output_data,
@@ -127,8 +127,8 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
template <typename Element>
StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix, se::blas::ComputationType computation_type,
- se::Stream* stream) {
+ MatrixDescriptor output_matrix, double alpha,
+ se::blas::ComputationType computation_type, se::Stream* stream) {
std::vector<se::blas::AlgorithmType> algorithms;
CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
@@ -140,8 +140,8 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
// non-null ProfileResult, DoGemmWithAlgorithm should always return true,
// and the actual success-ness is returned in ProfileResult::is_valid.
CHECK(DoGemmWithAlgorithm<Element>(lhs_matrix, rhs_matrix, output_matrix,
- computation_type, algorithm, stream,
- &profile_result));
+ alpha, computation_type, algorithm,
+ stream, &profile_result));
if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
@@ -224,7 +224,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape, bool transpose_lhs,
- bool transpose_rhs, const HloInstruction* hlo_instruction)
+ bool transpose_rhs, double alpha,
+ const HloInstruction* hlo_instruction)
: Thunk(Kind::kGemm, hlo_instruction),
lhs_buffer_(lhs_buffer),
rhs_buffer_(rhs_buffer),
@@ -233,7 +234,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
rhs_shape_(rhs_shape),
output_shape_(output_shape),
transpose_lhs_(transpose_lhs),
- transpose_rhs_(transpose_rhs) {}
+ transpose_rhs_(transpose_rhs),
+ alpha_(alpha) {}
tensorflow::Status GemmThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
@@ -302,7 +304,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
if (autotune_it == autotune_results_.end()) {
StatusOr<se::blas::AlgorithmType> best_algorithm =
GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
- computation_type, stream);
+ alpha_, computation_type, stream);
autotune_it =
autotune_results_.insert({device_name, best_algorithm}).first;
@@ -323,15 +325,15 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
VLOG(2) << "Using algorithm " << algorithm
<< " chosen by autotuning on GemmThunk " << this;
return GetGemmWithAlgorithmFn(element_type)(
- lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm,
- stream,
+ lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type,
+ algorithm, stream,
/*output_profile_result=*/nullptr);
}
// Autotune will fail when CUDA 8 and GPU sm_50 or older are used.
// Use the older Gemm API in this case.
return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
- stream);
+ alpha_, stream);
};
bool launch_ok;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index 8c6a1f51a8..df3edcefef 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -34,15 +34,16 @@ namespace gpu {
// This is thread-compatible.
class GemmThunk : public Thunk {
public:
- // Constructs a thunk that computes "output = lhs <dot> rhs" using BLAS gemm.
- // transpose_lhs and transpose_rhs indicate whether gemm should transpose the
- // lhs and rhs operand. hlo_instruction is as in Thunk.
+ // Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using
+ // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should
+ // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is
+ // a constant.
GemmThunk(const BufferAllocation::Slice& lhs_buffer,
const BufferAllocation::Slice& rhs_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& lhs_shape, const Shape& rhs_shape,
const Shape& output_shape, bool transpose_lhs, bool transpose_rhs,
- const HloInstruction* hlo_instruction);
+ double alpha, const HloInstruction* hlo_instruction);
GemmThunk(const GemmThunk&) = delete;
GemmThunk& operator=(const GemmThunk&) = delete;
@@ -72,6 +73,7 @@ class GemmThunk : public Thunk {
const bool transpose_lhs_;
const bool transpose_rhs_;
+ const double alpha_;
// Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
// results. The map's value is the best algorithm we've found for this thunk
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index f6576cd8e0..85ecbe8fdb 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -52,6 +52,34 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
+ // Check if we can use output fusion for (A @ B) * alpha
+ if (producer->opcode() == HloOpcode::kDot) {
+ if (consumer->opcode() == HloOpcode::kMultiply) {
+ CHECK_EQ(consumer->operand_count(), 2);
+ int64 other_operand_index = 1 - operand_index;
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
+ if (alpha->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsScalar(alpha->shape())) {
+ return true;
+ }
+ }
+ }
+
+ // Only allow to fuse transpose into an output fusion.
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) {
+ if (producer->opcode() != HloOpcode::kTranspose) {
+ return false;
+ }
+ // Check that the transpose is the operand of a dot.
+ auto producer_operand_index = consumer->operand_index(producer);
+ auto fused_parameter = consumer->fused_parameter(producer_operand_index);
+ const std::vector<HloInstruction*>& fused_parameter_users =
+ fused_parameter->users();
+ return (fused_parameter_users.size() == 1 &&
+ fused_parameter_users[0]->opcode() == HloOpcode::kDot);
+ }
+
// Output fusion is not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
@@ -93,6 +121,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
if (IsReductionToVector(*consumer)) {
return HloInstruction::FusionKind::kInput;
}
+ if (producer->opcode() == HloOpcode::kDot) {
+ return HloInstruction::FusionKind::kOutput;
+ }
if (HloOpcode::kFusion == consumer->opcode()) {
return consumer->fusion_kind();
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index f383d19035..4b231c449f 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -228,5 +228,30 @@ TEST_F(InstructionFusionTest, DontFuseGTE) {
.ValueOrDie());
}
+TEST_F(InstructionFusionTest, DotOutputFusion) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ constant = f32[] constant(3)
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
+ dot = f32[4,4]{1,0} dot(p0, transpose)
+ ROOT mul = f32[4,4] multiply(constant, dot)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(
+ root->fused_expression_root(),
+ op::Multiply(op::Parameter(),
+ op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 1b89dfa7ae..32413f975a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -89,6 +89,19 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
return true;
}
+ if (hlo.opcode() == HloOpcode::kFusion &&
+ hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) {
+ // Try to find the dot inside the output fusion node.
+ const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
+ if (dot->opcode() != HloOpcode::kDot) {
+ dot = hlo.fused_expression_root()->operand(1);
+ }
+ if (dot->opcode() == HloOpcode::kDot) {
+ return ImplementedAsGemm(*dot);
+ }
+ }
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 4cfb613ae9..2381d7a7d5 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2188,31 +2188,63 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
inst->shape(), // The shape of the output.
false, // Do not transpose LHS.
false, // Do not transpose RHS.
+ 1.0, // alpha.
inst);
}
if (inst->opcode() == HloOpcode::kFusion) {
- const HloInstruction* dot = inst->fused_expression_root();
- DCHECK(dot->opcode() == HloOpcode::kDot);
- const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
- const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
- DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
- rhs_parameter->opcode() == HloOpcode::kParameter);
- const HloInstruction* lhs =
- inst->operand(lhs_parameter->parameter_number());
- const HloInstruction* rhs =
- inst->operand(rhs_parameter->parameter_number());
-
- return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*inst), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
- dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS.
- inst);
+ if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) {
+ const HloInstruction* mul = inst->fused_expression_root();
+ const HloInstruction* dot = mul->operand(0);
+ const HloInstruction* alpha = mul->operand(1);
+ if (dot->opcode() != HloOpcode::kDot) {
+ std::swap(dot, alpha);
+ }
+ DCHECK(dot->opcode() == HloOpcode::kDot);
+ const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
+ const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
+ DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
+ rhs_parameter->opcode() == HloOpcode::kParameter);
+ const HloInstruction* lhs =
+ inst->operand(lhs_parameter->parameter_number());
+ const HloInstruction* rhs =
+ inst->operand(rhs_parameter->parameter_number());
+
+ return MakeUnique<GemmThunk>(
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*mul), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
+ dot->operand(1)->IsRank2Transpose(), // Transpose RHS.
+ alpha->literal().Get<double>({0}), // alpha.
+ inst);
+ } else {
+ const HloInstruction* dot = inst->fused_expression_root();
+ DCHECK(dot->opcode() == HloOpcode::kDot);
+ const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
+ const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
+ DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
+ rhs_parameter->opcode() == HloOpcode::kParameter);
+ const HloInstruction* lhs =
+ inst->operand(lhs_parameter->parameter_number());
+ const HloInstruction* rhs =
+ inst->operand(rhs_parameter->parameter_number());
+
+ return MakeUnique<GemmThunk>(
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
+ dot->operand(1)->IsRank2Transpose(), // Transpose RHS.
+ 1.0, // Alpha.
+ inst);
+ }
}
LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 822e2f1f53..4205b0402c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -40,7 +40,7 @@ void HloModuleConfig::SetDefaultComputationLayout(
string HloModuleConfig::compilation_cache_key() const {
string key =
- tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_);
+ tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled());
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index d3c1fae592..586a03d412 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -63,9 +63,10 @@ class HloModuleConfig {
return &(*entry_computation_layout_);
}
- // Sets/returns whether to enable HLO-level profiling.
- bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; }
- void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; }
+ // Returns whether to enable HLO-level profiling.
+ bool hlo_profiling_enabled() const {
+ return debug_options_.xla_hlo_profile();
+ }
// Sets/returns whether this is a "host module". Host modules are used to
// record the data- and control-flow dependencies of host side computation
@@ -110,9 +111,6 @@ class HloModuleConfig {
tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
- // Whether to enable HLO-level profiling.
- bool hlo_profiling_enabled_ = false;
-
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index d65befaf84..e5b1c2efa3 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -158,8 +158,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ShapedBuffer> result,
- executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs,
- /*hlo_execution_profile=*/nullptr));
+ executable->ExecuteOnStreamWrapper(
+ &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs));
// Create a ScopedShapedBuffer of the result to manage deallocation. This will
// deallocate all the device memory when it goes out of scope.
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 8edd457281..0becc9d8f8 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -314,8 +314,6 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
if (execution_options != nullptr) {
config->set_seed(execution_options->seed());
config->set_debug_options(execution_options->debug_options());
- config->enable_hlo_profiling(
- execution_options->debug_options().xla_hlo_profile());
} else {
config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
}
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 88956f0512..c6feec68e0 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -57,26 +57,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "custom_grad_test",
- size = "small",
- srcs = ["python/kernel_tests/custom_grad_test.py"],
- additional_deps = [
- ":bayesflow_py",
- "//third_party/py/numpy",
- "//tensorflow/contrib/layers:layers_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:gradients",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ],
-)
-
-cuda_py_test(
name = "monte_carlo_test",
size = "small",
srcs = ["python/kernel_tests/monte_carlo_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index 89dfa583a4..f868203826 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -21,7 +21,6 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.bayesflow.python.ops import custom_grad
from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
@@ -31,7 +30,6 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
- 'custom_grad',
'entropy',
'hmc',
'metropolis_hastings',
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
deleted file mode 100644
index a95df31ac1..0000000000
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for Custom Gradient Ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-cg = custom_grad_impl
-
-
-class CustomGradientTest(test.TestCase):
-
- def test_works_correctly(self):
- with self.test_session() as sess:
- f = lambda x: x**2 / 2
- g = lambda x: (x - 1)**3 / 3
- x_ = np.linspace(-100, 100, int(1e4)) + [0.]
-
- x = constant_op.constant(x_)
- fx = cg.custom_gradient(f(x), g(x), x)
- gx = gradients_impl.gradients(fx, x)[0]
- [fx_, gx_] = sess.run([fx, gx])
-
- self.assertAllClose(f(x_), fx_)
- self.assertAllClose(g(x_), gx_)
-
- def test_works_correctly_both_f_g_zero(self):
- with self.test_session() as sess:
- f = lambda x: x**2 / 2
- g = lambda x: x**3 / 3
- x_ = np.linspace(-100, 100, int(1e4)) + [0.]
-
- x = constant_op.constant(x_)
- fx = cg.custom_gradient(f(x), g(x), x)
- gx = gradients_impl.gradients(fx, x)[0]
- [fx_, gx_] = sess.run([fx, gx])
-
- self.assertAllClose(f(x_), fx_)
- self.assertAllClose(g(x_), gx_)
-
- def test_works_correctly_vector_of_vars(self):
- with self.test_session() as sess:
- x = variable_scope.get_variable(
- name="x",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(2))
- y = variable_scope.get_variable(
- name="y",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(3))
- sess.run([variables.global_variables_initializer()])
-
- f = lambda z: z[0] * z[1]
- g = lambda z: z[0]**2 * z[1]**2 / 2
-
- z = array_ops.stack([x, y])
- fz = cg.custom_gradient(f(z), g(z), z, axis=0)
- gz = gradients_impl.gradients(fz, variables.trainable_variables())
- [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]])
-
- self.assertEqual(f(z_), fz_)
- self.assertEqual(g(z_), gx_)
- self.assertEqual(g(z_), gy_)
-
- def test_works_correctly_side_vars(self):
- with self.test_session() as sess:
- x_ = np.float32(2.1) # Adding extra tenth to force imprecision.
- y_ = np.float32(3.1)
- x = variable_scope.get_variable(
- name="x",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(x_))
- y = variable_scope.get_variable(
- name="y",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(y_))
- sess.run([variables.global_variables_initializer()])
-
- f = lambda x: x * y
- g = lambda z: math_ops.square(x) * y
-
- fx = cg.custom_gradient(f(x), g(x), x)
- gx = gradients_impl.gradients(fx, variables.trainable_variables())
- [x_, fx_, gx_] = sess.run([x, fx, gx[0]])
- gy_ = gx[1]
-
- self.assertEqual(x_ * y_, fx_)
- self.assertEqual(np.square(x_) * y_, gx_)
- self.assertEqual(None, gy_)
-
- def test_works_correctly_fx_gx_manually_stopped(self):
- with self.test_session() as sess:
- x_ = np.float32(2.1) # Adding extra tenth to force imprecision.
- y_ = np.float32(3.1)
- x = variable_scope.get_variable(
- name="x",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(x_))
- y = variable_scope.get_variable(
- name="y",
- shape=[],
- dtype=dtypes.float32,
- initializer=init_ops.constant_initializer(y_))
- sess.run([variables.global_variables_initializer()])
-
- stop = array_ops.stop_gradient # For readability.
-
- # Basically we need to stop the `x` portion of `f`. And when we supply the
- # arg to `custom_gradient` we need to stop the complement, i.e., the `y`
- # part.
- f = lambda x: stop(x) * y
- g = lambda x: stop(math_ops.square(x)) * y
- fx = cg.custom_gradient(f(x), g(x), x + stop(y),
- fx_gx_manually_stopped=True)
-
- gx = gradients_impl.gradients(fx, variables.trainable_variables())
- [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]])
-
- self.assertEqual(x_ * y_, fx_)
- self.assertEqual(np.square(x_) * y_, gx_)
- self.assertEqual(x_, gy_)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad.py
deleted file mode 100644
index c8218c57cc..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Functions for specifying custom gradients.
-
-See @{tf.contrib.bayesflow.custom_grad.custom_gradient}.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- 'custom_gradient',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
deleted file mode 100644
index d44fe6529a..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Functions for specifying custom gradients.
-
-@@custom_gradient
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-
-__all__ = [
- "custom_gradient",
-]
-
-
-def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
- name=None):
- """Enables specifying a custom gradient.
-
- This function works by clever application of `stop_gradient`. I.e., observe
- that:
-
- ```none
- h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x))
- ```
-
- is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] =
- stop_gradient(g(x)).`
-
- In addition to scalar-domain/scalar-range functions, this function also
- supports tensor-domain/scalar-range functions. However, in the latter case it
- is necessary to reduce `x` to a scalar. This can be done by indicating the
- `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior
- to calling this function.
-
- Partial Custom Gradient:
-
- Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy =
- None`. This is because a `Tensor` cannot have only a portion of its gradient
- stopped. To circumvent this issue, one must manually `stop_gradient` the
- relevant portions of `f`, `g`. For example see the unit-test,
- `test_works_correctly_fx_gx_manually_stopped`.
-
- Args:
- fx: `Tensor`. Output of function evaluated at `x`.
- gx: `Tensor`. Gradient of function evaluated at `x`.
- x: `Tensor`. Point of evaluation for `f, g`.
- axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain
- of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range.
- If `None` `f` is assumed to render one scalar given all of `x`. Otherwise
- `f` is assumed to output one scalar for each of `axis` dimensions of `x`.
- fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually
- have `stop_gradient` applied.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- fx: Floating-type `Tensor` equal to `f(x)` but which has gradient
- `stop_gradient(g(x))`.
- """
- with ops.name_scope(name, "custom_gradient", [fx, gx, x]):
- fx = ops.convert_to_tensor(fx, name="fx")
- # We don't want to bother eagerly computing `gx` since we may not even need
- # it.
- with ops.control_dependencies([fx]):
- gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx")
- gx = array_ops.identity(gx, name="gx")
- # Proof of correctness:
- #
- # f(x) = x * stop[gx] + stop[fx - x * gx]
- # = stop[fx]
- #
- # g(x) = grad[fx]
- # = stop[gx] + grad[stop[fx - x * gx]]
- # = stop[gx] + 0
- #
- # Notice that when x is zero it still works:
- # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
- #
- # The proof is similar for the tensor-domain case, except that `x` is
- # replaced by `reduce_sum(x)`.
- sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x")
- if not fx_gx_manually_stopped:
- fx = array_ops.stop_gradient(fx)
- gx = array_ops.stop_gradient(gx)
- # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write
- # the code this way, rather than, e.g.,
- # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
- # For more discussion regarding the relevant portions of the IEEE754
- # standard, see the StackOverflow question,
- # "Is there a floating point value of x, for which x-x == 0 is false?"
- # http://stackoverflow.com/q/2686644
- return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx
diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
index ae99d53a2c..6aa5246398 100644
--- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
@@ -272,6 +272,20 @@ REGISTER_OP("Quantiles")
.Input("sparse_indices: num_sparse_features * int64")
.Output("dense_quantiles: num_dense_features * int32")
.Output("sparse_quantiles: num_sparse_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_dense_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features));
+ int num_sparse_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_sparse_features", &num_sparse_features));
+ // Set output shapes (dense_quantiles and sparse_quantiles) by the
+ // relevant inputs (dense_values and sparse_values). Note that the output
+ // has an additional dimension for dimension_ids.
+ for (int i = 0; i < num_dense_features + num_sparse_features; ++i) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2}));
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
Computes quantile for each a given list of dense and sparse feature values using
the given buckets.
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index e463ef2fb4..1bd73ee704 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -818,6 +818,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "affine_scalar_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/affine_scalar_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "affine_linear_operator_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/affine_linear_operator_test.py"],
@@ -1167,6 +1186,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "square_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/square_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "weibull_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/weibull_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
new file mode 100644
index 0000000000..16173a166f
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
@@ -0,0 +1,153 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Affine Scalar Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import AffineScalar
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class AffineScalarBijectorTest(test.TestCase):
+ """Tests correctness of the Y = scale @ x + shift transformation."""
+
+ def testProperties(self):
+ with self.test_session():
+ mu = -1.
+ # scale corresponds to 1.
+ bijector = AffineScalar(shift=mu)
+ self.assertEqual("affine_scalar", bijector.name)
+
+ def testNoBatchScalar(self):
+ with self.test_session() as sess:
+
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = -1.
+ # Corresponds to scale = 2
+ bijector = AffineScalar(shift=mu, scale=2.)
+ x = [1., 2, 3] # Three scalar samples (no batches).
+ self.assertAllClose([1., 3, 5], run(bijector.forward, x))
+ self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
+ self.assertAllClose([-np.log(2.)] * 3,
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
+ with self.test_session() as sess:
+
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value).astype(np.float64)
+ x = array_ops.placeholder(dtypes.float64, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = np.float64([1.])
+ # One batch, scalar.
+ # Corresponds to scale = 1.
+ bijector = AffineScalar(shift=mu)
+ x = np.float64([1.]) # One sample from one batches.
+ self.assertAllClose([2.], run(bijector.forward, x))
+ self.assertAllClose([0.], run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+ def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self):
+ with self.test_session() as sess:
+
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value).astype(np.float64)
+ x = array_ops.placeholder(dtypes.float64, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ multiplier = np.float64([2.])
+ # One batch, scalar.
+ # Corresponds to scale = 2, shift = 0.
+ bijector = AffineScalar(scale=multiplier)
+ x = np.float64([1.]) # One sample from one batches.
+ self.assertAllClose([2.], run(bijector.forward, x))
+ self.assertAllClose([0.5], run(bijector.inverse, x))
+ self.assertAllClose([np.log(0.5)],
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testTwoBatchScalarIdentityViaIdentity(self):
+ with self.test_session() as sess:
+
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [1., -1]
+ # Univariate, two batches.
+ # Corresponds to scale = 1.
+ bijector = AffineScalar(shift=mu)
+ x = [1., 1] # One sample from each of two batches.
+ self.assertAllClose([2., 0], run(bijector.forward, x))
+ self.assertAllClose([0., 2], run(bijector.inverse, x))
+ self.assertAllClose([0., 0.], run(bijector.inverse_log_det_jacobian, x))
+
+ def testTwoBatchScalarIdentityViaScale(self):
+ with self.test_session() as sess:
+
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [1., -1]
+ # Univariate, two batches.
+ # Corresponds to scale = 1.
+ bijector = AffineScalar(shift=mu, scale=[2., 1])
+ x = [1., 1] # One sample from each of two batches.
+ self.assertAllClose([3., 0], run(bijector.forward, x))
+ self.assertAllClose([0., 2], run(bijector.inverse, x))
+ self.assertAllClose(
+ [-np.log(2), 0.], run(bijector.inverse_log_det_jacobian, x))
+
+ def testScalarCongruency(self):
+ with self.test_session():
+ bijector = AffineScalar(shift=3.6, scale=0.42)
+ assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
index c9158117f7..077e6176b4 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
@@ -25,7 +25,6 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@@ -36,192 +35,9 @@ class AffineBijectorTest(test.TestCase):
with self.test_session():
mu = -1.
# scale corresponds to 1.
- bijector = Affine(shift=mu, event_ndims=0)
+ bijector = Affine(shift=mu)
self.assertEqual("affine", bijector.name)
- def testNoBatchScalarViaIdentity(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = -1.
- # Corresponds to scale = 2
- bijector = Affine(
- shift=mu, scale_identity_multiplier=2., event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [1., 2, 3] # Three scalar samples (no batches).
- self.assertAllClose([1., 3, 5], run(bijector.forward, x))
- self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
- self.assertAllClose(-np.log(2.),
- run(bijector.inverse_log_det_jacobian, x))
-
- def testNoBatchScalarViaDiag(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = -1.
- # Corresponds to scale = 2
- bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [1., 2, 3] # Three scalar samples (no batches).
- self.assertAllClose([1., 3, 5], run(bijector.forward, x))
- self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
- self.assertAllClose(-np.log(2.),
- run(bijector.inverse_log_det_jacobian, x))
-
- def testWeirdSampleNoBatchScalarViaDiagMultiplier(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = -1.
- # Corresponds to scale = 2.
- bijector = Affine(
- shift=mu, scale_identity_multiplier=2., event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape.
- self.assertAllClose([[1., 3, 5],
- [7, 9, 11]],
- run(bijector.forward, x))
- self.assertAllClose([[1., 1.5, 2.],
- [2.5, 3, 3.5]],
- run(bijector.inverse, x))
- self.assertAllClose(-np.log(2.),
- run(bijector.inverse_log_det_jacobian, x))
-
- def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value).astype(np.float64)
- x = array_ops.placeholder(dtypes.float64, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = np.float64([1.])
- # One batch, scalar.
- # Corresponds to scale = 1.
- bijector = Affine(shift=mu, event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = np.float64([1.]) # One sample from one batches.
- self.assertAllClose([2.], run(bijector.forward, x))
- self.assertAllClose([0.], run(bijector.inverse, x))
- self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
- def testOneBatchScalarViaIdentityIn64BitUserProvidesMultiplierOnly(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value).astype(np.float64)
- x = array_ops.placeholder(dtypes.float64, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- multiplier = np.float64([2.])
- # One batch, scalar.
- # Corresponds to scale = 2, shift = 0.
- bijector = Affine(scale_identity_multiplier=multiplier, event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = np.float64([1.]) # One sample from one batches.
- self.assertAllClose([2.], run(bijector.forward, x))
- self.assertAllClose([0.5], run(bijector.inverse, x))
- self.assertAllClose([np.log(0.5)],
- run(bijector.inverse_log_det_jacobian, x))
-
- def testOneBatchScalarViaDiagMultiplier(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = [1.]
- # One batch, scalar.
- # Corresponds to scale = 1.
- bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [1.] # One sample from one batches.
- self.assertAllClose([2.], run(bijector.forward, x))
- self.assertAllClose([0.], run(bijector.inverse, x))
- self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
- def testTwoBatchScalarIdentityViaIdentity(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = [1., -1]
- # Univariate, two batches.
- # Corresponds to scale = 1.
- bijector = Affine(shift=mu, event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [1., 1] # One sample from each of two batches.
- self.assertAllClose([2., 0], run(bijector.forward, x))
- self.assertAllClose([0., 2], run(bijector.inverse, x))
- self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
- def testTwoBatchScalarIdentityViaDiagMultiplier(self):
- with self.test_session() as sess:
-
- def static_run(fun, x):
- return fun(x).eval()
-
- def dynamic_run(fun, x_value):
- x_value = np.array(x_value)
- x = array_ops.placeholder(dtypes.float32, name="x")
- return sess.run(fun(x), feed_dict={x: x_value})
-
- for run in (static_run, dynamic_run):
- mu = [1., -1]
- # Univariate, two batches.
- # Corresponds to scale = 1.
- bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0)
- self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar"
- x = [1., 1] # One sample from each of two batches.
- self.assertAllClose([2., 0], run(bijector.forward, x))
- self.assertAllClose([0., 2], run(bijector.inverse, x))
- self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
-
def testNoBatchMultivariateIdentity(self):
with self.test_session() as sess:
@@ -238,7 +54,6 @@ class AffineBijectorTest(test.TestCase):
# Multivariate
# Corresponds to scale = [[1., 0], [0, 1.]]
bijector = Affine(shift=mu)
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 1]
# matmul(sigma, x) + shift
# = [-1, -1] + [1, -1]
@@ -269,7 +84,6 @@ class AffineBijectorTest(test.TestCase):
# Multivariate
# Corresponds to scale = [[2., 0], [0, 1.]]
bijector = Affine(shift=mu, scale_diag=[2., 1])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 1]
# matmul(sigma, x) + shift
# = [-1, -1] + [1, -1]
@@ -297,22 +111,17 @@ class AffineBijectorTest(test.TestCase):
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
- event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims")
x_value = np.array([[1., 1]], dtype=np.float32)
mu_value = np.array([1., -1], dtype=np.float32)
scale_diag_value = np.array([2., 2], dtype=np.float32)
- event_ndims_value = np.array(1, dtype=np.int32)
feed_dict = {
x: x_value,
mu: mu_value,
scale_diag: scale_diag_value,
- event_ndims: event_ndims_value
}
- bijector = Affine(
- shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
- self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
+ bijector = Affine(shift=mu, scale_diag=scale_diag)
self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict))
self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict))
self.assertAllClose(
@@ -335,7 +144,6 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to 1 2x2 matrix, with twos on the diagonal.
scale = 2.
bijector = Affine(shift=mu, scale_identity_multiplier=scale)
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[[1., 1]]]
self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -358,7 +166,6 @@ class AffineBijectorTest(test.TestCase):
# Corresponds to 1 2x2 matrix, with twos on the diagonal.
scale_diag = [[2., 2]]
bijector = Affine(shift=mu, scale_diag=scale_diag)
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[[1., 1]]]
self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
@@ -370,23 +177,18 @@ class AffineBijectorTest(test.TestCase):
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
- event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims")
x_value = np.array([[[1., 1]]], dtype=np.float32)
mu_value = np.array([[1., -1]], dtype=np.float32)
scale_diag_value = np.array([[2., 2]], dtype=np.float32)
- event_ndims_value = 1
feed_dict = {
x: x_value,
mu: mu_value,
scale_diag: scale_diag_value,
- event_ndims: event_ndims_value
}
- bijector = Affine(
- shift=mu, scale_diag=scale_diag, event_ndims=event_ndims)
- self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict))
+ bijector = Affine(shift=mu, scale_diag=scale_diag)
self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict))
self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict))
self.assertAllClose([-np.log(4)],
@@ -410,9 +212,7 @@ class AffineBijectorTest(test.TestCase):
bijector = Affine(
shift=mu,
scale_identity_multiplier=1.,
- scale_diag=[1., 1., 1.],
- event_ndims=1)
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
+ scale_diag=[1., 1., 1.])
x = [1., 2, 3] # Three scalar samples (no batches).
self.assertAllClose([1., 3, 5], run(bijector.forward, x))
self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
@@ -437,7 +237,6 @@ class AffineBijectorTest(test.TestCase):
shift=mu,
scale_identity_multiplier=1.,
scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[1., 5]], run(bijector.forward, x))
self.assertAllClose([[1., 0.5]], run(bijector.inverse, x))
@@ -460,7 +259,6 @@ class AffineBijectorTest(test.TestCase):
# scale = [[2., 0], [2, 3]]
bijector = Affine(
shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[1., 7]], run(bijector.forward, x))
self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x))
@@ -486,7 +284,6 @@ class AffineBijectorTest(test.TestCase):
scale_identity_multiplier=1.0,
scale_diag=[1., 2.],
scale_tril=[[1., 0], [2., 1]])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [[1., 2]] # One multivariate sample.
self.assertAllClose([[2., 9]], run(bijector.forward, x))
self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x))
@@ -514,7 +311,6 @@ class AffineBijectorTest(test.TestCase):
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = Affine(shift=mu, scale_diag=[10., 2, 3])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 3, 8], run(bijector.forward, x))
self.assertAllClose(
@@ -550,7 +346,6 @@ class AffineBijectorTest(test.TestCase):
scale_perturb_factor=[[2., 0], [0., 0], [0, 1]])
bijector_ref = Affine(shift=mu, scale_diag=[10., 3, 5])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 5, 14], run(bijector.forward, x))
self.assertAllClose(
@@ -586,7 +381,6 @@ class AffineBijectorTest(test.TestCase):
bijector_ref = Affine(
shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([9., 6, 22], run(bijector.forward, x))
self.assertAllClose(
@@ -622,7 +416,6 @@ class AffineBijectorTest(test.TestCase):
bijector_ref = Affine(
shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]])
- self.assertEqual(1, bijector.event_ndims.eval()) # "is vector"
x = [1., 2, 3] # Vector.
self.assertAllClose([5., 6, 22], run(bijector.forward, x))
self.assertAllClose(
@@ -647,38 +440,6 @@ class AffineBijectorTest(test.TestCase):
with self.assertRaisesOpError("diagonal part must be non-zero"):
bijector.forward([1., 1.]).eval()
- def testEventNdimsLargerThanOneRaises(self):
- with self.test_session():
- mu = [1., -1]
- with self.assertRaisesRegexp(
- ValueError, (r"event_ndims\(2\) was not 0 or 1")):
- # Scale corresponds to 2x2 identity matrix.
- bijector = Affine(shift=mu, event_ndims=2, validate_args=True)
- bijector.forward([1., 1.]).eval()
-
- def testScaleZeroScalarRaises(self):
- with self.test_session():
- mu = -1.
- # Check Identity matrix with zero scaling.
- bijector = Affine(
- shift=mu,
- scale_identity_multiplier=0.,
- event_ndims=0,
- validate_args=True)
- with self.assertRaisesOpError("identity_multiplier should be non-zero"):
- bijector.forward(1.).eval()
-
- def testScaleDiagAndEventNdimsZeroRaises(self):
- # Check Diag matrix with zero scaling.
- with self.assertRaisesRegexp(ValueError, "only scale argument"):
- Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True)
-
- def testScalarCongruency(self):
- with self.test_session():
- bijector = Affine(
- shift=3.6, scale_identity_multiplier=0.42, event_ndims=0)
- assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
-
def _makeScale(self,
x,
scale_identity_multiplier=None,
@@ -747,14 +508,12 @@ class AffineBijectorTest(test.TestCase):
scale_args = dict({"x": x}, **args)
scale = self._makeScale(**scale_args)
- bijector_args = dict({"event_ndims": 1}, **args)
-
# We haven't specified enough information for the scale.
if scale is None:
with self.assertRaisesRegexp(ValueError, ("must be specified.")):
- bijector = Affine(shift=shift, **bijector_args)
+ bijector = Affine(shift=shift, **args)
else:
- bijector = Affine(shift=shift, **bijector_args)
+ bijector = Affine(shift=shift, **args)
np_x = x
# For the case a vector is passed in, we need to make the shape
# match the matrix for matmul to work.
@@ -829,15 +588,5 @@ class AffineBijectorTest(test.TestCase):
x=np.array(
[1., 2], dtype=np.float32))
- def testScalarEventIdentityScale(self):
- with self.test_session() as sess:
- doubler = Affine(
- scale_identity_multiplier=2.,
- event_ndims=0)
- doubler2 = doubler.inverse_log_det_jacobian(2.)
- doubler2_ildj_ = sess.run([doubler2])
- self.assertAllClose([-np.log(2.)], doubler2_ildj_)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index ab2338f4cb..f392e83d2c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@@ -32,8 +31,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
def testBijectorMatrix(self):
with self.test_session():
- bijector = bijectors.CholeskyOuterProduct(
- event_ndims=2, validate_args=True)
+ bijector = bijectors.CholeskyOuterProduct(validate_args=True)
self.assertEqual("cholesky_outer_product", bijector.name)
x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]]
y = np.matmul(x, np.transpose(x, axes=(0, 2, 1)))
@@ -60,39 +58,12 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
atol=0.,
rtol=1e-7)
- def testBijectorScalar(self):
- with self.test_session():
- bijector = bijectors.CholeskyOuterProduct(
- event_ndims=0, validate_args=True)
- self.assertEqual("cholesky_outer_product", bijector.name)
- x = [[[1., 5],
- [2, 1]],
- [[np.sqrt(2.), 3],
- [np.sqrt(8.), 1]]]
- y = np.square(x)
- ildj = -np.log(2.) - np.log(x)
- self.assertAllClose(y, bijector.forward(x).eval())
- self.assertAllClose(x, bijector.inverse(y).eval())
- self.assertAllClose(
- ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
- self.assertAllClose(
- -bijector.inverse_log_det_jacobian(y).eval(),
- bijector.forward_log_det_jacobian(x).eval(),
- atol=0.,
- rtol=1e-7)
-
- def testScalarCongruency(self):
- with self.test_session():
- bijector = bijectors.CholeskyOuterProduct(
- event_ndims=0, validate_args=True)
- assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
-
def testNoBatchStatic(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
with self.test_session() as sess:
- y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x)
- x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y)
+ y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
+ x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
self.assertAllEqual([2, 2], y_actual.get_shape())
self.assertAllEqual([2, 2], x_actual.get_shape())
@@ -105,8 +76,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
with self.test_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
- y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x_pl)
- x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y_pl)
+ y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
+ x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual],
feed_dict={x_pl: x, y_pl: y})
self.assertEqual(None, y_actual.get_shape())
@@ -124,8 +95,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
[[9., 3],
[3, 5]]]) # np.matmul(x, x.T)
with self.test_session() as sess:
- y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x)
- x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y)
+ y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
+ x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
self.assertEqual([2, 2, 2], y_actual.get_shape())
self.assertEqual([2, 2, 2], x_actual.get_shape())
@@ -144,8 +115,8 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
with self.test_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
- y_actual = bijectors.CholeskyOuterProduct(event_ndims=2).forward(x=x_pl)
- x_actual = bijectors.CholeskyOuterProduct(event_ndims=2).inverse(y=y_pl)
+ y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
+ x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual],
feed_dict={x_pl: x, y_pl: y})
self.assertEqual(None, y_actual.get_shape())
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
index 0ff3530428..28e3e31354 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
@@ -35,8 +35,7 @@ class InvertBijectorTest(test.TestCase):
for fwd in [
bijectors.Identity(),
bijectors.Exp(event_ndims=1),
- bijectors.Affine(
- shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1),
+ bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]),
bijectors.Softplus(event_ndims=1),
bijectors.SoftmaxCentered(event_ndims=1),
bijectors.SigmoidCentered(),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
new file mode 100644
index 0000000000..f03d6f1343
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
@@ -0,0 +1,58 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class SquareBijectorTest(test.TestCase):
+ """Tests the correctness of the Y = X ** 2 transformation."""
+
+ def testBijectorScalar(self):
+ with self.test_session():
+ bijector = bijectors.Square(validate_args=True)
+ self.assertEqual("square", bijector.name)
+ x = [[[1., 5],
+ [2, 1]],
+ [[np.sqrt(2.), 3],
+ [np.sqrt(8.), 1]]]
+ y = np.square(x)
+ ildj = -np.log(2.) - np.log(x)
+ self.assertAllClose(y, bijector.forward(x).eval())
+ self.assertAllClose(x, bijector.inverse(y).eval())
+ self.assertAllClose(
+ ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
+ self.assertAllClose(
+ -bijector.inverse_log_det_jacobian(y).eval(),
+ bijector.forward_log_det_jacobian(x).eval(),
+ atol=0.,
+ rtol=1e-7)
+
+ def testScalarCongruency(self):
+ with self.test_session():
+ bijector = bijectors.Square(validate_args=True)
+ assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index cbaf74d3f6..af13553c32 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -245,9 +245,8 @@ class TransformedDistributionTest(test.TestCase):
with self.test_session() as sess:
exp2 = self._cls()(
ds.Exponential(rate=0.25),
- bijector=ds.bijectors.Affine(
- scale_identity_multiplier=2.,
- event_ndims=0))
+ bijector=ds.bijectors.AffineScalar(scale=2.)
+ )
log_prob = exp2.log_prob(1.)
log_prob_ = sess.run(log_prob)
base_log_prob = -0.5 * 0.25 + np.log(0.25)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 46ec49754a..452f1caa30 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -17,6 +17,7 @@
@@AbsoluteValue
@@Affine
@@AffineLinearOperator
+@@AffineScalar
@@Bijector
@@BatchNormalization
@@Chain
@@ -38,6 +39,7 @@
@@SinhArcsinh
@@SoftmaxCentered
@@Softplus
+@@Square
@@Weibull
@@masked_autoregressive_default_template
@@ -54,6 +56,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import *
from tensorflow.contrib.distributions.python.ops.bijectors.affine import *
from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import *
+from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import *
from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import *
from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
@@ -73,6 +76,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
+from tensorflow.contrib.distributions.python.ops.bijectors.square import *
from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
index 05bb9c2f9b..7fe73ada44 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
@@ -104,7 +104,6 @@ class Affine(bijector.Bijector):
scale_tril=None,
scale_perturb_factor=None,
scale_perturb_diag=None,
- event_ndims=1,
validate_args=False,
name="affine"):
"""Instantiates the `Affine` bijector.
@@ -157,8 +156,6 @@ class Affine(bijector.Bijector):
matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which
represents an `r x r` diagonal matrix. When `None` low rank updates will
take the form `scale_perturb_factor * scale_perturb_factor.T`.
- event_ndims: Scalar `int` `Tensor` indicating the number of dimensions
- associated with a particular draw from the distribution. Must be 0 or 1.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
@@ -187,23 +184,6 @@ class Affine(bijector.Bijector):
with self._name_scope("init", values=[
shift, scale_identity_multiplier, scale_diag, scale_tril,
scale_perturb_diag, scale_perturb_factor]):
- event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
- event_ndims_const = tensor_util.constant_value(event_ndims)
- if event_ndims_const is not None and event_ndims_const not in (0, 1):
- raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const)
- else:
- if validate_args:
- # Shape tool will catch if event_ndims is negative.
- event_ndims = control_flow_ops.with_dependencies(
- [check_ops.assert_less(
- event_ndims, 2, message="event_ndims must be 0 or 1")],
- event_ndims)
-
- if event_ndims_const == 0 and not self._is_only_identity_multiplier:
- raise ValueError(
- "If event_ndims == 0, the only scale argument you can pass is "
- "scale_identity_multiplier. All others operate on vectors.")
-
# In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
dtype = dtypes.float32
@@ -251,12 +231,11 @@ class Affine(bijector.Bijector):
self._scale = scale
self._shaper = _DistributionShape(
batch_ndims=batch_ndims,
- event_ndims=event_ndims,
+ event_ndims=1,
validate_args=validate_args)
super(Affine, self).__init__(
- event_ndims=event_ndims,
+ event_ndims=1,
graph_parents=(
- [event_ndims] +
[self._scale] if tensor_util.is_tensor(self._scale)
else self._scale.graph_parents +
[self._shift] if self._shift is not None else []),
@@ -388,9 +367,7 @@ class Affine(bijector.Bijector):
if self._is_only_identity_multiplier:
# We don't pad in this case and instead let the fldj be applied
# via broadcast.
- event_size = distribution_util.pick_vector(
- math_ops.equal(self._shaper.event_ndims, 0),
- [1], array_ops.shape(x))[-1]
+ event_size = array_ops.shape(x)[-1]
event_size = math_ops.cast(event_size, dtype=self._scale.dtype)
return math_ops.log(math_ops.abs(self._scale)) * event_size
return self.scale.log_abs_determinant()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
new file mode 100644
index 0000000000..8adaa54c84
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
@@ -0,0 +1,138 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Affine bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+ "AffineScalar",
+]
+
+
+class AffineScalar(bijector.Bijector):
+ """Compute `Y = g(X; shift, scale) = scale * X + shift`.
+
+ Examples:
+
+ ```python
+ # Y = X
+ b = AffineScalar()
+
+ # Y = X + shift
+ b = AffineScalar(shift=[1., 2, 3])
+
+ # Y = 2 * X + shift
+ b = AffineScalar(
+ shift=[1., 2, 3],
+ scale=2.)
+ ```
+
+ """
+
+ def __init__(self,
+ shift=None,
+ scale=None,
+ validate_args=False,
+ name="affine_scalar"):
+ """Instantiates the `AffineScalar` bijector.
+
+ This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
+ giving the forward operation:
+
+ ```none
+ Y = g(X) = scale * X + shift
+ ```
+
+ if `scale` is not specified, then the bijector has the semantics of
+ `scale = 1.`. Similarly, if `shift` is not specified, then the bijector
+ has the semantics of `shift = 0.`.
+
+ Args:
+ shift: Floating-point `Tensor`. If this is set to `None`, no shift is
+ applied.
+ scale: Floating-point `Tensor`. If this is set to `None`, no scale is
+ applied.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._graph_parents = []
+ self._name = name
+ self._validate_args = validate_args
+
+ with self._name_scope("init", values=[scale, shift]):
+ self._shift = shift
+ self._scale = scale
+
+ if self._shift is not None:
+ self._shift = ops.convert_to_tensor(shift, name="shift")
+
+ if self._scale is not None:
+ self._scale = ops.convert_to_tensor(self._scale, name="scale")
+ if validate_args:
+ self._scale = control_flow_ops.with_dependencies(
+ [check_ops.assert_none_equal(
+ self._scale,
+ array_ops.zeros([], dtype=self._scale.dtype))],
+ self._scale)
+
+ super(AffineScalar, self).__init__(
+ event_ndims=0,
+ is_constant_jacobian=True,
+ validate_args=validate_args,
+ name=name)
+
+ @property
+ def shift(self):
+ """The `shift` `Tensor` in `Y = scale @ X + shift`."""
+ return self._shift
+
+ @property
+ def scale(self):
+ """The `scale` `LinearOperator` in `Y = scale @ X + shift`."""
+ return self._scale
+
+ def _forward(self, x):
+ y = array_ops.identity(x)
+ if self.scale is not None:
+ y *= self.scale
+ if self.shift is not None:
+ y += self.shift
+ return y
+
+ def _inverse(self, y):
+ x = array_ops.identity(y)
+ if self.shift is not None:
+ x -= self.shift
+ if self.scale is not None:
+ x /= self.scale
+ return x
+
+ def _forward_log_det_jacobian(self, x):
+ log_det_jacobian = array_ops.zeros_like(x)
+ if self.scale is None:
+ return log_det_jacobian
+ log_det_jacobian += math_ops.log(math_ops.abs(self.scale))
+ return log_det_jacobian
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
index cbd60f92a6..43208ff088 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -39,8 +37,6 @@ __all__ = [
class CholeskyOuterProduct(bijector.Bijector):
"""Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
- `event_ndims` must be 0 or 2, i.e., scalar or matrix.
-
Note: the upper-triangular part of X is ignored (whether or not its zero).
The surjectivity of g as a map from the set of n x n positive-diagonal
@@ -64,46 +60,31 @@ class CholeskyOuterProduct(bijector.Bijector):
Examples:
```python
- bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]])
+ bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]])
# Result: [[1., 2], [2, 5]], i.e., x @ x.T
- bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]])
+ bijector.CholeskyOuterProduct().inverse(y=[[1., 2], [2, 5]])
# Result: [[1., 0], [2, 1]], i.e., cholesky(y).
```
"""
- def __init__(self, event_ndims=2, validate_args=False,
- name="cholesky_outer_product"):
+ def __init__(self, validate_args=False, name="cholesky_outer_product"):
"""Instantiates the `CholeskyOuterProduct` bijector.
Args:
- event_ndims: `constant` `int32` scalar `Tensor` indicating the number of
- dimensions associated with a particular draw from the distribution. Must
- be 0 or 2.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
-
- Raises:
- ValueError: if event_ndims is neither 0 or 2.
"""
self._graph_parents = []
self._name = name
- with self._name_scope("init", values=[event_ndims]):
- event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
- event_ndims = tensor_util.constant_value(event_ndims)
- if event_ndims is None or event_ndims not in [0, 2]:
- raise ValueError("`event_ndims` must be a TF constant which is 0 or 2")
- self._static_event_ndims = event_ndims
super(CholeskyOuterProduct, self).__init__(
- event_ndims=event_ndims,
+ event_ndims=2,
validate_args=validate_args,
name=name)
def _forward(self, x):
- if self._static_event_ndims == 0:
- return math_ops.square(x)
if self.validate_args:
is_matrix = check_ops.assert_rank_at_least(x, 2)
shape = array_ops.shape(x)
@@ -114,11 +95,7 @@ class CholeskyOuterProduct(bijector.Bijector):
return math_ops.matmul(x, x, adjoint_b=True)
def _inverse(self, y):
- return (math_ops.sqrt(y) if self._static_event_ndims == 0
- else linalg_ops.cholesky(y))
-
- def _inverse_log_det_jacobian(self, y):
- return -self._forward_log_det_jacobian(x=self._inverse(y))
+ return linalg_ops.cholesky(y)
def _forward_log_det_jacobian(self, x):
# Let Y be a symmetric, positive definite matrix and write:
@@ -161,13 +138,6 @@ class CholeskyOuterProduct(bijector.Bijector):
# Since there is a 2 X[j,j] term for every lower-triangular element of X we
# conclude:
# |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
- if self._static_event_ndims == 0:
- if self.validate_args:
- is_positive = check_ops.assert_positive(
- x, message="All elements must be positive.")
- x = control_flow_ops.with_dependencies([is_positive], x)
- return np.log(2.) + math_ops.log(x)
-
diag = array_ops.matrix_diag_part(x)
# We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py
new file mode 100644
index 0000000000..2831a92df8
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py
@@ -0,0 +1,84 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Square bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+ "Square",
+]
+
+
+class Square(bijector.Bijector):
+ """Compute `g(X) = X^2`; X is a positive real number.
+
+ g is a bijection between the non-negative real numbers (R_+) and the
+ non-negative real numbers.
+
+ Examples:
+
+ ```python
+ bijector.Square().forward(x=[[1., 0], [2, 1]])
+ # Result: [[1., 0], [4, 1]], i.e., x^2
+
+ bijector.Square().inverse(y=[[1., 4], [9, 1]])
+ # Result: [[1., 2], [3, 1]], i.e., sqrt(y).
+ ```
+
+ """
+
+ def __init__(self, validate_args=False, name="square"):
+ """Instantiates the `Square` bijector.
+
+ Args:
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._name = name
+ super(Square, self).__init__(
+ event_ndims=0,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ x = self._maybe_assert_valid(x)
+ return math_ops.square(x)
+
+ def _inverse(self, y):
+ y = self._maybe_assert_valid(y)
+ return math_ops.sqrt(y)
+
+ def _forward_log_det_jacobian(self, x):
+ x = self._maybe_assert_valid(x)
+ return np.log(2.) + math_ops.log(x)
+
+ def _maybe_assert_valid(self, t):
+ if not self.validate_args:
+ return t
+ is_valid = check_ops.assert_non_negative(
+ t, message="All elements must be non-negative.")
+ return control_flow_ops.with_dependencies([is_valid], t)
+
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index c4b8f055b7..0d8a192691 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -174,13 +174,12 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
skewness=skewness.dtype.as_numpy_dtype(0.),
tailweight=tailweight, event_ndims=0)
- # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2))
+ # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype))
- affine = bijectors.Affine(
+ affine = bijectors.AffineScalar(
shift=loc,
- scale_identity_multiplier=c,
- validate_args=validate_args,
- event_ndims=0)
+ scale=c,
+ validate_args=validate_args)
bijector = bijectors.Chain([affine, f])
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index e1ccf11645..003c66b941 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -227,7 +227,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
c = 2 * scale_diag_part / f_noskew.forward(
ops.convert_to_tensor(2, dtype=dtype))
affine = bijectors.Affine(
- shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1)
+ shift=loc, scale_diag=c, validate_args=validate_args)
bijector = bijectors.Chain([affine, f])
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 30a7642dd3..332bada57b 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -27,7 +27,6 @@ from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
@@ -45,8 +44,13 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
-class Iterator(object):
- """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+class Iterator(iterator_ops.EagerIterator):
+ """An iterator producing tf.Tensor objects from a tf.data.Dataset.
+
+ NOTE: Unlike the iterator created by the
+ @{tf.data.Dataset.make_one_shot_iterator} method, this class enables
+ additional experimental functionality, such as prefetching to the GPU.
+ """
def __init__(self, dataset):
"""Creates a new iterator over the given dataset.
@@ -67,37 +71,12 @@ class Iterator(object):
Raises:
RuntimeError: When invoked without eager execution enabled.
"""
-
- if not context.executing_eagerly():
- raise RuntimeError(
- "{} objects can only be used when eager execution is enabled, use "
- "tf.data.Dataset.make_initializable_iterator or "
- "tf.data.Dataset.make_one_shot_iterator for graph construction".
- format(type(self)))
- with ops.device("/device:CPU:0"):
- ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
- self._output_classes = dataset.output_classes
- self._output_types = dataset.output_types
- self._output_shapes = dataset.output_shapes
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes))
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._resource = gen_dataset_ops.iterator(
- shared_name="",
- container=_generate_shared_name("eageriterator"),
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- gen_dataset_ops.make_iterator(ds_variant, self._resource)
- # Delete the resource when this object is deleted
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device="/device:CPU:0")
- self._device = context.context().device_name
- self._buffer_resource_handle = None
+ super(Iterator, self).__init__(dataset)
if not context.context().device_spec.device_type:
is_remote_device = False
else:
is_remote_device = context.context().device_spec.device_type != "CPU"
+ self._buffer_resource_handle = None
if is_remote_device:
with ops.device("/device:CPU:0"):
iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
@@ -106,7 +85,7 @@ class Iterator(object):
@function.Defun(dtypes.string)
def remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, self._output_types, self._output_shapes)
+ h, self.output_types, self.output_shapes, self.output_classes)
return remote_iterator.get_next()
remote_fn.add_to_graph(None)
@@ -124,89 +103,16 @@ class Iterator(object):
handle=self._buffer_resource_handle,
handle_device=self._device)
- def __iter__(self):
- return self
-
- def __next__(self): # For Python 3 compatibility
- return self.next()
-
def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
"""
- with ops.device(self._device):
- if self._buffer_resource_handle is not None:
+ if self._buffer_resource_handle is not None:
+ with ops.device(self._device):
ret = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=self._buffer_resource_handle,
output_types=self._flat_output_types)
- else:
- # TODO(ashankar): Consider removing this ops.device() contextmanager
- # and instead mimic ops placement in graphs: Operations on resource
- # handles execute on the same device as where the resource is placed.
- # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
- # because in eager mode this code will run synchronously on the calling
- # thread. Therefore we do not need to make a defensive context switch
- # to a background thread, and can achieve a small constant performance
- # boost by invoking the iterator synchronously.
- ret = gen_dataset_ops.iterator_get_next_sync(
- self._resource,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-
- def next(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- try:
- return self._next_internal()
- except errors.OutOfRangeError:
- raise StopIteration
-
- @property
- def output_classes(self):
- """Returns the class of each component of an element of this iterator.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of an element of this dataset.
- """
- return self._output_classes
-
- @property
- def output_shapes(self):
- """Returns the shape of each component of an element of this iterator.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of an element of this dataset.
- """
- return self._output_shapes
-
- @property
- def output_types(self):
- """Returns the type of each component of an element of this iterator.
-
- Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of an element of this dataset.
- """
- return self._output_types
-
- def get_next(self, name=None):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
-
- Args:
- name: (Optional.) A name for the created operation. Currently unused.
-
- Returns:
- A nested structure of `tf.Tensor` objects.
-
- Raises:
- `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
- """
- del name
- return self._next_internal()
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+ else:
+ return super(Iterator, self)._next_internal()
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index 35c3c5d3fa..4afadd88f5 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -44,6 +44,18 @@ class IteratorTest(test.TestCase):
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)
+ def testBasicOneShotIterator(self):
+ got = []
+ for t in Dataset.range(4).make_one_shot_iterator():
+ got.append(t.numpy())
+ self.assertAllEqual([0, 1, 2, 3], got)
+
+ def testBasicImplicitIterator(self):
+ got = []
+ for t in Dataset.range(4):
+ got.append(t.numpy())
+ self.assertAllEqual([0, 1, 2, 3], got)
+
def testGetNext(self):
iterator = datasets.Iterator(Dataset.range(4))
self.assertEqual(0, iterator.get_next().numpy())
@@ -53,6 +65,15 @@ class IteratorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
iterator.get_next()
+ def testGetNextOneShotIterator(self):
+ iterator = Dataset.range(4).make_one_shot_iterator()
+ self.assertEqual(0, iterator.get_next().numpy())
+ self.assertEqual(1, iterator.get_next().numpy())
+ self.assertEqual(2, iterator.get_next().numpy())
+ self.assertEqual(3, iterator.get_next().numpy())
+ with self.assertRaises(errors.OutOfRangeError):
+ iterator.get_next()
+
def testMultipleIteratorsOnTheSameDataset(self):
ds = Dataset.range(4)
it1 = datasets.Iterator(ds)
diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD
index 8ba0823a71..3614b2b15a 100644
--- a/tensorflow/contrib/feature_column/BUILD
+++ b/tensorflow/contrib/feature_column/BUILD
@@ -26,6 +26,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":sequence_feature_column",
+ "//tensorflow/python:util",
],
)
@@ -38,7 +39,6 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:tensor_shape",
diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py
index 650a80144f..baa8c1567a 100644
--- a/tensorflow/contrib/feature_column/__init__.py
+++ b/tensorflow/contrib/feature_column/__init__.py
@@ -25,6 +25,12 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
+ 'sequence_categorical_column_with_hash_bucket',
+ 'sequence_categorical_column_with_identity',
+ 'sequence_categorical_column_with_vocabulary_list',
+ 'sequence_categorical_column_with_vocabulary_file',
+ 'sequence_input_layer',
+ 'sequence_numeric_column',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index f57557c1cc..e60116966f 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
-import abc
import collections
@@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
@@ -99,9 +97,11 @@ def sequence_input_layer(
"""
feature_columns = fc._clean_feature_columns(feature_columns)
for c in feature_columns:
- if not isinstance(c, _SequenceDenseColumn):
+ if not isinstance(c, fc._SequenceDenseColumn):
raise ValueError(
'All feature_columns must be of type _SequenceDenseColumn. '
+ 'You can wrap a sequence_categorical_column with an embedding_column '
+ 'or indicator_column. '
'Given (type {}): {}'.format(type(c), c))
with variable_scope.variable_scope(
@@ -136,6 +136,10 @@ def sequence_categorical_column_with_identity(
key, num_buckets, default_value=None):
"""Returns a feature column that represents sequences of integers.
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
+ categorical data into dense representation for input to sequence NN, such as
+ RNN.
+
Example:
```python
@@ -163,7 +167,7 @@ def sequence_categorical_column_with_identity(
Returns:
A `_SequenceCategoricalColumn`.
"""
- return _SequenceCategoricalColumn(
+ return fc._SequenceCategoricalColumn(
fc.categorical_column_with_identity(
key=key,
num_buckets=num_buckets,
@@ -174,6 +178,10 @@ def sequence_categorical_column_with_hash_bucket(
key, hash_bucket_size, dtype=dtypes.string):
"""A sequence of categorical terms where ids are set by hashing.
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
+ categorical data into dense representation for input to sequence NN, such as
+ RNN.
+
Example:
```python
@@ -198,7 +206,7 @@ def sequence_categorical_column_with_hash_bucket(
Returns:
A `_SequenceCategoricalColumn`.
"""
- return _SequenceCategoricalColumn(
+ return fc._SequenceCategoricalColumn(
fc.categorical_column_with_hash_bucket(
key=key,
hash_bucket_size=hash_bucket_size,
@@ -210,6 +218,10 @@ def sequence_categorical_column_with_vocabulary_file(
default_value=None, dtype=dtypes.string):
"""A sequence of categorical terms where ids use a vocabulary file.
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
+ categorical data into dense representation for input to sequence NN, such as
+ RNN.
+
Example:
```python
@@ -246,7 +258,7 @@ def sequence_categorical_column_with_vocabulary_file(
Returns:
A `_SequenceCategoricalColumn`.
"""
- return _SequenceCategoricalColumn(
+ return fc._SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_file(
key=key,
vocabulary_file=vocabulary_file,
@@ -260,6 +272,10 @@ def sequence_categorical_column_with_vocabulary_list(
key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
"""A sequence of categorical terms where ids use an in-memory list.
+ Pass this to `embedding_column` or `indicator_column` to convert sequence
+ categorical data into dense representation for input to sequence NN, such as
+ RNN.
+
Example:
```python
@@ -296,7 +312,7 @@ def sequence_categorical_column_with_vocabulary_list(
Returns:
A `_SequenceCategoricalColumn`.
"""
- return _SequenceCategoricalColumn(
+ return fc._SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_list(
key=key,
vocabulary_list=vocabulary_list,
@@ -305,108 +321,6 @@ def sequence_categorical_column_with_vocabulary_list(
num_oov_buckets=num_oov_buckets))
-# TODO(b/73160931): Merge with embedding_column
-def _sequence_embedding_column(
- categorical_column, dimension, initializer=None, ckpt_to_load_from=None,
- tensor_name_in_ckpt=None, max_norm=None, trainable=True):
- """Returns a feature column that represents sequences of embeddings.
-
- Use this to convert sequence categorical data into dense representation for
- input to sequence NN, such as RNN.
-
- Example:
-
- ```python
- watches = sequence_categorical_column_with_identity(
- 'watches', num_buckets=1000)
- watches_embedding = _sequence_embedding_column(watches, dimension=10)
- columns = [watches]
-
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
- input_layer, sequence_length = sequence_input_layer(features, columns)
-
- rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
- outputs, state = tf.nn.dynamic_rnn(
- rnn_cell, inputs=input_layer, sequence_length=sequence_length)
- ```
-
- Args:
- categorical_column: A `_SequenceCategoricalColumn` created with a
- `sequence_cateogrical_column_with_*` function.
- dimension: Integer dimension of the embedding.
- initializer: Initializer function used to initialize the embeddings.
- ckpt_to_load_from: String representing checkpoint name/pattern from which to
- restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
- tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
- which to restore the column weights. Required if `ckpt_to_load_from` is
- not `None`.
- max_norm: If not `None`, embedding values are l2-normalized to this value.
- trainable: Whether or not the embedding is trainable. Default is True.
-
- Returns:
- A `_SequenceCategoricalToDenseColumn`.
-
- Raises:
- ValueError: If `categorical_column` is not the right type.
- """
- if not isinstance(categorical_column, _SequenceCategoricalColumn):
- raise ValueError(
- 'categorical_column must be of type _SequenceCategoricalColumn. '
- 'Given (type {}): {}'.format(
- type(categorical_column), categorical_column))
- return _SequenceCategoricalToDenseColumn(
- fc.embedding_column(
- categorical_column,
- dimension=dimension,
- initializer=initializer,
- ckpt_to_load_from=ckpt_to_load_from,
- tensor_name_in_ckpt=tensor_name_in_ckpt,
- max_norm=max_norm,
- trainable=trainable))
-
-
-# TODO(b/73160931): Merge with indicator_column
-def _sequence_indicator_column(categorical_column):
- """Returns a feature column that represents sequences of multi-hot tensors.
-
- Use this to convert sequence categorical data into dense representation for
- input to sequence NN, such as RNN.
-
- Example:
-
- ```python
- colors = sequence_categorical_column_with_vocabulary_list(
- key='colors', vocabulary_list=('R', 'G', 'B', 'Y'))
- colors_indicator = _sequence_indicator_column(colors)
- columns = [colors]
-
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
- input_layer, sequence_length = sequence_input_layer(features, columns)
-
- rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
- outputs, state = tf.nn.dynamic_rnn(
- rnn_cell, inputs=input_layer, sequence_length=sequence_length)
- ```
-
- Args:
- categorical_column: A `_SequenceCategoricalColumn` created with a
- `sequence_cateogrical_column_with_*` function.
-
- Returns:
- A `_SequenceCategoricalToDenseColumn`.
-
- Raises:
- ValueError: If `categorical_column` is not the right type.
- """
- if not isinstance(categorical_column, _SequenceCategoricalColumn):
- raise ValueError(
- 'categorical_column must be of type _SequenceCategoricalColumn. '
- 'Given (type {}): {}'.format(
- type(categorical_column), categorical_column))
- return _SequenceCategoricalToDenseColumn(
- fc.indicator_column(categorical_column))
-
-
def sequence_numeric_column(
key,
shape=(1,),
@@ -459,129 +373,8 @@ def _assert_all_equal_and_return(tensors, name=None):
return array_ops.identity(tensors[0])
-class _SequenceDenseColumn(fc._FeatureColumn):
- """Represents dense sequence data."""
-
- __metaclass__ = abc.ABCMeta
-
- TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
- 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length'])
-
- @abc.abstractproperty
- def _variable_shape(self):
- """`TensorShape` without batch and sequence dimensions."""
- pass
-
- @abc.abstractmethod
- def _get_sequence_dense_tensor(
- self, inputs, weight_collections=None, trainable=None):
- """Returns a `TensorSequenceLengthPair`."""
- pass
-
-
-def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
- with ops.name_scope(None, 'sequence_length') as name_scope:
- row_ids = sp_tensor.indices[:, 0]
- column_ids = sp_tensor.indices[:, 1]
- column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
- # If the last n rows do not have ids, seq_length will have shape
- # [batch_size - n]. Pad the remaining values with zeros.
- n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
- padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
- return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
-
-
-class _SequenceCategoricalColumn(
- fc._CategoricalColumn,
- collections.namedtuple(
- '_SequenceCategoricalColumn', ['categorical_column'])):
- """Represents sequences of categorical data."""
-
- @property
- def name(self):
- return self.categorical_column.name
-
- @property
- def _parse_example_spec(self):
- return self.categorical_column._parse_example_spec
-
- def _transform_feature(self, inputs):
- return self.categorical_column._transform_feature(inputs)
-
- @property
- def _num_buckets(self):
- return self.categorical_column._num_buckets
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
- id_tensor = sparse_tensors.id_tensor
- weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
- if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
- return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
-
- def _sequence_length(self, inputs):
- sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
- return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor)
-
-
-class _SequenceCategoricalToDenseColumn(
- _SequenceDenseColumn,
- collections.namedtuple(
- '_SequenceCategoricalToDenseColumn', ['dense_column'])):
- """Densifies a _SequenceCategoricalColumn using the specified column."""
-
- @property
- def name(self):
- return self.dense_column.name
-
- @property
- def _parse_example_spec(self):
- return self.dense_column._parse_example_spec
-
- def _transform_feature(self, inputs):
- return self.dense_column._transform_feature(inputs)
-
- @property
- def _variable_shape(self):
- return self.dense_column._variable_shape
-
- def _get_sequence_dense_tensor(
- self, inputs, weight_collections=None, trainable=None):
- dense_tensor = self.dense_column._get_dense_tensor(
- inputs=inputs,
- weight_collections=weight_collections,
- trainable=trainable)
- sequence_length = self.dense_column.categorical_column._sequence_length(
- inputs)
- return _SequenceDenseColumn.TensorSequenceLengthPair(
- dense_tensor=dense_tensor, sequence_length=sequence_length)
-
-
class _SequenceNumericColumn(
- _SequenceDenseColumn,
+ fc._SequenceDenseColumn,
collections.namedtuple(
'_SequenceNumericColumn',
['key', 'shape', 'default_value', 'dtype'])):
@@ -616,9 +409,9 @@ class _SequenceNumericColumn(
[array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
axis=0)
dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc._sequence_length_from_sparse_tensor(
sp_tensor, num_elements=self._variable_shape.num_elements())
- return _SequenceDenseColumn.TensorSequenceLengthPair(
+ return fc._SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index c077f03291..b64f086376 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -22,6 +22,7 @@ import os
import numpy as np
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
+from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column.feature_column import _LazyBuilder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -78,12 +79,12 @@ class SequenceInputLayerTest(test.TestCase):
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column_a = sfc._sequence_embedding_column(
+ embedding_column_a = fc.embedding_column(
categorical_column_a, dimension=embedding_dimension_a,
initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_b = sfc._sequence_embedding_column(
+ embedding_column_b = fc.embedding_column(
categorical_column_b, dimension=embedding_dimension_b,
initializer=_get_initializer(embedding_dimension_b, embedding_values_b))
@@ -107,6 +108,29 @@ class SequenceInputLayerTest(test.TestCase):
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
+ def test_embedding_column_with_non_sequence_categorical(self):
+ """Tests that error is raised for non-sequence categorical column."""
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column_a = fc.embedding_column(
+ categorical_column_a, dimension=2)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In embedding_column: aaa_embedding\. categorical_column must be of '
+ r'type _SequenceCategoricalColumn to use sequence_input_layer\.'):
+ _, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[embedding_column_a])
+
def test_indicator_column(self):
vocabulary_size_a = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -133,10 +157,10 @@ class SequenceInputLayerTest(test.TestCase):
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size_a)
- indicator_column_a = sfc._sequence_indicator_column(categorical_column_a)
+ indicator_column_a = fc.indicator_column(categorical_column_a)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size_b)
- indicator_column_b = sfc._sequence_indicator_column(categorical_column_b)
+ indicator_column_b = fc.indicator_column(categorical_column_b)
input_layer, sequence_length = sfc.sequence_input_layer(
features={
'aaa': sparse_input_a,
@@ -150,6 +174,28 @@ class SequenceInputLayerTest(test.TestCase):
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
+ def test_indicator_column_with_non_sequence_categorical(self):
+ """Tests that error is raised for non-sequence categorical column."""
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ indicator_column_a = fc.indicator_column(categorical_column_a)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In indicator_column: aaa_indicator\. categorical_column must be of '
+ r'type _SequenceCategoricalColumn to use sequence_input_layer\.'):
+ _, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[indicator_column_a])
+
def test_numeric_column(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
@@ -230,6 +276,55 @@ class SequenceInputLayerTest(test.TestCase):
sess.run(sequence_length)
+class InputLayerTest(test.TestCase):
+ """Tests input_layer with sequence feature columns."""
+
+ def test_embedding_column(self):
+ """Tests that error is raised for sequence embedding column."""
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column_a = fc.embedding_column(
+ categorical_column_a, dimension=2)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In embedding_column: aaa_embedding\. categorical_column must not be '
+ r'of type _SequenceCategoricalColumn\.'):
+ _ = fc.input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[embedding_column_a])
+
+ def test_indicator_column(self):
+ """Tests that error is raised for sequence indicator column."""
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ indicator_column_a = fc.indicator_column(categorical_column_a)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In indicator_column: aaa_indicator\. categorical_column must not be '
+ r'of type _SequenceCategoricalColumn\.'):
+ _ = fc.input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[indicator_column_a])
+
+
def _assert_sparse_tensor_value(test_case, expected, actual):
_assert_sparse_tensor_indices_shape(test_case, expected, actual)
@@ -287,37 +382,6 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
with monitored_session.MonitoredSession() as sess:
id_weight_pair.id_tensor.eval(session=sess)
- def test_sequence_length(self):
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 2, 0),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
-
- sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
-
- with monitored_session.MonitoredSession() as sess:
- sequence_length = sess.run(sequence_length)
- self.assertAllEqual(expected_sequence_length, sequence_length)
- self.assertEqual(np.int64, sequence_length.dtype)
-
- def test_sequence_length_with_zeros(self):
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((1, 0), (3, 0), (3, 1)),
- values=(1, 2, 0),
- dense_shape=(5, 2))
- expected_sequence_length = [0, 1, 0, 2, 0]
-
- sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
@@ -344,21 +408,6 @@ class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
expected_sparse_ids,
id_weight_pair.id_tensor.eval(session=sess))
- def test_sequence_length(self):
- column = sfc.sequence_categorical_column_with_hash_bucket(
- 'aaa', hash_bucket_size=10)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('omar', 'stringer', 'marlo'),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
-
- sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
@@ -399,23 +448,6 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
expected_sparse_ids,
id_weight_pair.id_tensor.eval(session=sess))
- def test_sequence_length(self):
- column = sfc.sequence_categorical_column_with_vocabulary_file(
- key='aaa',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
-
- sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
@@ -441,22 +473,6 @@ class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
expected_sparse_ids,
id_weight_pair.id_tensor.eval(session=sess))
- def test_sequence_length(self):
- column = sfc.sequence_categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
-
- sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
class SequenceEmbeddingColumnTest(test.TestCase):
@@ -496,7 +512,7 @@ class SequenceEmbeddingColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = sfc._sequence_embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
@@ -522,7 +538,7 @@ class SequenceEmbeddingColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = sfc._sequence_embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
@@ -550,7 +566,7 @@ class SequenceEmbeddingColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = sfc._sequence_embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
@@ -587,7 +603,7 @@ class SequenceIndicatorColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- indicator_column = sfc._sequence_indicator_column(categorical_column)
+ indicator_column = fc.indicator_column(categorical_column)
indicator_tensor, _ = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))
@@ -607,7 +623,7 @@ class SequenceIndicatorColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- indicator_column = sfc._sequence_indicator_column(categorical_column)
+ indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))
@@ -634,7 +650,7 @@ class SequenceIndicatorColumnTest(test.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- indicator_column = sfc._sequence_indicator_column(categorical_column)
+ indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))
diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc
index fe8bf6e21c..9372289623 100644
--- a/tensorflow/contrib/image/kernels/segmentation_ops.cc
+++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc
@@ -101,8 +101,8 @@ struct ImageConnectedComponentsFunctor<CPUDevice, T> {
int cost = (union_find.block_height() + union_find.block_width()) * 20;
Shard(worker_threads->num_threads, worker_threads->workers,
num_images * num_blocks_vertically * num_blocks_horizontally, cost,
- [&union_find, num_images, num_blocks_vertically,
- num_blocks_horizontally](int64 start_block, int64 limit_block) {
+ [&union_find, num_blocks_vertically, num_blocks_horizontally](
+ int64 start_block, int64 limit_block) {
for (int64 i = start_block; i < limit_block; i++) {
int64 block_x = i % num_blocks_horizontally;
int64 block_y =
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index c1ea296b43..30c5404e03 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -96,49 +96,57 @@ class EstimatorTest(test.TestCase):
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est.make_ops_and_vars()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection)
+ est.make_ops_and_vars()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="not_a_real_mode")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="not_a_real_mode")
+ est.make_ops_and_vars()
def testGradientsModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="gradients")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="gradients")
+ est.make_ops_and_vars()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="empirical")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="empirical")
+ est.make_ops_and_vars()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="curvature_prop")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="curvature_prop")
+ est.make_ops_and_vars()
def testExactModeBuild(self):
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="exact")
+ est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ self.layer_collection,
+ estimation_mode="exact")
+ est.make_ops_and_vars()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index c9c0f8e0ae..b70c700f09 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -764,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase):
return multiply_result, multiply_inverse_result
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_minibatch(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_minibatch(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Ensure inverse update op doesn't crash.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run([
+ factor.make_inverse_update_ops()
+ for factor in layer_collection.get_factors()
+ ])
+
+ # Ensure inverse-vector multiply doesn't crash.
+ output = block.multiply_inverse(params)
+ sess.run(output)
+
+ # Ensure same shape.
+ self.assertAllEqual(output.shape, params.shape)
+
+
class ConvKFCBasicFBTest(test.TestCase):
def _testConvKFCBasicFBInitParams(self, params):
@@ -775,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.constant(params)
inputs = random_ops.random_normal((2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertAllEqual([outputs], block.tensors_to_compute_grads())
def testConvKFCBasicFBInitParamsParamsTuple(self):
- self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
def testConvKFCBasicFBInitParamsParamsSingle(self):
- self._testConvKFCBasicFBInitParams([np.array([1., 2.])])
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
def testMultiplyInverseTuple(self):
with ops.Graph().as_default(), self.test_session() as sess:
@@ -792,8 +841,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
block.instantiate_factors(((grads,),), 0.5)
@@ -823,8 +872,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = random_ops.random_normal((2, 2, 2, 2))
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertFalse(block._has_bias)
grads = outputs**2
@@ -851,8 +900,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = [random_ops.random_normal((2, 2, 2, 2))]
inputs = random_ops.random_normal((2, 2, 2, 2))
outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
self.assertTrue(block._has_bias)
grads = outputs**2
@@ -879,8 +928,8 @@ class ConvKFCBasicFBTest(test.TestCase):
params = array_ops.zeros((2, 2, 2, 2))
inputs = array_ops.zeros((2, 2, 2, 2))
outputs = array_ops.zeros((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
- 'SAME')
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
block.register_additional_minibatch(inputs, outputs)
grads = outputs**2
damping = 0. # This test is only valid without damping.
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index beb427bdcc..16f02f1199 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -23,12 +23,14 @@ import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
@@ -447,6 +449,117 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
+class ConvDiagonalFactorTest(test.TestCase):
+
+ def setUp(self):
+ self.batch_size = 10
+ self.height = self.width = 32
+ self.in_channels = 3
+ self.out_channels = 1
+ self.kernel_height = self.kernel_width = 3
+ self.strides = [1, 2, 2, 1]
+ self.data_format = 'NHWC'
+ self.padding = 'SAME'
+ self.kernel_shape = [
+ self.kernel_height, self.kernel_width, self.in_channels,
+ self.out_channels
+ ]
+
+ def testInit(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ inputs,
+ outputs_grads,
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format)
+ factor.instantiate_cov_variables()
+
+ # Ensure covariance matrix's shape makes sense.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels,
+ self.out_channels
+ ],
+ factor.get_cov_var().shape.as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ # Construct all arguments such that convolution kernel is applied in
+ # exactly one spatial location.
+ inputs = np.random.randn(
+ 1, # batch_size
+ self.kernel_height,
+ self.kernel_width,
+ self.in_channels) # in_channels
+ outputs_grad = np.random.randn(
+ 1, # batch_size
+ 1, # output_height
+ 1, # output_width
+ self.out_channels)
+
+ factor = ff.ConvDiagonalFactor(
+ constant_op.constant(inputs), [constant_op.constant(outputs_grad)],
+ self.kernel_shape,
+ strides=[1, 1, 1, 1],
+ padding='VALID')
+ factor.instantiate_cov_variables()
+
+ # Completely forget initial value on first update.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ # Ensure new covariance value is same as outer-product of inputs/outputs
+ # vectorized, squared.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ cov = sess.run(cov_update_op)
+ expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
+ self.assertAllClose(expected_cov, cov)
+
+ def testHasBias(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ inputs,
+ outputs_grads,
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format,
+ has_bias=True)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape accounts for bias.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels + 1,
+ self.out_channels
+ ],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure update op doesn't crash.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(cov_update_op)
+
+
class FullyConnectedKroneckerFactorTest(test.TestCase):
def _testFullyConnectedKroneckerFactorInit(self,
@@ -493,24 +606,152 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-class ConvInputKroneckerFactorTest(test.TestCase):
+class ConvFactorTestCase(test.TestCase):
+
+ def assertMatrixRank(self, rank, matrix, atol=1e-5):
+ assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
+ eigvals = np.linalg.eigvals(matrix)
+ nnz_eigvals = np.sum(eigvals > atol)
+ self.assertEqual(
+ rank,
+ nnz_eigvals,
+ msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
+ (nnz_eigvals, rank, eigvals)))
+
+
+class ConvInputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**3
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, width, in_channels), seed=0),
+ filter_shape=(width, width, width, in_channels, out_channels),
+ padding='SAME',
+ strides=(2, 2, 2),
+ extract_patches_fn='extract_convolution_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ input_size = in_channels * (width**3)
+ self.assertEqual([input_size, input_size],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank-8, as the filter will be applied at each corner of
+ # the 4-D cube.
+ self.assertMatrixRank(8, cov)
+
+ def testPointwiseConv2d(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 1, 1, 1),
+ extract_patches_fn='extract_pointwise_conv2d_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ self.assertEqual([in_channels, in_channels],
+ factor.get_cov_var().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank-9, as the filter will be applied at each location.
+ self.assertMatrixRank(9, cov)
+
+ def testStrides(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 2, 1, 1),
+ extract_patches_fn='extract_image_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be the sum of 3 * 2 = 6 outer products.
+ self.assertMatrixRank(6, cov)
+
+ def testDilationRate(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),
+ filter_shape=(3, 3, in_channels, out_channels),
+ padding='SAME',
+ extract_patches_fn='extract_image_patches',
+ strides=(1, 1, 1, 1),
+ dilation_rate=(1, width, width, 1),
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov_var())
+
+ # Cov should be rank = in_channels, as only the center of the filter
+ # receives non-zero input for each input channel.
+ self.assertMatrixRank(in_channels, cov)
def testConvInputKroneckerFactorInitNoBias(self):
with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
+ inputs=tensor,
+ filter_shape=(1, 2, 3, 4),
+ padding='SAME',
+ has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
@@ -518,10 +759,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
def testConvInputKroneckerFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+ tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -530,33 +770,60 @@ class ConvInputKroneckerFactorTest(test.TestCase):
def testMakeCovarianceUpdateOpWithBias(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
+ input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
- np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
factor = ff.ConvInputKroneckerFactor(
- tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
+ tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]],
- new_cov)
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose(
+ [
+ [(1. + 4.) / 2., (1. + 2.) / 2.], #
+ [(1. + 2.) / 2., (1. + 1.) / 2.]
+ ], #
+ new_cov)
def testMakeCovarianceUpdateOpNoBias(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
+ input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
- np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
- factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
- [1, 1, 1, 1], 'SAME')
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ tensor, filter_shape=(1, 1, 1, 1), padding='SAME')
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[34.375, 37], [37, 41]], new_cov)
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
-class ConvOutputKroneckerFactorTest(test.TestCase):
+class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ out_channels = width**3
+
+ factor = ff.ConvOutputKroneckerFactor(outputs_grads=[
+ random_ops.random_uniform(
+ (batch_size, width, width, width, out_channels), seed=0)
+ ])
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank 3^3, as each spatial position donates a rank-1
+ # update.
+ self.assertMatrixRank(width**3, cov)
def testConvOutputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
@@ -577,13 +844,6 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
self.assertEqual(cov.dtype, dtype)
self.assertEqual([5, 5], cov.get_shape().as_list())
- def testConvOutputKroneckerFactorInitNotEnoughDims(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- with self.assertRaises(IndexError):
- ff.ConvOutputKroneckerFactor((tensor,))
-
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index 889f336811..bae6bd7a3b 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -104,14 +104,31 @@ class LayerCollectionTest(test.TestCase):
array_ops.constant(3),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_conv2d(
- array_ops.constant(4), [1, 1, 1, 1], 'SAME',
- array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)))
lc.register_conv2d(
- array_ops.constant(4), [1, 1, 1, 1],
- 'SAME',
- array_ops.ones((1, 1, 1, 1)),
- array_ops.constant(3),
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)),
approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_separable_conv2d(
+ depthwise_params=array_ops.ones((3, 3, 1, 2)),
+ pointwise_params=array_ops.ones((1, 1, 2, 4)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
+ pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ lc.register_convolution(
+ params=array_ops.ones((3, 3, 1, 8)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ outputs=array_ops.ones((32, 5, 5, 8)),
+ padding='SAME')
lc.register_generic(
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
lc.register_generic(
@@ -119,7 +136,7 @@ class LayerCollectionTest(test.TestCase):
16,
approx=layer_collection.APPROX_DIAGONAL_NAME)
- self.assertEqual(6, len(lc.get_blocks()))
+ self.assertEqual(9, len(lc.get_blocks()))
def testRegisterBlocksMultipleRegistrations(self):
with ops.Graph().as_default():
@@ -535,6 +552,32 @@ class LayerCollectionTest(test.TestCase):
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
+ def testDefaultLayerCollection(self):
+ with ops.Graph().as_default():
+ # Can't get default if there isn't one set.
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # Can't set default twice.
+ lc = layer_collection.LayerCollection()
+ layer_collection.set_default_layer_collection(lc)
+ with self.assertRaises(ValueError):
+ layer_collection.set_default_layer_collection(lc)
+
+ # Same as one set.
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+
+ # Can set to None.
+ layer_collection.set_default_layer_collection(None)
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # as_default() is the same as setting/clearing.
+ with lc.as_default():
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
index 97a97adbf5..2cee01212a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
@@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -325,6 +327,84 @@ class UtilsTest(test.TestCase):
],
values)
+ def testExtractConvolutionPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_spatial_shape = [9, 10, 11]
+ in_channels = out_channels = 32
+ kernel_spatial_shape = [5, 3, 3]
+ spatial_strides = [1, 2, 1]
+ spatial_dilation = [1, 1, 1]
+ padding = 'SAME'
+
+ images = random_ops.random_uniform(
+ [batch_size] + image_spatial_shape + [in_channels], seed=0)
+ kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_convolution_patches(
+ images,
+ kernel_shape,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+ result_spatial_shape = (
+ patches.shape.as_list()[1:1 + len(image_spatial_shape)])
+ self.assertEqual(patches.shape.as_list(),
+ [batch_size] + result_spatial_shape +
+ kernel_spatial_shape + [in_channels])
+
+ # Ensure extract...patches() + matmul() and convolution() implementation
+ # give the same answer.
+ outputs = nn_ops.convolution(
+ images,
+ kernel,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+ def testExtractPointwiseConv2dPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_height = image_width = 8
+ in_channels = out_channels = 3
+ kernel_height = kernel_width = 1
+ strides = [1, 1, 1, 1]
+ padding = 'VALID'
+
+ images = random_ops.random_uniform(
+ [batch_size, image_height, image_width, in_channels], seed=0)
+ kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
+ self.assertEqual(patches.shape.as_list(), [
+ batch_size, image_height, image_width, kernel_height, kernel_width,
+ in_channels
+ ])
+
+ # Ensure extract...patches() + matmul() and conv2d() implementation
+ # give the same answer.
+ outputs = nn_ops.conv2d(images, kernel, strides, padding)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, kernel_height * kernel_width * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index fdfd9599f4..64755be65c 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -149,8 +149,6 @@ class FisherEstimator(object):
self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
- self._layers.create_subgraph()
- self._layers.check_registration(variables)
self._gradient_fns = {
"gradients": self._get_grads_lists_gradients,
"empirical": self._get_grads_lists_empirical,
@@ -164,9 +162,6 @@ class FisherEstimator(object):
self._name = name
- self._instantiate_factors()
- self._register_matrix_functions()
-
@property
def variables(self):
return self._variables
@@ -285,6 +280,12 @@ class FisherEstimator(object):
for block in self.blocks:
block.register_matpower(exp)
+ def _finalize_layer_collection(self):
+ self._layers.create_subgraph()
+ self._layers.check_registration(self.variables)
+ self._instantiate_factors()
+ self._register_matrix_functions()
+
def make_ops_and_vars(self, scope=None):
"""Make ops and vars with no specific device placement.
@@ -467,6 +468,8 @@ class FisherEstimator(object):
"""
self._check_vars_unmade_and_set_made_flag()
+ self._finalize_layer_collection()
+
scope = self.name if scope is None else scope
cov_variable_thunks = [
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 521a98866b..31f4689fbf 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -40,10 +40,12 @@ from __future__ import print_function
import abc
import enum # pylint: disable=g-bad-import-order
+import numpy as np
import six
from tensorflow.contrib.kfac.python.ops import fisher_factors
from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -517,7 +519,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
- """FisherBlock for convolutional layers using a diagonal approx.
+ """FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
@@ -541,7 +543,13 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
to the layer's parameters 'w'.
"""
- def __init__(self, layer_collection, params, strides, padding):
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ data_format=None,
+ dilations=None):
"""Creates a ConvDiagonalFB block.
Args:
@@ -553,29 +561,53 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
containing the previous and a Tensor of shape [out_channels].
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (e.g. "SAME").
+ data_format: str or None. Format of input data.
+ dilations: List of 4 ints or None. Rate for dilation along all dimensions.
+
+ Raises:
+ ValueError: if strides is not length-4.
+ ValueError: if dilations is not length-4.
+ ValueError: if channel is not last dimension.
"""
- self._strides = tuple(strides) if isinstance(strides, list) else strides
+ if len(strides) != 4:
+ raise ValueError("strides must contain 4 numbers.")
+
+ if dilations is None:
+ dilations = [1, 1, 1, 1]
+
+ if len(dilations) != 4:
+ raise ValueError("dilations must contain 4 numbers.")
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ self._strides = maybe_tuple(strides)
self._padding = padding
+ self._data_format = data_format
+ self._dilations = maybe_tuple(dilations)
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
self._filter_shape = tuple(fltr.shape.as_list())
+ if len(self._filter_shape) != 4:
+ raise ValueError(
+ "Convolution filter must be of shape"
+ " [filter_height, filter_width, in_channels, out_channels].")
+
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- # Infer number of locations upon which convolution is applied.
- inputs_shape = tuple(self._inputs[0].shape.as_list())
- self._num_locations = (
- inputs_shape[1] * inputs_shape[2] //
- (self._strides[1] * self._strides[2]))
-
inputs, grads_list = self._package_minibatches(grads_list)
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._strides)
+
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides,
- self._padding, self._has_bias))
+ (inputs, grads_list, self._filter_shape, self._strides, self._padding,
+ self._data_format, self._dilations, self._has_bias))
def damping_func():
return self._num_locations * normalize_damping(damping,
@@ -658,8 +690,8 @@ class KroneckerProductFB(FisherBlock):
reshaped_out = self._input_factor.left_multiply_matpower(
reshaped_out, exp, self._input_damping_func)
if self._renorm_coeff != 1.0:
- reshaped_out *= math_ops.cast(
- self._renorm_coeff**exp, dtype=reshaped_out.dtype)
+ renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
+ reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def full_fisher_block(self):
@@ -761,7 +793,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
- """FisherBlock for 2D convolutional layers using the basic KFC approx.
+ """FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
layer.
@@ -784,21 +816,40 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
See equation 23 in https://arxiv.org/abs/1602.01407 for details.
"""
- def __init__(self, layer_collection, params, strides, padding):
+ def __init__(self,
+ layer_collection,
+ params,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None):
"""Creates a ConvKFCBasicFB block.
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [kernel_height, kernel_width,
+ kernel alone, a Tensor of shape [..spatial_filter_shape..,
in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
containing the previous and a Tensor of shape [out_channels].
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (1-D of Tensor length 4).
+ padding: str. Padding method.
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
"""
- self._strides = tuple(strides) if isinstance(strides, list) else strides
self._padding = padding
+ self._strides = maybe_tuple(strides)
+ self._dilation_rate = maybe_tuple(dilation_rate)
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
self._has_bias = isinstance(params, (tuple, list))
fltr = params[0] if self._has_bias else params
@@ -807,15 +858,16 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._package_minibatches(grads_list)
+
# Infer number of locations upon which convolution is applied.
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
- inputs, grads_list = self._package_minibatches(grads_list)
-
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._strides, self._padding,
+ (inputs, self._filter_shape, self._padding, self._strides,
+ self._dilation_rate, self._data_format, self._extract_patches_fn,
self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
@@ -827,17 +879,262 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
return self._num_locations
+class DepthwiseConvDiagonalFB(ConvDiagonalFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvDiagonalFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvDiagonalFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ strides=strides,
+ padding=padding,
+ dilations=rate,
+ data_format=data_format)
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def multiply_matpower(self, vector, exp):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
+ conv2d_vector, exp)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvKFCBasicFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=rate,
+ data_format=data_format,
+ extract_patches_fn="extract_image_patches")
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def multiply_matpower(self, vector, exp):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
+ conv2d_vector, exp)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with conv2d.
+
+ Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
+ compatible with tf.nn.conv2d().
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, channel_multiplier].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape [height, width, in_channels, out_channels].
+
+ """
+ with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ filter.shape.as_list())
+
+ results = []
+ for i in range(in_channels):
+ # Slice out one in_channel's filter. Insert zeros around it to force it
+ # to affect that channel and that channel alone.
+ elements = []
+ if i > 0:
+ elements.append(
+ array_ops.zeros(
+ [filter_height, filter_width, i, channel_multiplier]))
+ elements.append(filter[:, :, i:(i + 1), :])
+ if i + 1 < in_channels:
+ elements.append(
+ array_ops.zeros([
+ filter_height, filter_width, in_channels - (i + 1),
+ channel_multiplier
+ ]))
+
+ # Concat along in_channel.
+ results.append(
+ array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-1, name="out_channel")
+
+
+def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with depthwise_conv2d.
+
+ Transforms a filter for use with tf.nn.conv2d() to one that's
+ compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
+ the diagonal.
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, out_channels].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape,
+ [height, width, in_channels, channel_multiplier]
+
+ Raises:
+ ValueError: if out_channels is not evenly divisible by in_channels.
+ """
+ with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, out_channels = (
+ filter.shape.as_list())
+
+ if out_channels % in_channels != 0:
+ raise ValueError("out_channels must be evenly divisible by in_channels.")
+ channel_multiplier = out_channels // in_channels
+
+ results = []
+ filter = array_ops.reshape(filter, [
+ filter_height, filter_width, in_channels, in_channels,
+ channel_multiplier
+ ])
+ for i in range(in_channels):
+ # Slice out output corresponding to the correct filter.
+ filter_slice = array_ops.reshape(
+ filter[:, :, i, i, :],
+ [filter_height, filter_width, 1, channel_multiplier])
+ results.append(filter_slice)
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-2, name="in_channels")
+
+
+def maybe_tuple(obj):
+ if not isinstance(obj, list):
+ return obj
+ return tuple(obj)
+
+
def num_conv_locations(input_shape, strides):
"""Returns the number of spatial locations a 2D Conv kernel is applied to.
Args:
- input_shape: list representing shape of inputs to the Conv layer.
- strides: list representing strides for the Conv kernel.
+ input_shape: List of ints representing shape of inputs to
+ tf.nn.convolution().
+ strides: List of ints representing strides along spatial dimensions as
+ passed in to tf.nn.convolution().
Returns:
A scalar |T| denoting the number of spatial locations for the Conv layer.
"""
- return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
+ spatial_input_locations = np.prod(input_shape[1:-1])
+
+ if strides is None:
+ spatial_strides_divisor = 1
+ else:
+ spatial_strides_divisor = np.prod(strides)
+
+ return spatial_input_locations // spatial_strides_divisor
class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
@@ -858,7 +1155,7 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
def instantiate_factors(self, grads_list, damping):
- self._num_uses = len(self._inputs[0])
+ self._num_uses = float(len(self._inputs[0]))
inputs, grads_list = self._package_minibatches_multi(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 8ac63bc764..6fc163e232 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -159,7 +159,9 @@ def scope_string_from_params(params):
name_parts = []
for param in params:
- if isinstance(param, (tuple, list)):
+ if param is None:
+ name_parts.append("None")
+ elif isinstance(param, (tuple, list)):
if all([isinstance(p, int) for p in param]):
name_parts.append("-".join([str(p) for p in param]))
else:
@@ -867,6 +869,8 @@ class ConvDiagonalFactor(DiagonalFactor):
filter_shape,
strides,
padding,
+ data_format=None,
+ dilations=None,
has_bias=False):
"""Creates a ConvDiagonalFactor object.
@@ -880,15 +884,42 @@ class ConvDiagonalFactor(DiagonalFactor):
out_channels). Represents shape of kernel used in this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
+ data_format: None or str. Format of conv2d inputs.
+ dilations: None or tuple of 4 ints.
has_bias: Python bool. If True, the layer is assumed to have a bias
parameter in addition to its filter parameter.
+
+ Raises:
+ ValueError: If inputs, output_grads, and filter_shape do not agree on
+ in_channels or out_channels.
+ ValueError: If strides, dilations are not length-4 lists of ints.
+ ValueError: If data_format does not put channel last.
"""
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ if inputs.shape.ndims != 4:
+ raise ValueError("inputs must be 4-D Tensor.")
+ if inputs.shape.as_list()[-1] != filter_shape[-2]:
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+ for i, outputs_grad in enumerate(outputs_grads):
+ if outputs_grad.shape.ndims != 4:
+ raise ValueError("outputs[%d] must be 4-D Tensor." % i)
+ if outputs_grad.shape.as_list()[-1] != filter_shape[-1]:
+ raise ValueError(
+ "outputs[%d] and filter_shape must agree on out_channels." % i)
+ if len(strides) != 4:
+ raise ValueError("strides must be length-4 list of ints.")
+ if dilations is not None and len(dilations) != 4:
+ raise ValueError("dilations must be length-4 list of ints.")
+
self._inputs = inputs
+ self._outputs_grads = outputs_grads
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
+ self._data_format = data_format
+ self._dilations = dilations
self._has_bias = has_bias
- self._outputs_grads = outputs_grads
self._patches = None
super(ConvDiagonalFactor, self).__init__()
@@ -919,11 +950,15 @@ class ConvDiagonalFactor(DiagonalFactor):
# TODO(b/64144716): there is potential here for a big savings in terms
# of memory use.
+ if self._dilations is None:
+ rates = (1, 1, 1, 1)
+ else:
+ rates = tuple(self._dilations)
patches = array_ops.extract_image_patches(
self._inputs,
ksizes=[1, filter_height, filter_width, 1],
strides=self._strides,
- rates=[1, 1, 1, 1],
+ rates=rates,
padding=self._padding)
if self._has_bias:
@@ -1010,39 +1045,55 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
def __init__(self,
inputs,
filter_shape,
- strides,
padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None,
has_bias=False):
"""Initializes ConvInputKroneckerFactor.
Args:
- inputs: A Tensor of shape [batch_size, height, width, in_channels]
- which is the inputs to the layer (before being processed into patches).
- filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
- kernel_width, in_channels, out_channels].
- strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
- width_stride, in_channel_stride].
+ inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels].
+ Inputs to layer.
+ filter_shape: List of ints. Contains [..spatial_filter_size..,
+ in_channels, out_channels]. Shape of convolution kernel.
padding: str. Padding method for layer. "SAME" or "VALID".
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
has_bias: bool. If True, append 1 to in_channel.
"""
+ self._inputs = inputs
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
+ self._dilation_rate = dilation_rate
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
self._has_bias = has_bias
- self._inputs = inputs
+
super(ConvInputKroneckerFactor, self).__init__()
@property
def _var_scope(self):
return "ff_convinkron_" + scope_string_from_params([
self._inputs, self._filter_shape, self._strides, self._padding,
- self._has_bias
+ self._dilation_rate, self._data_format, self._has_bias
])
@property
def _cov_shape(self):
- filter_height, filter_width, in_channels, _ = self._filter_shape
- size = filter_height * filter_width * in_channels + self._has_bias
+ spatial_filter_shape = self._filter_shape[0:-2]
+ in_channels = self._filter_shape[-2]
+ size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
return [size, size]
@property
@@ -1057,18 +1108,44 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- filter_height, filter_width, in_channels, _ = self._filter_shape
-
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
+ if self._extract_patches_fn in [None, "extract_convolution_patches"]:
+ patches = utils.extract_convolution_patches(
+ self._inputs,
+ self._filter_shape,
+ padding=self._padding,
+ strides=self._strides,
+ dilation_rate=self._dilation_rate,
+ data_format=self._data_format)
+
+ elif self._extract_patches_fn == "extract_image_patches":
+ assert self._inputs.shape.ndims == 4
+ assert len(self._filter_shape) == 4
+ assert len(self._strides) == 4, self._strides
+ if self._dilation_rate is None:
+ rates = [1, 1, 1, 1]
+ else:
+ rates = self._dilation_rate
+ assert len(rates) == 4
+ assert rates[0] == rates[-1] == 1
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
+ strides=self._strides,
+ rates=rates,
+ padding=self._padding)
+
+ elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
+ assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
+ assert self._filter_shape[0] == self._filter_shape[1] == 1
+ patches = utils.extract_pointwise_conv2d_patches(
+ self._inputs, self._filter_shape, data_format=None)
- flatten_size = (filter_height * filter_width * in_channels)
+ else:
+ raise NotImplementedError(self._extract_patches_fn)
+
+ flatten_size = np.prod(self._filter_shape[0:-1])
# patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
# omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
# where M = minibatch size, |T| = number of spatial locations,
@@ -1100,14 +1177,21 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
Section 3.1 Estimating the factors.
"""
- def __init__(self, outputs_grads):
+ def __init__(self, outputs_grads, data_format=None):
"""Initializes ConvOutputKroneckerFactor.
Args:
- outputs_grads: List of Tensors, each of shape [batch_size,
- height, width, out_channels]. One Tensor for each "source".
+ outputs_grads: list of Tensors. Each Tensor is of shape
+ [batch_size, ..spatial_input_size.., out_channels]. One Tensor per
+ source.
+ data_format: None or str. Format of outputs_grads.
+
+ Raises:
+ ValueError: If channels are not final dimension.
"""
- self._out_channels = outputs_grads[0].shape.as_list()[3]
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ self._out_channels = outputs_grads[0].shape.as_list()[-1]
self._outputs_grads = outputs_grads
super(ConvOutputKroneckerFactor, self).__init__()
@@ -1433,4 +1517,3 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return [control_flow_ops.group(*ops)]
# pylint: enable=invalid-name
-
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 60894ed951..4eb5e4c092 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -26,6 +26,7 @@ from __future__ import print_function
from collections import defaultdict
from collections import OrderedDict
+from contextlib import contextmanager
from functools import partial
import math
@@ -75,6 +76,27 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
+_DEFAULT_LAYER_COLLECTION = None
+
+
+def get_default_layer_collection():
+ """Get default LayerCollection."""
+ if _DEFAULT_LAYER_COLLECTION is None:
+ raise ValueError(
+ "Attempted to retrieve default LayerCollection when none is set. Use "
+ "LayerCollection.as_default().")
+
+ return _DEFAULT_LAYER_COLLECTION
+
+
+def set_default_layer_collection(layer_collection):
+ global _DEFAULT_LAYER_COLLECTION
+
+ if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
+ raise ValueError("Default LayerCollection is already set.")
+
+ _DEFAULT_LAYER_COLLECTION = layer_collection
+
class LayerParametersDict(OrderedDict):
"""An OrderedDict where keys are Tensors or tuples of Tensors.
@@ -594,21 +616,25 @@ class LayerCollection(object):
padding,
inputs,
outputs,
+ data_format=None,
+ dilations=None,
approx=None,
reuse=VARIABLE_SCOPE):
- """Registers a convolutional layer.
+ """Registers a call to tf.nn.conv2d().
Args:
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
this layer. Weight matrix should have shape [kernel_height,
kernel_width, in_channels, out_channels]. Bias should have shape
[out_channels].
- strides: 1-D Tensor of length 4. Strides for convolution kernel.
+ strides: List of 4 ints. Strides for convolution kernel.
padding: string. see tf.nn.conv2d for valid values.
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
to layer.
outputs: Tensor of shape [batch_size, height, width, out_channels].
Output produced by layer.
+ data_format: str or None. Format of data.
+ dilations: List of 4 ints. Dilations along each dimension.
approx: str. One of "kron" or "diagonal".
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
@@ -629,12 +655,206 @@ class LayerCollection(object):
raise ValueError("Bad value {} for approx.".format(approx))
block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
+ if approx == APPROX_KRONECKER_NAME:
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ data_format=data_format,
+ dilation_rate=dilations,
+ extract_patches_fn="extract_image_patches"),
+ reuse=reuse)
+ elif approx == APPROX_DIAGONAL_NAME:
+ assert strides[0] == strides[-1] == 1
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilations=dilations,
+ data_format=data_format),
+ reuse=reuse)
+ else:
+ raise NotImplementedError
+
+ block.register_additional_minibatch(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_convolution(self,
+ params,
+ inputs,
+ outputs,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.convolution().
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [..filter_spatial_size..,
+ in_channels, out_channels]. Bias should have shape [out_channels].
+ inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
+ Inputs to layer.
+ outputs: Tensor of shape [batch_size, ..output_spatial_size..,
+ out_channels]. Output produced by layer.
+ padding: string. see tf.nn.conv2d for valid values.
+ strides: List of ints of length len(..input_spatial_size..). Strides for
+ convolution kernel in spatial dimensions.
+ dilation_rate: List of ints of length len(..input_spatial_size..).
+ Dilations along spatial dimension.
+ data_format: str or None. Format of data.
+ approx: str. One of "kron" or "diagonal".
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ assert approx is None or approx == APPROX_KRONECKER_NAME
+
block = self.register_block(
- params, block_type(self, params, strides, padding), reuse=reuse)
+ params,
+ fb.ConvKFCBasicFB(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate,
+ data_format=data_format),
+ reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
+ def register_depthwise_conv2d(self,
+ params,
+ inputs,
+ outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.depthwise_conv2d().
+
+ Args:
+ params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Convolutional filter.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ outputs: Tensor of shape [batch_size, output_height, output_width,
+ in_channels * channel_multiplier]. Output produced by depthwise conv2d.
+ strides: List of ints of length 4. Strides along all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rates in spatial
+ dimensions.
+ data_format: str or None. Format of data.
+ approx: None or str. Must be "diagonal" if non-None.
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ assert approx is None or approx == APPROX_DIAGONAL_NAME
+ assert data_format in [None, "NHWC"]
+
+ block = self.register_block(
+ params,
+ fb.DepthwiseConvDiagonalFB(
+ layer_collection=self,
+ params=params,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format),
+ reuse=reuse)
+ block.register_additional_minibatch(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_separable_conv2d(self,
+ depthwise_params,
+ pointwise_params,
+ inputs,
+ depthwise_outputs,
+ pointwise_outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.separable_conv2d().
+
+ Note: This requires access to intermediate outputs betwee depthwise and
+ pointwise convolutions.
+
+ Args:
+ depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Filter for depthwise conv2d.
+ pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
+ channel_multiplier, out_channels]. Filter for pointwise conv2d.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ depthwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, in_channels * channel_multiplier]. Output produced by
+ depthwise conv2d.
+ pointwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, out_channels]. Output produced by pointwise conv2d.
+ strides: List of ints of length 4. Strides for depthwise conv2d kernel in
+ all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
+ kernel in spatial dimensions.
+ data_format: str or None. Format of data.
+ approx: None or str. Must be "kron" if non-None.
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ self.register_depthwise_conv2d(
+ params=depthwise_params,
+ inputs=inputs,
+ outputs=depthwise_outputs,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format,
+ approx=APPROX_DIAGONAL_NAME,
+ reuse=reuse)
+
+ self.register_conv2d(
+ params=pointwise_params,
+ inputs=depthwise_outputs,
+ outputs=pointwise_outputs,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ data_format=data_format,
+ approx=approx,
+ reuse=reuse)
+
def register_generic(self,
params,
batch_size,
@@ -833,3 +1053,10 @@ class LayerCollection(object):
with variable_scope.variable_scope(self._var_scope):
self.fisher_factors[key] = cls(*args)
return self.fisher_factors[key]
+
+ @contextmanager
+ def as_default(self):
+ """Sets this LayerCollection as the default."""
+ set_default_layer_collection(self)
+ yield
+ set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
index f8aa230d9c..9f46853807 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
+ "get_default_layer_collection",
+ "set_default_layer_collection",
"LayerParametersDict",
"LayerCollection",
"APPROX_KRONECKER_NAME",
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index 5ce5338a9f..af26f5e56b 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -431,6 +432,127 @@ def batch_execute(global_step, thunks, batch_size, name=None):
return result
+def extract_convolution_patches(inputs,
+ filter_shape,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ name=None,
+ data_format=None):
+ """Extracts inputs to each output coordinate in tf.nn.convolution.
+
+ This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
+ where the number of spatial dimensions may be something other than 2.
+
+ Assumes,
+ - First dimension of inputs is batch_size
+ - Convolution filter is applied to all input channels.
+
+ Args:
+ inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
+ filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
+ padding: string. Padding method. One of "VALID", "SAME".
+ strides: None or list of ints. Strides along spatial dimensions.
+ dilation_rate: None or list of ints. Dilation along spatial dimensions.
+ name: None or str. Name of Op.
+ data_format: None or str. Format of data.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: If data_format does not put channel last.
+ ValueError: If inputs and filter disagree on in_channels.
+ """
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last dimension.")
+ with ops.name_scope(name, "extract_convolution_patches",
+ [inputs, filter_shape, padding, strides, dilation_rate]):
+ batch_size = inputs.shape.as_list()[0]
+ in_channels = inputs.shape.as_list()[-1]
+
+ # filter_shape = spatial_filter_shape + [in_channels, out_channels]
+ spatial_filter_shape = filter_shape[:-2]
+ if in_channels != filter_shape[-2]:
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+
+ # Map each input feature to a location in the output.
+ out_channels = np.prod(spatial_filter_shape) * in_channels
+ filters = linalg_ops.eye(out_channels)
+ filters = array_ops.reshape(
+ filters,
+ list(spatial_filter_shape) + [in_channels, out_channels])
+
+ result = nn_ops.convolution(
+ inputs,
+ filters,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate)
+ spatial_output_shape = result.shape.as_list()[1:-1]
+ result = array_ops.reshape(result,
+ [batch_size or -1] + spatial_output_shape +
+ list(spatial_filter_shape) + [in_channels])
+
+ return result
+
+
+def extract_pointwise_conv2d_patches(inputs,
+ filter_shape,
+ name=None,
+ data_format=None):
+ """Extract patches for a 1x1 conv2d.
+
+ Args:
+ inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
+ filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
+ name: None or str. Name for Op.
+ data_format: None or str. Format for data. See 'data_format' in
+ tf.nn.conv2d() for details.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_input_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: if inputs is not 4-D.
+ ValueError: if filter_shape is not [1, 1, ?, ?]
+ ValueError: if data_format is not channels-last.
+ """
+ if inputs.shape.ndims != 4:
+ raise ValueError("inputs must have 4 dims.")
+ if len(filter_shape) != 4:
+ raise ValueError("filter_shape must have 4 dims.")
+ if filter_shape[0] != 1 or filter_shape[1] != 1:
+ raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels last.")
+ with ops.name_scope(name, "extract_pointwise_conv2d_patches",
+ [inputs, filter_shape]):
+ ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
+ strides = [1, 1, 1, 1] # Operate on all pixels.
+ rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
+ padding = "VALID" # Doesn't matter.
+ result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
+ padding)
+
+ batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
+ filter_height, filter_width, in_channels, _ = filter_shape
+ return array_ops.reshape(result, [
+ batch_size, input_height, input_width, filter_height, filter_width,
+ in_channels
+ ])
+
+
+def is_data_format_channel_last(data_format):
+ """True if data_format puts channel last."""
+ if data_format is None:
+ return True
+ return data_format.endswith("C")
+
+
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index 8e424a7946..330d222dbf 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -40,6 +40,9 @@ _allowed_symbols = [
"fwd_gradients",
"ensure_sequence",
"batch_execute",
+ "extract_convolution_patches",
+ "extract_pointwise_conv2d_patches",
+ "is_data_format_channel_last",
"matmul_sparse_dense",
"matmul_diag_sparse",
]
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 6491d8c86a..45184b05ec 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -415,6 +415,8 @@ typedef struct _TfLiteDelegate {
typedef struct {
TfLiteDelegate* delegate;
TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
} TfLiteDelegateParams;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 831cfafeae..cee57bba5e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -94,7 +94,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.tensors_size = 0;
context_.eigen_context = nullptr;
context_.gemm_context = nullptr;
- context_.recommended_num_threads = 0;
+ context_.recommended_num_threads = -1;
// Invalid to call these these except from TfLiteDelegate
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
@@ -139,31 +139,76 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
namespace {
+// Copy a std::vector<int> to an existing TfLiteIntArray.
+// This is a low-level data manipulation function, and it's caller's
+// responsibility to ensure TfLiteIntArray has enough size.
+void CopyVectorToTfLiteIntArray(const std::vector<int>& vec,
+ TfLiteIntArray* arr) {
+ arr->size = vec.size();
+ memcpy(arr->data, vec.data(), sizeof(int) * arr->size);
+}
+
// This function allocates a continuous memory space that contains a
-// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be
-// deallocated by C `free` function later.
-TfLiteDelegateParams* CreateDelegateParams(
- TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) {
- int nodes_to_replace_size_in_bytes =
- TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size());
- void* allocation =
- malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes);
+// TfLiteDelegateParams followed by a several TfLiteIntArray.
+// When calling `free` at TfLiteDelegateParams*, all the allocated space
+// will be freed together.
+//
+// +-----------------------------------+
+// | TfLiteDelegateParams |
+// | TfLiteDelegate* delegate; |
+// | TfLiteIntArray* nodes_to_replace; |--\
+// | TfLiteIntArray* input_tensors; |--+--\
+// | TfLiteIntArray* output_tensors; |--+--+--\
+// +-----------------------------------+ | | |
+// | TfLiteIntArray (variable size) |<-/ | |
+// +-----------------------------------+ | |
+// | TfLiteIntArray (variable size) |<----/ |
+// +-----------------------------------+ |
+// | TfLiteIntArray (variable size) |<-------/
+// +-----------------------------------+
+TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
+ const Subgraph& subgraph) {
+ // Step 1: Calculate the allocation size.
+ int allocation_size = sizeof(TfLiteDelegateParams);
+
+ int nodes_to_replace_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size());
+ allocation_size += nodes_to_replace_size;
+
+ int input_tensors_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size());
+ allocation_size += input_tensors_size;
+
+ int output_tensors_size =
+ TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size());
+ allocation_size += output_tensors_size;
+
+ // Step 2: Allocate the memory.
+ // Use `char*` for conveniently step through the allocated space by bytes.
+ char* allocation = reinterpret_cast<char*>(malloc(allocation_size));
+
+ // Step 3: Fill all data structures structures.
TfLiteDelegateParams* params =
reinterpret_cast<TfLiteDelegateParams*>(allocation);
- TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>(
- static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams));
+ params->delegate = delegate;
+ allocation += sizeof(TfLiteDelegateParams);
- nodes_to_replace_arr->size = nodes_to_replace.size();
- for (int i = 0; i < nodes_to_replace.size(); ++i) {
- nodes_to_replace_arr->data[i] = nodes_to_replace[i];
- }
+ params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace);
+ allocation += nodes_to_replace_size;
+
+ params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors);
+ allocation += input_tensors_size;
+
+ params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
+ CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors);
+ allocation += output_tensors_size;
- params->delegate = delegate;
- params->nodes_to_replace = nodes_to_replace_arr;
return params;
}
-} // Anonymous namespace
+} // namespace
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
@@ -192,8 +237,7 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
case Subgraph::kTfPartition: {
int node_index;
- TfLiteDelegateParams* params =
- CreateDelegateParams(delegate, subgraph.nodes);
+ TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph);
AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
nullptr, 0, params, &registration, &node_index);
@@ -229,8 +273,8 @@ TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
*execution_plan = plan_cache_.get();
static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
"TfLiteIntArray and execution_plan do not contain same type.");
- memcpy(plan_cache_->data, execution_plan_.data(),
- sizeof(plan_cache_->data[0]) * execution_plan_.size());
+ std::memcpy(plan_cache_->data, execution_plan_.data(),
+ sizeof(plan_cache_->data[0]) * execution_plan_.size());
return kTfLiteOk;
}
@@ -575,9 +619,9 @@ TfLiteStatus Interpreter::GetNodeAndRegistration(
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization,
- const char* buffer, size_t bytes, const Allocation* allocation) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
+ size_t bytes, const Allocation* allocation) {
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
@@ -585,23 +629,24 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// because their sizes change with the contents of the individual strings.
if (type != kTfLiteString) {
size_t required_bytes;
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
TfLiteTensor& tensor = context_.tensors[tensor_index];
- if (type == tensor.type && EqualVectorAndTfLiteIntArray(tensor.dims, dims)) {
+ if (type == tensor.type &&
+ EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
// Fast path which does not invalidate the invokable property.
TfLiteTensorDataFree(&tensor);
tensor.data.raw = const_cast<char*>(buffer);
- if (!tensor.dims) tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+ if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
tensor.params = quantization;
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation;
} else {
invokable_ = false;
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
}
@@ -613,8 +658,8 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
- int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization) {
invokable_ = false;
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
@@ -624,10 +669,10 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
// many bytes we will need based on the dimensions. String tensors are
// allocated dynamically and we can't know ahead of time how much space
// they will require.
- TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
- &required_bytes));
+ TF_LITE_ENSURE_OK(&context_,
+ BytesRequired(type, dims, rank, &required_bytes));
}
- TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 276dc0e0ae..b481ee0891 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -134,18 +134,34 @@ class Interpreter {
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
- TfLiteStatus SetTensorParametersReadOnly(
+ inline TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ const char* buffer, size_t bytes,
+ const Allocation* allocation = nullptr) {
+ return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
+ dims.data(), quantization, buffer, bytes,
+ allocation);
+ };
+
+ TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
- TfLiteStatus SetTensorParametersReadWrite(
+ inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization);
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
+ dims.data(), quantization);
+ }
+ TfLiteStatus SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name, const int rank,
+ const int* dims, TfLiteQuantizationParams quantization);
// Functions to access tensor data
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 008b2c9a3e..72d4acedbe 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -923,8 +923,24 @@ TEST_F(TestDelegate, BasicDelegate) {
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
int node = interpreter_->execution_plan()[0];
const auto* node_and_reg = interpreter_->node_and_registration(node);
- ASSERT_EQ(node_and_reg->second.custom_name,
+ EXPECT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
+
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(
+ node_and_reg->first.builtin_data);
+ ASSERT_EQ(params->nodes_to_replace->size, 3);
+ EXPECT_EQ(params->nodes_to_replace->data[0], 0);
+ EXPECT_EQ(params->nodes_to_replace->data[1], 1);
+ EXPECT_EQ(params->nodes_to_replace->data[2], 2);
+
+ ASSERT_EQ(params->input_tensors->size, 2);
+ EXPECT_EQ(params->input_tensors->data[0], 0);
+ EXPECT_EQ(params->input_tensors->data[1], 1);
+
+ ASSERT_EQ(params->output_tensors->size, 2);
+ EXPECT_EQ(params->output_tensors->data[0], 3);
+ EXPECT_EQ(params->output_tensors->data[1], 4);
}
TEST_F(TestDelegate, ComplexDeligate) {
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index 1435a45672..213e465552 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -27,8 +27,9 @@ struct RefCountedEigenContext {
void IncrementUsageCounter(TfLiteContext* context) {
auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
if (ptr == nullptr) {
- Eigen::setNbThreads(context->recommended_num_threads);
-
+ if (context->recommended_num_threads != -1) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+ }
ptr = new RefCountedEigenContext;
ptr->num_references = 0;
context->eigen_context = ptr;
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index df8a9c8cee..76a5165d14 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -29,7 +29,9 @@ void IncrementUsageCounter(TfLiteContext* context) {
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
ptr->gemm_context_ = new gemmlowp::GemmContext();
- ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ if (context->recommended_num_threads != -1) {
+ ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ }
ptr->num_references_ = 0;
context->gemm_context = ptr;
}
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 82feae0f00..76607af079 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -4,6 +4,38 @@ package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "py_test")
+filegroup(
+ name = "interpreter_test_data",
+ srcs = glob(["**/testdata/*"]),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "interpreter",
+ srcs = [
+ "interpreter.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
+ ],
+)
+
+py_test(
+ name = "interpreter_test",
+ srcs = ["interpreter_test.py"],
+ data = [":interpreter_test_data"],
+ srcs_version = "PY2AND3",
+ tags = ["no_oss"],
+ deps = [
+ ":interpreter",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
py_library(
name = "lite",
srcs = ["lite.py"],
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
new file mode 100644
index 0000000000..5b5a7c3199
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -0,0 +1,135 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python TF-Lite interpreter."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_interpreter_wrapper as interpreter_wrapper
+
+
+class Interpreter(object):
+ """Interpreter inferace for TF-Lite Models."""
+
+ def __init__(self, model_path):
+ """Constructor.
+
+ Args:
+ model_path: Path to TF-Lite Flatbuffer file.
+
+ Raises:
+ ValueError: If the interpreter was unable to open the model.
+ """
+ self._interpreter = (
+ interpreter_wrapper.InterpreterWrapper_CreateWrapperCPP(model_path))
+ if not self._interpreter:
+ raise ValueError('Failed to open {}'.format(model_path))
+
+ def allocate_tensors(self):
+ if not self._interpreter.AllocateTensors():
+ raise ValueError('Failed to allocate tensors')
+
+ def _get_tensor_details(self, tensor_index):
+ """Gets tensor details.
+
+ Args:
+ tensor_index: Tensor index of tensor to query.
+
+ Returns:
+ a dictionary containing the name, index, shape and type of the tensor.
+
+ Raises:
+ ValueError: If tensor_index is invalid.
+ """
+ tensor_index = int(tensor_index)
+ tensor_name = self._interpreter.TensorName(tensor_index)
+ tensor_size = self._interpreter.TensorSize(tensor_index)
+ tensor_type = self._interpreter.TensorType(tensor_index)
+
+ if not tensor_name or not tensor_type:
+ raise ValueError('Could not get tensor details')
+
+ details = {
+ 'name': tensor_name,
+ 'index': tensor_index,
+ 'shape': tensor_size,
+ 'dtype': tensor_type,
+ }
+
+ return details
+
+ def get_input_details(self):
+ """Gets model input details.
+
+ Returns:
+ A list of input details.
+ """
+ return [
+ self._get_tensor_details(i) for i in self._interpreter.InputIndices()
+ ]
+
+ def set_tensor(self, tensor_index, value):
+ """Sets the value of the input.
+
+ Args:
+ tensor_index: Tensor index of tensor to set. This value can be gotten from
+ the 'index' field in get_input_details.
+ value: Value of tensor to set.
+
+ Raises:
+ ValueError: If the interpreter could not set the tensor.
+ """
+ if not self._interpreter.SetTensor(tensor_index, value):
+ raise ValueError('Failed to set tensor')
+
+ def resize_tensor_input(self, input_index, tensor_size):
+ """Resizes an input tensor.
+
+ Args:
+ input_index: Tensor index of input to set. This value can be gotten from
+ the 'index' field in get_input_details.
+ tensor_size: The tensor_shape to resize the input to.
+
+ Raises:
+ ValueError: If the interpreter could not resize the input tensor.
+ """
+ if not self.ResizeInputTensor.SetTensor(input_index, tensor_size):
+ raise ValueError('Failed to set input')
+
+ def get_output_details(self):
+ """Gets model output details.
+
+ Returns:
+ A list of output details.
+ """
+ return [
+ self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
+ ]
+
+ def get_tensor(self, tensor_index):
+ """Sets the value of the input.
+
+ Args:
+ tensor_index: Tensor index of tensor to get. This value can be gotten from
+ the 'index' field in get_output_details.
+
+ Returns:
+ a numpy array.
+ """
+ return self._interpreter.GetTensor(tensor_index)
+
+ def invoke(self):
+ if not self._interpreter.Invoke():
+ raise ValueError('Failed to invoke TFLite model')
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
new file mode 100644
index 0000000000..e0215b721c
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -0,0 +1,82 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Lite Python Interface: Sanity check."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+
+
+class InterpreterTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ resource_loader.get_path_to_datafile('testdata/permute_float.tflite'))
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+
+ test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
+ expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], test_input)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue((expected_output == output_data).all())
+
+ def testUint8(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ resource_loader.get_path_to_datafile('testdata/permute_uint8.tflite'))
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+
+ test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
+ expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
+ interpreter.set_tensor(input_details[0]['index'], test_input)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue((expected_output == output_data).all())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
new file mode 100644
index 0000000000..453eda6e73
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
@@ -0,0 +1,32 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+
+cc_library(
+ name = "interpreter_wrapper_lib",
+ srcs = ["interpreter_wrapper.cc"],
+ hdrs = ["interpreter_wrapper.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/core:lib",
+ "//tensorflow/python:numpy_lib",
+ "//util/python:python_headers",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "tensorflow_wrap_interpreter_wrapper",
+ srcs = [
+ "interpreter_wrapper.i",
+ ],
+ deps = [
+ ":interpreter_wrapper_lib",
+ "//util/python:python_headers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
new file mode 100644
index 0000000000..f30067de94
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -0,0 +1,313 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/python/lib/core/numpy.h"
+
+#if PY_MAJOR_VERSION >= 3
+#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
+#define CPP_TO_PYSTRING PyBytes_FromStringAndSize
+#else
+#define PY_TO_CPPSTRING PyString_AsStringAndSize
+#define CPP_TO_PYSTRING PyString_FromStringAndSize
+#endif
+
+namespace tflite {
+namespace interpreter_wrapper {
+
+namespace {
+std::unique_ptr<tflite::Interpreter> CreateInterpreter(
+ const tflite::FlatBufferModel* model,
+ const tflite::ops::builtin::BuiltinOpResolver& resolver) {
+ if (!model) {
+ return nullptr;
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (interpreter) {
+ for (const int input_index : interpreter->inputs()) {
+ const TfLiteTensor* tensor = interpreter->tensor(input_index);
+ CHECK(tensor);
+ const TfLiteIntArray* dims = tensor->dims;
+ if (!dims) {
+ continue;
+ }
+
+ std::vector<int> input_dims(dims->data, dims->data + dims->size);
+ interpreter->ResizeInputTensor(input_index, input_dims);
+ }
+ }
+ return interpreter;
+}
+
+int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
+ switch (tf_lite_type) {
+ case kTfLiteFloat32:
+ return NPY_FLOAT32;
+ case kTfLiteInt32:
+ return NPY_INT32;
+ case kTfLiteUInt8:
+ return NPY_UINT8;
+ case kTfLiteInt64:
+ return NPY_INT64;
+ case kTfLiteString:
+ return NPY_OBJECT;
+ case kTfLiteNoType:
+ return -1;
+ }
+ LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type;
+ return -1;
+}
+
+TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
+ int pyarray_type = PyArray_TYPE(array);
+ switch (pyarray_type) {
+ case NPY_FLOAT32:
+ return kTfLiteFloat32;
+ case NPY_INT32:
+ return kTfLiteInt32;
+ case NPY_UINT8:
+ return kTfLiteUInt8;
+ case NPY_INT64:
+ return kTfLiteInt64;
+ case NPY_OBJECT:
+ case NPY_STRING:
+ case NPY_UNICODE:
+ return kTfLiteString;
+ }
+ LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
+ return kTfLiteNoType;
+}
+
+struct PyDecrefDeleter {
+ void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
+ void* pydata = malloc(size * sizeof(int));
+ memcpy(pydata, data, size * sizeof(int));
+ return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
+}
+
+} // namespace
+
+InterpreterWrapper::InterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model)
+ : model_(std::move(model)),
+ resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
+ interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
+
+InterpreterWrapper::~InterpreterWrapper() {}
+
+bool InterpreterWrapper::AllocateTensors() {
+ if (!interpreter_) {
+ LOG(ERROR) << "Cannot allocate tensors: invalid interpreter.";
+ return false;
+ }
+
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ LOG(ERROR) << "Unable to allocate tensors.";
+ return false;
+ }
+
+ return true;
+}
+
+bool InterpreterWrapper::Invoke() {
+ return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false;
+}
+
+PyObject* InterpreterWrapper::InputIndices() const {
+ PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
+ interpreter_->inputs().size());
+
+ return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+}
+
+PyObject* InterpreterWrapper::OutputIndices() const {
+ PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
+ interpreter_->outputs().size());
+
+ return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+}
+
+bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
+ if (!interpreter_) {
+ LOG(ERROR) << "Invalid interpreter.";
+ return false;
+ }
+
+ std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
+ PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
+ if (!array_safe) {
+ LOG(ERROR) << "Failed to convert value into readable tensor.";
+ return false;
+ }
+
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
+
+ if (PyArray_NDIM(array) != 1) {
+ LOG(ERROR) << "Expected 1-D defining input shape.";
+ return false;
+ }
+
+ if (PyArray_TYPE(array) != NPY_INT32) {
+ LOG(ERROR) << "Shape must be an int32 array";
+ return false;
+ }
+
+ std::vector<int> dims(PyArray_SHAPE(array)[0]);
+ memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
+
+ return interpreter_->ResizeInputTensor(i, dims);
+}
+
+std::string InterpreterWrapper::TensorName(int i) const {
+ if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
+ return "";
+ }
+
+ const TfLiteTensor* tensor = interpreter_->tensor(i);
+ return tensor->name;
+}
+
+PyObject* InterpreterWrapper::TensorType(int i) const {
+ if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
+ return nullptr;
+ }
+
+ const TfLiteTensor* tensor = interpreter_->tensor(i);
+ int typenum = TfLiteTypeToPyArrayType(tensor->type);
+ return PyArray_TypeObjectFromType(typenum);
+}
+
+PyObject* InterpreterWrapper::TensorSize(int i) const {
+ if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ const TfLiteTensor* tensor = interpreter_->tensor(i);
+ PyObject* np_array =
+ PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
+
+ return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+}
+
+bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
+ if (!interpreter_) {
+ LOG(ERROR) << "Invalid interpreter.";
+ return false;
+ }
+
+ if (i >= interpreter_->tensors_size()) {
+ LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
+ << interpreter_->tensors_size();
+ return false;
+ }
+
+ std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
+ PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
+ if (!array_safe) {
+ LOG(ERROR) << "Failed to convert value into readable tensor.";
+ return false;
+ }
+
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
+ const TfLiteTensor* tensor = interpreter_->tensor(i);
+
+ if (TfLiteTypeFromPyArray(array) != tensor->type) {
+ LOG(ERROR) << "Cannot set tensor:"
+ << " Got tensor of type " << TfLiteTypeFromPyArray(array)
+ << " but expected type " << tensor->type << " for input " << i;
+ return false;
+ }
+
+ if (PyArray_NDIM(array) != tensor->dims->size) {
+ LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
+ return false;
+ }
+
+ for (int j = 0; j < PyArray_NDIM(array); j++) {
+ if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
+ LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
+ return false;
+ }
+ }
+
+ size_t size = PyArray_NBYTES(array);
+ DCHECK_EQ(size, tensor->bytes);
+ memcpy(tensor->data.raw, PyArray_DATA(array), size);
+ return true;
+}
+
+PyObject* InterpreterWrapper::GetTensor(int i) const {
+ if (!interpreter_) {
+ LOG(ERROR) << "Invalid interpreter.";
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ if (i >= interpreter_->tensors_size()) {
+ LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
+ << interpreter_->inputs().size();
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ const TfLiteTensor* output_tensor = interpreter_->tensor(i);
+ const int tensor_size = output_tensor->bytes;
+ if (tensor_size <= 0) {
+ LOG(ERROR) << "Invalid tensor size";
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ int type_num = TfLiteTypeToPyArrayType(output_tensor->type);
+ if (type_num == -1) {
+ LOG(ERROR) << "Unknown tensor type " << output_tensor->type;
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+
+ void* data = malloc(tensor_size);
+ memcpy(data, output_tensor->data.raw, tensor_size);
+
+ const TfLiteIntArray* output_dims = output_tensor->dims;
+ std::vector<npy_intp> dims(output_dims->data,
+ output_dims->data + output_dims->size);
+ PyObject* np_array =
+ PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
+
+ return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+}
+
+InterpreterWrapper* InterpreterWrapper::CreateWrapperCPP(
+ const char* model_path) {
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(model_path);
+ return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+}
+
+} // namespace interpreter_wrapper
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
new file mode 100644
index 0000000000..dea71ca879
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
+#define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <Python.h>
+
+// We forward declare TFLite classes here to avoid exposing them to SWIG.
+namespace tflite {
+namespace ops {
+namespace builtin {
+class BuiltinOpResolver;
+} // namespace builtin
+} // namespace ops
+
+class FlatBufferModel;
+class Interpreter;
+
+namespace interpreter_wrapper {
+
+class InterpreterWrapper {
+ public:
+ // SWIG caller takes ownership of pointer.
+ static InterpreterWrapper* CreateWrapperCPP(const char* model_path);
+
+ ~InterpreterWrapper();
+ bool AllocateTensors();
+ bool Invoke();
+
+ PyObject* InputIndices() const;
+ PyObject* OutputIndices() const;
+ bool ResizeInputTensor(int i, PyObject* value);
+
+ std::string TensorName(int i) const;
+ PyObject* TensorType(int i) const;
+ PyObject* TensorSize(int i) const;
+ bool SetTensor(int i, PyObject* value);
+ PyObject* GetTensor(int i) const;
+
+ private:
+ InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model);
+
+ // InterpreterWrapper is not copyable or assignable. We avoid the use of
+ // InterpreterWrapper() = delete here for SWIG compatibility.
+ InterpreterWrapper();
+ InterpreterWrapper(const InterpreterWrapper& rhs);
+
+ const std::unique_ptr<tflite::FlatBufferModel> model_;
+ const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
+ const std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+} // namespace interpreter_wrapper
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
new file mode 100644
index 0000000000..7f51f9f00d
--- /dev/null
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "std_string.i"
+
+
+%{
+#define SWIG_FILE_WITH_INIT
+#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+%}
+
+
+%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 5d2f216537..35d224924e 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -202,11 +202,12 @@ def toco_convert(input_data,
input_array.name = _tensor_name(input_tensor)
input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
- toco.inference_input_type = tflite_input_type
for output_tensor in output_tensors:
model.output_arrays.append(_tensor_name(output_tensor))
+ # TODO(aselle): Consider handling the case of allowing quantized
+ # inputs to be converted to float (via the toco.inference_input_type field).
data = toco_convert_protos(model.SerializeToString(),
toco.SerializeToString(),
input_data.SerializeToString())
diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py
index 9a3971228a..7908689ce4 100644
--- a/tensorflow/contrib/lite/python/op_hint.py
+++ b/tensorflow/contrib/lite/python/op_hint.py
@@ -119,8 +119,10 @@ class OpHint(object):
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
- dest_op.op.node_def.attr[name].tensor.CopyFrom(
- tensor_value.op.node_def.attr["value"].tensor)
+ # pylint: disable=protected-access
+ dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
+ tensor=tensor_value.op.node_def.attr["value"].tensor))
+ # pylint: enable=protected-access
def add_inputs(self, *args):
"""Add a sequence of inputs to the function invocation.
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index da65ec659c..a758c5e7e1 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -1,6 +1,8 @@
-package(default_visibility = [
- "//visibility:public",
-])
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+)
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
index e601284495..81cedb5dad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
@@ -57,6 +57,11 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
// We use the output shape that has been calculated by shape propagation.
const auto& output_shape = model->GetArray(squeeze_op->outputs[0]).shape();
+ // Empty shapes will not work as empty data arrays.
+ if (output_shape.dimensions_count() == 0) {
+ return false;
+ }
+
auto* reshape_op = new TensorFlowReshapeOperator;
reshape_op->inputs = {
squeeze_op->inputs[0],
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index c2b166033c..5a36a90b38 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -21,6 +21,33 @@ limitations under the License.
namespace toco {
+namespace {
+
+bool TransposeAffectsMemoryOrder(std::vector<int> perm,
+ std::vector<int> in_shape) {
+ CHECK_EQ(perm.size(), in_shape.size());
+ // See what the ordering of the non-unary columns are before and after
+ // transpose permutation. If the major indices stay in the same order (not
+ // just the shape) then the flat buffer representation shouldn't change.
+ std::vector<int> old_major_index_ordering;
+ std::vector<int> new_major_index_ordering;
+ for (int i = 0; i < in_shape.size(); i++) {
+ if (in_shape[i] != 1) {
+ old_major_index_ordering.push_back(i);
+ }
+
+ if (in_shape[perm[i]] != 1) {
+ new_major_index_ordering.push_back(perm[i]);
+ }
+ }
+
+ CHECK_EQ(new_major_index_ordering.size(), old_major_index_ordering.size());
+
+ return old_major_index_ordering != new_major_index_ordering;
+}
+
+} // namespace
+
bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
auto transpose_it = model->operators.begin() + op_index;
if (transpose_it->get()->type != OperatorType::kTranspose) {
@@ -29,23 +56,26 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
+ const auto& input_array = model->GetArray(transpose_op->inputs[0]);
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
- if (!output_array.has_shape()) {
+ if (!input_array.has_shape() || !output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
- // This transpose is trivial if we only have one non-unitary dimension.
- std::vector<int> const& dims = output_array.shape().dims();
- unsigned non_unitary_axis_count = 0;
- for (int i = 0; i < dims.size(); i++) {
- if (dims[i] != 1) {
- non_unitary_axis_count++;
- }
+ // Check that the permutation has propogated.
+ std::vector<int> const& perm = transpose_op->perm;
+ if (perm.empty()) {
+ return false;
}
- if (non_unitary_axis_count > 1) {
- // Transpose is not trivial
+
+ // This transpose is trivial if non-unitary dimensions remain in the same
+ // order.
+ std::vector<int> const& input_dims = input_array.shape().dims();
+ std::vector<int> const& output_dims = output_array.shape().dims();
+
+ if (TransposeAffectsMemoryOrder(perm, input_dims)) {
return false;
}
@@ -61,11 +91,11 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
string shape_array_name = toco::AvailableArrayName(*model, perm_array_name);
Array& shape_array = model->GetOrCreateArray(shape_array_name);
*(shape_array.mutable_shape()->mutable_dims()) = {
- 1, static_cast<int>(dims.size())};
+ 1, static_cast<int>(output_dims.size())};
reshape_op->inputs.push_back(shape_array_name);
shape_array.data_type = ArrayDataType::kInt32;
auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>();
- shape_buffer.data = dims;
+ shape_buffer.data = output_dims;
// Delete perm array if unused
if (IsDiscardableArray(*model, perm_array_name) &&
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index b7f31e2731..fb4af07d06 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -17,17 +17,21 @@ limitations under the License.
namespace tflite {
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
- TfLiteIntArray* output = TfLiteIntArrayCreate(input.size());
- for (size_t i = 0; i < input.size(); i++) {
- output->data[i] = input[i];
+ return ConvertArrayToTfLiteIntArray(input.size(), input.data());
+}
+
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) {
+ TfLiteIntArray* output = TfLiteIntArrayCreate(rank);
+ for (size_t i = 0; i < rank; i++) {
+ output->data[i] = dims[i];
}
return output;
}
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
- const std::vector<int>& b) {
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+ const int* b) {
if (!a) return false;
- if (a->size != b.size()) return false;
+ if (a->size != b_size) return false;
for (int i = 0; i < a->size; ++i) {
if (a->data[i] != b[i]) return false;
}
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f505d82a11..a34db35823 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -29,9 +29,11 @@ namespace tflite {
// Converts a `std::vector` to a `TfLiteIntArray`.
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
-// Checks whether a `TfLiteIntArray` and `std::vector` have matching elements.
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
- const std::vector<int>& b);
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims);
+
+// Checks whether a `TfLiteIntArray` and an int array have matching elements.
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+ const int* b);
} // namespace tflite
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 4c00630cfe..75a3c3d034 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -1134,9 +1134,9 @@ tensorflow::Status BinaryTensorOpTensor(
CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented(
- "binary op: " + node_def.op() +
- " not supported at: " + node_def.name());
+ return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
+ " not supported at: " +
+ node_def.name());
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(tensor_l),
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 1ae6347220..74df75902e 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -70,6 +70,7 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
<< "' failed with " << status;
}
}
+
// TODO(Sami, aaorey): Find an alternative way!
cudaStreamSynchronize(
stream); // we have to wait for the stream before returning!
@@ -85,11 +86,12 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
cond_.notify_all();
while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
cond_.wait(lock);
+
}
if (done_) {
return false;
}
- CHECK(!calib_running_ && batch_is_set_);
+
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
if (it == dev_buffers_.end()) {
@@ -107,11 +109,13 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
return nullptr;
}
+
void TRTInt8Calibrator::setDone() {
tensorflow::mutex_lock lock(cond_mtx_);
done_ = true;
cond_.notify_all();
}
+
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {}
TRTInt8Calibrator::~TRTInt8Calibrator() {
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index 4e7b74d620..d77aa2c5ab 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -69,5 +69,4 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
#endif
#endif
-
#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 1d96145e59..f4d9351432 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -95,12 +95,12 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
def _train_ops(self, features):
"""Add training ops to the graph."""
+ mode = estimator_lib.ModeKeys.TRAIN
with variable_scope.variable_scope(
"model",
# Use ResourceVariables to avoid race conditions.
use_resource=True):
- model_outputs = self.state_manager.define_loss(
- self.model, features, estimator_lib.ModeKeys.TRAIN)
+ model_outputs = self.create_loss(features, mode)
train_op = optimizers.optimize_loss(
model_outputs.loss,
@@ -110,14 +110,14 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
learning_rate=None)
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.TRAIN,
+ mode=mode,
train_op=train_op)
def _evaluate_ops(self, features):
"""Add ops for evaluation (aka filtering) to the graph."""
+ mode = estimator_lib.ModeKeys.EVAL
with variable_scope.variable_scope("model", use_resource=True):
- model_outputs = self.state_manager.define_loss(
- self.model, features, estimator_lib.ModeKeys.EVAL)
+ model_outputs = self.create_loss(features, mode)
metrics = {}
# Just output in-sample predictions for the last chunk seen
for prediction_key, prediction_value in model_outputs.predictions.items():
@@ -130,7 +130,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
model_outputs.end_state))
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.EVAL,
+ mode=mode,
eval_metric_ops=metrics,
predictions={})
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc
index 811d459758..a9dc6ca6cd 100644
--- a/tensorflow/core/common_runtime/build_graph_options.cc
+++ b/tensorflow/core/common_runtime/build_graph_options.cc
@@ -21,15 +21,15 @@ namespace tensorflow {
string BuildGraphOptions::DebugString() const {
string rv = "Feed endpoints: ";
- for (auto& s : feed_endpoints) {
+ for (auto& s : callable_options.feed()) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nFetch endpoints: ");
- for (auto& s : fetch_endpoints) {
+ for (auto& s : callable_options.fetch()) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nTarget nodes: ");
- for (auto& s : target_nodes) {
+ for (auto& s : callable_options.target()) {
strings::StrAppend(&rv, s, ", ");
}
return rv;
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 5f0e8f170b..5ca170e922 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -19,25 +19,18 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/debug.pb.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
struct BuildGraphOptions {
- std::vector<string> feed_endpoints;
- std::vector<string> fetch_endpoints;
-
- // TODO(vrv): Remove this when we unify target_nodes and fetch_endpoint,
- // the former via "ref" fetch_endpoints.
- std::vector<string> target_nodes;
+ CallableOptions callable_options;
// If `true`, uses Arg/Retval to implement feeds/fetches; otherwise
// uses Recv/Send to implement feeds/fetches.
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
- DebugOptions debug_options;
-
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 1fbc314e2e..25cfb9e524 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1083,19 +1083,8 @@ Status DirectSession::CreateExecutors(
std::unique_ptr<FunctionInfo>* out_func_info,
RunStateArgs* run_state_args) {
BuildGraphOptions options;
- options.feed_endpoints = std::vector<string>(callable_options.feed().begin(),
- callable_options.feed().end());
- options.fetch_endpoints = std::vector<string>(
- callable_options.fetch().begin(), callable_options.fetch().end());
- options.target_nodes = std::vector<string>(callable_options.target().begin(),
- callable_options.target().end());
+ options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run;
- if (!callable_options.run_options()
- .debug_options()
- .debug_tensor_watch_opts()
- .empty()) {
- options.debug_options = callable_options.run_options().debug_options();
- }
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1191,9 +1180,11 @@ Status DirectSession::CreateExecutors(
/*shape_map=*/nullptr);
// EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
- if (!options.debug_options.debug_tensor_watch_opts().empty()) {
+ const DebugOptions& debug_options =
+ options.callable_options.run_options().debug_options();
+ if (!debug_options.debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
- options.debug_options, partition_graph.get(), params.device));
+ debug_options, partition_graph.get(), params.device));
}
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
@@ -1384,19 +1375,19 @@ Status DirectSession::CreateGraphs(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
- if (subgraph_options.feed_endpoints.size() !=
+ if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) {
return errors::Internal(
"Graph pruning failed: requested number of feed endpoints = ",
- subgraph_options.feed_endpoints.size(),
+ subgraph_options.callable_options.feed_size(),
" versus number of pruned feed endpoints = ",
client_graph->feed_types.size());
}
- if (subgraph_options.fetch_endpoints.size() !=
+ if (subgraph_options.callable_options.fetch_size() !=
client_graph->fetch_types.size()) {
return errors::Internal(
"Graph pruning failed: requested number of fetch endpoints = ",
- subgraph_options.fetch_endpoints.size(),
+ subgraph_options.callable_options.fetch_size(),
" versus number of pruned fetch endpoints = ",
client_graph->fetch_types.size());
}
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc
index 6370bb5028..c1542f1f57 100644
--- a/tensorflow/core/common_runtime/eval_const_tensor.cc
+++ b/tensorflow/core/common_runtime/eval_const_tensor.cc
@@ -128,12 +128,16 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ if (IsMerge(&target_node)) {
+ return Status::OK();
+ }
+
if (target_node.type_string() == "PlaceholderWithDefault") {
return Status::OK();
}
- // TODO(skyewm): more of the filtering applied in input nodes below should be
- // applied to target_node here
+ // TODO(skyewm): should more of the filtering applied in input nodes below be
+ // applied to target_node here?
// Identify the possibly constant subgraph by recursively iterating backwards
// through the inputs to 'target_node' until we either 1) find an already
@@ -153,11 +157,8 @@ Status ExtractConstantSubgraph(
// Add the target node's inputs to seed the recursion.
std::deque<const Edge*> edges_to_visit;
for (const Edge* e : target_node.in_edges()) {
- // TODO(vrv): What do we do about control edges? Based on our
- // definition of a constant graph, we should be free to ignore
- // control edges since the order in which a constant graph is
- // executed should be the same regardless of when nodes run: we
- // should only need to recurse down data edges.
+ // TODO(skyewm): control edges will be meaningful if/when we handle control
+ // flow (e.g. constants in cond branches are triggered via control edges).
if (e->IsControlEdge()) continue;
edges_to_visit.push_back(e);
}
@@ -177,7 +178,9 @@ Status ExtractConstantSubgraph(
}
// During construction or import from GraphConstructor, back edges may not
- // be filled in. Don't constant fold through merges at all for now.
+ // be filled in. In addition, control flow constructs may depend on control
+ // edges which aren't handled by this method. Don't constant fold through
+ // merges at all for now.
if (IsMerge(current_node)) {
*is_constant_graph = false;
return Status::OK();
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 785ec3d227..f5e3d78242 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -252,8 +252,8 @@ Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
// Rewrite the graph before placement.
rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
- new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
- options.target_nodes, device_set_->client_device()->attributes(),
+ new_graph.get(), options.callable_options,
+ device_set_->client_device()->attributes(),
options.use_function_convention, rewrite_metadata_.get()));
}
@@ -299,13 +299,16 @@ Status GraphExecutionState::OptimizeGraph(
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
- item.fetch = options.fetch_endpoints;
- item.fetch.insert(item.fetch.end(), options.target_nodes.begin(),
- options.target_nodes.end());
+ item.fetch.insert(item.fetch.end(),
+ options.callable_options.fetch().begin(),
+ options.callable_options.fetch().end());
+ item.fetch.insert(item.fetch.end(),
+ options.callable_options.target().begin(),
+ options.callable_options.target().end());
- if (!options.feed_endpoints.empty()) {
+ if (!options.callable_options.feed().empty()) {
std::unordered_set<string> feeds;
- for (const string& feed : options.feed_endpoints) {
+ for (const string& feed : options.callable_options.feed()) {
TensorId id = ParseTensorName(feed);
if (id.second != 0) {
return errors::InvalidArgument("Unsupported feed: ", feed);
@@ -404,8 +407,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
// Extract the subset of the graph that needs to be run, adding feed/fetch
// ops as needed.
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
- ng.get(), options.feed_endpoints, options.fetch_endpoints,
- options.target_nodes, device_set_->client_device()->attributes(),
+ ng.get(), options.callable_options,
+ device_set_->client_device()->attributes(),
options.use_function_convention, &rewrite_metadata));
} else {
// This GraphExecutionState represents a graph that was
@@ -415,8 +418,10 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
rewrite_metadata = *rewrite_metadata_;
}
- CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
- CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
+ CHECK_EQ(options.callable_options.feed_size(),
+ rewrite_metadata.feed_types.size());
+ CHECK_EQ(options.callable_options.fetch_size(),
+ rewrite_metadata.fetch_types.size());
// Make a fresh copy of the function library for the client graph.
std::unique_ptr<FunctionLibraryDefinition> flib(
diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc
index 090a16ebeb..116750fbfd 100644
--- a/tensorflow/core/common_runtime/memory_types.cc
+++ b/tensorflow/core/common_runtime/memory_types.cc
@@ -92,7 +92,7 @@ static Status ProcessMemoryTypes(
Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
return ProcessMemoryTypes(
- device_type, g, [g](const Edge* e, MemoryType sm, MemoryType dm) {
+ device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
if (sm == dm) {
return Status::OK();
}
@@ -155,7 +155,7 @@ Status EnsureMemoryTypes(const DeviceType& device_type,
};
std::vector<Item> edges;
TF_RETURN_IF_ERROR(ProcessMemoryTypes(
- device_type, g, [g, &edges](const Edge* e, MemoryType sm, MemoryType dm) {
+ device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
if (sm == dm) {
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 9768a244f2..8447c55bf4 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -438,7 +438,7 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
StartParallelExecutors(handle, step_id, item, rendezvous, collector,
cost_graph, cancellation_manager,
- [this, item, rendezvous, done](const Status& s) {
+ [item, rendezvous, done](const Status& s) {
done(s);
rendezvous->Unref();
item->Unref();
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 878a1398c9..01da54fcb3 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
- debug_opts_(bopts.debug_options),
+ debug_opts_(bopts.callable_options.run_options().debug_options()),
worker_cache_(worker_cache),
should_deregister_(should_deregister) {
VLOG(1) << "Created ReffedClientGraph for node with "
@@ -921,61 +921,70 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
}
}
+namespace {
+void CopyAndSortStrings(size_t size,
+ const std::function<string(size_t)>& input_accessor,
+ protobuf::RepeatedPtrField<string>* output) {
+ std::vector<string> temp;
+ temp.reserve(size);
+ for (size_t i = 0; i < size; ++i) {
+ output->Add(input_accessor(i));
+ }
+ std::sort(output->begin(), output->end());
+}
+} // namespace
+
void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
BuildGraphOptions* opts) {
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- opts->feed_endpoints.push_back(req.feed_name(i));
- }
- for (size_t i = 0; i < req.num_fetches(); ++i) {
- opts->fetch_endpoints.push_back(req.fetch_name(i));
- }
- for (size_t i = 0; i < req.num_targets(); ++i) {
- opts->target_nodes.push_back(req.target_name(i));
- }
+ CallableOptions* callable_opts = &opts->callable_options;
+ CopyAndSortStrings(req.num_feeds(),
+ [&req](size_t i) { return req.feed_name(i); },
+ callable_opts->mutable_feed());
+ CopyAndSortStrings(req.num_fetches(),
+ [&req](size_t i) { return req.fetch_name(i); },
+ callable_opts->mutable_fetch());
+ CopyAndSortStrings(req.num_targets(),
+ [&req](size_t i) { return req.target_name(i); },
+ callable_opts->mutable_target());
if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
- opts->debug_options = req.options().debug_options();
+ *callable_opts->mutable_run_options()->mutable_debug_options() =
+ req.options().debug_options();
}
-
- std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
- std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
- std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
BuildGraphOptions* opts) {
- for (const auto& feed : req.feed()) {
- opts->feed_endpoints.push_back(feed);
- }
- for (const auto& fetch : req.fetch()) {
- opts->fetch_endpoints.push_back(fetch);
- }
- for (const auto& target : req.target()) {
- opts->target_nodes.push_back(target);
- }
+ CallableOptions* callable_opts = &opts->callable_options;
+ CopyAndSortStrings(req.feed_size(), [&req](size_t i) { return req.feed(i); },
+ callable_opts->mutable_feed());
+ CopyAndSortStrings(req.fetch_size(),
+ [&req](size_t i) { return req.fetch(i); },
+ callable_opts->mutable_fetch());
+ CopyAndSortStrings(req.target_size(),
+ [&req](size_t i) { return req.target(i); },
+ callable_opts->mutable_target());
// TODO(cais): Add TFDBG support to partial runs.
-
- std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
- std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
- std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
uint64 h = 0x2b992ddfa23249d6ull;
- for (const string& name : opts.feed_endpoints) {
+ for (const string& name : opts.callable_options.feed()) {
h = Hash64(name.c_str(), name.size(), h);
}
- for (const string& name : opts.target_nodes) {
+ for (const string& name : opts.callable_options.target()) {
h = Hash64(name.c_str(), name.size(), h);
}
- for (const string& name : opts.fetch_endpoints) {
+ for (const string& name : opts.callable_options.fetch()) {
h = Hash64(name.c_str(), name.size(), h);
}
- if (!opts.debug_options.debug_tensor_watch_opts().empty()) {
- const string watch_summary = SummarizeDebugTensorWatches(
- opts.debug_options.debug_tensor_watch_opts());
+ const DebugOptions& debug_options =
+ opts.callable_options.run_options().debug_options();
+ if (!debug_options.debug_tensor_watch_opts().empty()) {
+ const string watch_summary =
+ SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
@@ -984,15 +993,15 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
string BuildGraphOptionsString(const BuildGraphOptions& opts) {
string buf;
- for (const string& name : opts.feed_endpoints) {
+ for (const string& name : opts.callable_options.feed()) {
strings::StrAppend(&buf, " FdE: ", name);
}
strings::StrAppend(&buf, "\n");
- for (const string& name : opts.target_nodes) {
+ for (const string& name : opts.callable_options.target()) {
strings::StrAppend(&buf, " TN: ", name);
}
strings::StrAppend(&buf, "\n");
- for (const string& name : opts.fetch_endpoints) {
+ for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
strings::StrAppend(&buf, "\n");
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 6345549367..598652fb98 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -215,7 +215,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
- auto finish = [this, done, out, opts](const Status& s) {
+ auto finish = [done, out, opts](const Status& s) {
opts->ClearCancelCallback();
delete out;
done(s);
@@ -247,7 +247,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
session->graph_mgr->ExecuteAsync(
graph_handle, step_id, session.get(), request->exec_opts(),
nullptr /* collector */, nullptr /* response */, cm, in,
- [this, token, step_id, session, cm](Status s) {
+ [this, token, step_id, session](Status s) {
{
mutex_lock l(mu_);
cancellation_manager_->DeregisterCallback(token);
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 2a08bf8ca0..ca93d049d0 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -323,6 +323,25 @@ Status RewriteGraphForExecution(
return Status::OK();
}
+namespace {
+template <typename StringContainer>
+std::vector<string> ConvertToVector(StringContainer field) {
+ return std::vector<string>(field.begin(), field.end());
+}
+} // namespace
+
+Status RewriteGraphForExecution(Graph* g,
+ const CallableOptions& callable_options,
+ const DeviceAttributes& device_info,
+ bool use_function_convention,
+ RewriteGraphMetadata* out_metadata) {
+ return RewriteGraphForExecution(g, ConvertToVector(callable_options.feed()),
+ ConvertToVector(callable_options.fetch()),
+ ConvertToVector(callable_options.target()),
+ device_info, use_function_convention,
+ out_metadata);
+}
+
} // namespace subgraph
} // namespace tensorflow
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
index 3c1f8870f5..0dc59582f4 100644
--- a/tensorflow/core/graph/subgraph.h
+++ b/tensorflow/core/graph/subgraph.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace subgraph {
@@ -70,6 +71,11 @@ Status RewriteGraphForExecution(
const gtl::ArraySlice<string>& target_node_names,
const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata);
+Status RewriteGraphForExecution(Graph* g,
+ const CallableOptions& callable_options,
+ const DeviceAttributes& device_info,
+ bool use_function_convention,
+ RewriteGraphMetadata* out_metadata);
typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index ca56833ef6..53c177befc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -217,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) {
return op == "NextIteration" || op == "RefNextIteration";
}
+bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
+
bool IsPad(const NodeDef& node) {
const auto& op = node.op();
return op == "Pad" || op == "PadV2";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index a0946ee1ad..cd5b464099 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -86,6 +86,7 @@ bool IsMod(const NodeDef& node);
bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
+bool IsPack(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e3ed1865f7..fe095a725a 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -543,6 +543,7 @@ tf_cc_test(
tags = [
"manual",
"no_oss", # b/74111495
+ "notap",
],
deps = [
":loop_optimizer",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 177b0735e9..c0fcfaf428 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -290,25 +290,30 @@ NodeDef* GetTailOfValuePreservingChain(
struct ArithmeticOptimizerContext {
ArithmeticOptimizerContext(
const std::unordered_set<string>* nodes_to_preserve,
- GraphDef* optimized_graph, NodeMap* node_map,
+ GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
node_map(node_map),
+ frame_map(frame_map),
nodes_to_simplify(nodes_to_simplify) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
NodeMap* node_map;
+ FrameMap* frame_map;
SetVector<NodeDef*>* nodes_to_simplify;
};
// Base class for single arithmetic optimization: e.g. Bitcast optimization,
// AddOps optimization, etc...
+// TODO(ezhulenev): extract this class to be reused by other multi-stage
+// graph optimizers (const_folding, dependency_optimizer, etc...)
class ArithmeticOptimizerStage {
public:
- explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx)
- : ctx_(ctx) {}
+ explicit ArithmeticOptimizerStage(const string& name,
+ const ArithmeticOptimizerContext& ctx)
+ : name_(name), ctx_(ctx) {}
virtual ~ArithmeticOptimizerStage() = default;
// Check if we should try to simplify node. Returning true doesn't
@@ -336,6 +341,46 @@ class ArithmeticOptimizerStage {
string* simplified_node_name) = 0;
protected:
+ struct ScopedNodeName {
+ string scope;
+ string name;
+ };
+
+ const ScopedNodeName ParseScopedNodeName(const string& name) const {
+ auto pos = name.find_last_of("/");
+ if (pos == string::npos) {
+ return {"", name};
+ } else {
+ return {name.substr(0, pos), name.substr(pos + 1)};
+ }
+ }
+
+ // Prefix optimized node name with stage name and rewrite_rule
+ const string OptimizedNodeName(const string& rewrite_rule,
+ const ScopedNodeName& scoped_node_name) const {
+ return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+ scoped_node_name);
+ }
+
+ // Prefix optimized node name with stage name and rewrite_rule
+ const string OptimizedNodeName(const string& rewrite_rule,
+ const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+ scoped_node_name, node_names);
+ }
+
+ // Prefix optimized node name with stage name
+ const string OptimizedNodeName(const ScopedNodeName& scoped_node_name) const {
+ return MakeOptimizedNodeName(name_, scoped_node_name);
+ }
+
+ // Prefix optimized node name with stage name
+ const string OptimizedNodeName(const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ return MakeOptimizedNodeName(name_, scoped_node_name, node_names);
+ }
+
// Simplification graph rewrite can create additional nodes that are inputs
// to final simplified node, they can be also added to the arithmetic
// optimizer queue for further optimization.
@@ -374,7 +419,91 @@ class ArithmeticOptimizerStage {
}
}
- ArithmeticOptimizerContext ctx_;
+ NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
+ CHECK(node_to_copy != nullptr);
+ CHECK(!ctx_.node_map->NodeExists(name))
+ << "Node " << name << " already exists in a graph";
+ NodeDef* new_node = ctx_.optimized_graph->add_node();
+ *new_node = *node_to_copy;
+ new_node->set_name(name);
+ ctx_.node_map->AddNode(name, new_node);
+ return new_node;
+ }
+
+ NodeDef* AddEmptyNode(const string& name) {
+ CHECK(!ctx_.node_map->NodeExists(name))
+ << "Node " << name << " already exists in a graph";
+ NodeDef* new_node = ctx_.optimized_graph->add_node();
+ new_node->set_name(name);
+ ctx_.node_map->AddNode(name, new_node);
+ return new_node;
+ }
+
+ // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
+ // optimizations will be migrated to stages
+ void AddFrameControlDeps(const NodeDef* old_node,
+ const std::vector<NodeDef*>& new_nodes,
+ const string& source_for_ctrl_dep,
+ const std::vector<NodeDef*>& sinks_for_control_dep) {
+ const auto frame_it = ctx_.frame_map->find(old_node);
+ if (frame_it != ctx_.frame_map->end()) {
+ for (auto node : new_nodes) {
+ ctx_.frame_map->emplace(node, frame_it->second);
+ }
+ if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
+ for (auto node : sinks_for_control_dep) {
+ MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
+ ctx_.node_map);
+ }
+ }
+ }
+ }
+
+ const string name_;
+ const ArithmeticOptimizerContext ctx_;
+
+ private:
+ // Get a name for a new node obtained by optimizing a single node of the
+ // original graph. The optimized node is placed under the original node scope.
+ //
+ // Node name uniqueness is guaranteed by unique name of an original node in
+ // a same scope.
+ //
+ // Example: MakeOptimizedNodeName("AwesomeRewrite", "a/b/c/Add_1")
+ // Optimized name: "a/b/c/ArithmeticOptimizer/AwesomeRewrite_Add_1"
+ const string MakeOptimizedNodeName(
+ const string& prefix, const ScopedNodeName& scoped_node_name) const {
+ string node_name;
+ strings::StrAppend(&node_name, scoped_node_name.scope);
+ if (!node_name.empty()) strings::StrAppend(&node_name, "/");
+ strings::StrAppend(&node_name, kArithmeticOptimizer, "/", prefix, "_",
+ scoped_node_name.name);
+ return node_name;
+ }
+
+ // Get a name for a new node obtained by optimizing multiple nodes of the
+ // original graph, starting from "root". The optimized node is placed under
+ // the original scope of a "root" node.
+ //
+ // Node name uniqueness is guaranteed by unique name of a "root" node in
+ // a same scope.
+ //
+ // Example:
+ // MakeOptimizedNodeName("AwesomeRewrite", "a/b/Add_AB", ["x/y/Add_XY"])
+ // Optimized name:
+ // "a/b/ArithmeticOptimizer/AwesomeRewrite_Add_AB_Add_XY"
+ const string MakeOptimizedNodeName(
+ const string& prefix, const ScopedNodeName& scoped_node_name,
+ const std::vector<string>& node_names) const {
+ string node_name = MakeOptimizedNodeName(prefix, scoped_node_name);
+ for (const string& optimized : node_names) {
+ auto scoped_node = ParseScopedNodeName(optimized);
+ strings::StrAppend(&node_name, "_", scoped_node.name);
+ }
+ return node_name;
+ }
};
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
@@ -393,8 +522,8 @@ class ArithmeticOptimizerStage {
// q e
class AddOpsRewriteStage : public ArithmeticOptimizerStage {
public:
- explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {}
+ explicit AddOpsRewriteStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("AddOpsRewrite", ctx), rewritten_nodes_() {}
~AddOpsRewriteStage() override = default;
@@ -422,7 +551,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
AddOpsGroup group;
TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
- if (!group.absorbed_nodes.empty()) {
+ if (!group.absorbed_nodes.empty() && !IsRewritten(group)) {
*simplified_node_name = RewriteAddOpsGroup(group);
}
@@ -530,6 +659,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
DrivesControlDependency(*node));
}
+ // Check that optimized group node name doesn't exists. It might happen if
+ // graph optimized multiple times without pruning beween invocations.
+ bool IsRewritten(const AddOpsGroup& group) const {
+ return ctx_.node_map->NodeExists(AddOpsGroupName(group));
+ }
+
// Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
group->root_node = root_node;
@@ -559,39 +694,23 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
return Status::OK();
}
- const std::pair<string, string> ParseNodeScopeAndName(const string& name) {
- auto pos = name.find_last_of("/");
- if (pos == string::npos) {
- return {"", name};
- } else {
- return {name.substr(0, pos), name.substr(pos + 1)};
- }
- }
-
// New node for AddOpsGroup is added to the same scope as a root_node. All
// absorbed nodes are stripped of their scope, and only names are used in a
// new node name.
//
// Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
// node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
- string AddOpsGroupName(const AddOpsGroup& group) {
+ string AddOpsGroupName(const AddOpsGroup& group) const {
CHECK_NOTNULL(group.root_node);
- string node_name;
- auto root_node = ParseNodeScopeAndName(group.root_node->name());
- auto root_scope = root_node.first;
- auto root_name = root_node.second;
- if (!root_scope.empty()) {
- strings::StrAppend(&node_name, root_scope, "/");
- }
+ auto root = ParseScopedNodeName(group.root_node->name());
- strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_",
- root_name);
- for (const NodeDef* absorbed : group.absorbed_nodes) {
- auto absorbed_node = ParseNodeScopeAndName(absorbed->name());
- strings::StrAppend(&node_name, "_", absorbed_node.second);
- }
- return node_name;
+ std::vector<string> absorbed_node_names(group.absorbed_nodes.size());
+ std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(),
+ absorbed_node_names.begin(),
+ [](const NodeDef* node) { return node->name(); });
+
+ return OptimizedNodeName(root, absorbed_node_names);
}
// Create a new node for a AddOpsGroup and return it's name.
@@ -605,18 +724,17 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// copy attributes from a root node
DataType dtype = group.root_node->attr().at("T").type();
- // add new node
- NodeDef* added_node = ctx_.optimized_graph->add_node();
- added_node->set_name(node_name);
+ // add new AddN node
+ NodeDef* added_node = AddEmptyNode(node_name);
added_node->set_op("AddN");
added_node->set_device(group.root_node->device());
(*added_node->mutable_attr())["T"].set_type(dtype);
(*added_node->mutable_attr())["N"].set_i(group.inputs.size());
- ctx_.node_map->AddNode(node_name, added_node);
- for (string input : group.inputs) {
+ // all inputs of absorbed nodes are added to the new node
+ for (const string& input : group.inputs) {
ctx_.node_map->AddOutput(input, node_name);
- added_node->add_input(std::move(input));
+ added_node->add_input(input);
}
VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
@@ -635,11 +753,167 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
std::unordered_set<string> rewritten_nodes_;
};
+// Use the commutativity and (left- and right-) distributive property of
+// multiplication over addition to hoist common factors out of aggregate nodes
+// where all the inputs are Mul nodes. This pattern occurs frequently in
+// regularization terms for the gradients during training.
+//
+// For example, we can rewrite an expression of the form:
+// AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+// to the following:
+// Mul(x, AddN(y1, y2, y3, ... yn))
+class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
+ public:
+ explicit HoistCommonFactorOutOfAggregation(
+ const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("HoistCommonFactor", ctx) {}
+ ~HoistCommonFactorOutOfAggregation() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
+ !IsRewritten(node);
+ }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node));
+
+ std::set<string> common_factors;
+ TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+
+ if (common_factors.size() == 1) {
+ const string& common_factor = *common_factors.begin();
+
+ // Gather up the non-shared factors
+ bool shapes_match = true;
+ std::vector<string> unique_factors;
+ TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match,
+ &unique_factors));
+
+ if (shapes_match) {
+ NodeDef* input_0;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
+
+ // Use a copy of the first Mul node for the outer multiplication.
+ NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0);
+ // And a copy of aggregation node as one of the inner operands
+ NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
+
+ new_mul_node->set_device(node->device());
+ new_mul_node->set_input(0, common_factor);
+ new_mul_node->set_input(1, new_add_node->name());
+
+ ctx_.node_map->AddOutput(common_factor, new_mul_node->name());
+ ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name());
+
+ // Hoist non-shared factors up into the new AddN node.
+ for (int i = 0; i < unique_factors.size(); ++i) {
+ new_add_node->set_input(i, unique_factors[i]);
+ }
+
+ // Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+ {new_add_node});
+
+ // optimize new inner aggregation node
+ AddToOptimizationQueue(new_add_node);
+ // do not optimize the same node twice
+ rewritten_nodes_.insert(node->name());
+ *simplified_node_name = new_mul_node->name();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Get a name for new outer Mul node
+ string OuterMulNodeName(const NodeDef* node) const {
+ auto scoped_node = ParseScopedNodeName(node->name());
+ return OptimizedNodeName("Mul", scoped_node);
+ }
+
+ // Get a name new inner Add node
+ string InnerAddNodeName(const NodeDef* node) const {
+ auto scoped_node = ParseScopedNodeName(node->name());
+ return OptimizedNodeName("Add", scoped_node);
+ }
+
+ // Determine the set of common factors if the input nodes are all Mul nodes.
+ Status GetCommonFactors(const NodeDef* node,
+ std::set<string>* common_factors) const {
+ CHECK(common_factors->empty());
+
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (i > 0 && common_factors->empty()) break;
+ if (IsControlInput(node->input(i))) break;
+
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
+
+ if (!IsMul(*input)) {
+ common_factors->clear();
+ break;
+ }
+
+ std::set<string> factors_i{input->input(0), input->input(1)};
+ if (i == 0) {
+ std::swap(*common_factors, factors_i);
+ } else {
+ std::set<string> intersection;
+ std::set_intersection(
+ factors_i.begin(), factors_i.end(), common_factors->begin(),
+ common_factors->end(),
+ std::inserter(intersection, intersection.begin()));
+ std::swap(*common_factors, intersection);
+ }
+ }
+ return Status::OK();
+ }
+
+ // Gather up the non-shared factors (the y's in the example).
+ // Unless the aggregation is Add, we have to make sure that all the y's
+ // have the same shape since the other aggregation ops do not support
+ // broadcasting.
+ Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
+ bool* shapes_match,
+ std::vector<string>* unique_factors) const {
+ *shapes_match = true;
+ unique_factors->reserve(node->input_size());
+
+ for (int i = 0; i < node->input_size() && shapes_match; ++i) {
+ const string& input = node->input(i);
+ if (IsControlInput(input)) {
+ break;
+ }
+ NodeDef* mul_node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node));
+ const int unique_factor_index =
+ mul_node->input(0) == common_factor ? 1 : 0;
+ unique_factors->push_back(mul_node->input(unique_factor_index));
+ if (i > 0 && !IsAdd(*node)) {
+ *shapes_match = ShapesEqual(unique_factors->front(),
+ unique_factors->back(), *ctx_.node_map);
+ }
+ }
+ return Status::OK();
+ }
+
+ bool IsRewritten(const NodeDef* node) const {
+ // if graph rewrite happens in multiple passes without graph pruning between
+ // them, it's possible that rewritten node already exists in a graph
+ return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
+ ctx_.node_map->NodeExists(OuterMulNodeName(node));
+ }
+
+ // keep names of the nodes that were optimized by this stage
+ std::unordered_set<string> rewritten_nodes_;
+};
+
// Removes inverse transpose nodes
class RemoveInverseTranspose : public ArithmeticOptimizerStage {
public:
- explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveInverseTranspose(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveInverseTranspose", ctx) {}
~RemoveInverseTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
@@ -702,8 +976,8 @@ class RemoveInverseTranspose : public ArithmeticOptimizerStage {
// 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
public:
- explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveRedundantBitcastStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx) {}
~RemoveRedundantBitcastStage() override = default;
bool IsSupported(const NodeDef* node) const override {
@@ -742,8 +1016,8 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
// Remove Casts whose source type and destination type are equal.
class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
public:
- explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
- : ArithmeticOptimizerStage(ctx) {}
+ explicit RemoveRedundantCastStage(const ArithmeticOptimizerContext& ctx)
+ : ArithmeticOptimizerStage("RemoveRedundantCast", ctx) {}
~RemoveRedundantCastStage() override = default;
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
@@ -1276,98 +1550,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
- // Use the commutativity and (left- and right-) distributive property of
- // multiplication over addition to hoist common factors out of aggregate nodes
- // where all the inputs are Mul nodes. This pattern occurs frequently in
- // regularization terms for the gradients during training.
- // For example, we can rewrite an expression of the form:
- // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
- // to the following:
- // Mul(x, AddN(y1, y2, y3, ... yn))
- if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
- !OptimizedNodeExists(*node, "hoist_add") &&
- !OptimizedNodeExists(*node, "hoist_mul")) {
- // Determine the set of common factors if the input nodes are all Mul nodes.
- std::set<string> common_factors;
- for (int i = 0; i < node->input_size(); ++i) {
- if (i > 0 && common_factors.empty()) {
- break;
- }
- if (IsControlInput(node->input(i))) {
- break;
- }
- const NodeDef* input = node_map_->GetNode(node->input(i));
- if (input->op() == "Mul") {
- std::set<string> factors_i{input->input(0), input->input(1)};
- if (i == 0) {
- std::swap(common_factors, factors_i);
- } else {
- std::set<string> intersection;
- std::set_intersection(
- factors_i.begin(), factors_i.end(), common_factors.begin(),
- common_factors.end(),
- std::inserter(intersection, intersection.begin()));
- std::swap(common_factors, intersection);
- }
- } else {
- common_factors.clear();
- }
- }
- if (common_factors.size() == 1) {
- const string& common_factor = *common_factors.begin();
-
- // Gather up the non-shared factors (the y's in the example).
- // Unless the aggregation is Add, we have to make sure that all the y's
- // have the same shape since the other aggregation ops do not support
- // broadcasting.
- std::vector<string> unique_factors;
- unique_factors.reserve(node->input_size());
- bool shapes_match = true;
- for (int i = 0; i < node->input_size() && shapes_match; ++i) {
- const string& input = node->input(i);
- if (IsControlInput(input)) {
- break;
- }
- const NodeDef* mul_node = node_map_->GetNode(input);
- const int unique_factor_index =
- mul_node->input(0) == common_factor ? 1 : 0;
- unique_factors.push_back(mul_node->input(unique_factor_index));
- if (i > 0 && !IsAdd(*node)) {
- shapes_match = ShapesEqual(unique_factors.front(),
- unique_factors.back(), *node_map_);
- }
- }
-
- if (shapes_match) {
- // 1. Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"),
- node_map_->GetNode(node->input(0)));
- NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true);
- new_mul_node->set_device(node->device());
- new_mul_node->set_input(0, common_factor);
- node_map_->AddOutput(common_factor, new_mul_node->name());
- new_mul_node->set_input(1, new_add_node->name());
- node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
-
- // 2. Hoist non-shared factors up into the new AddN node.
- nodes_to_simplify->PushBack(new_add_node);
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input = node->input(i);
- if (IsControlInput(input)) {
- break;
- }
- new_add_node->set_input(i, unique_factors[i]);
- }
-
- // 3. Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
-
- return new_mul_node->name();
- }
- }
- }
-
// Fold Transpose into matrix multiplication.
if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
node->op() == "BatchMatMul") &&
@@ -1444,8 +1626,9 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
}
- ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- node_map_.get(), &nodes_to_simplify);
+ const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
+ node_map_.get(), &frame_map_,
+ &nodes_to_simplify);
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
@@ -1453,6 +1636,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
stages.push_back(
std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
}
+ if (options_.hoist_common_factor_out_of_aggregation) {
+ stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+ new HoistCommonFactorOutOfAggregation(ctx)));
+ }
if (options_.remove_inverse_transpose) {
stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
new RemoveInverseTranspose(ctx)));
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 787084454d..d5a7af5ba6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -56,6 +56,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
bool combine_add_to_addn = true;
+ bool hoist_common_factor_out_of_aggregation = true;
bool remove_inverse_transpose = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 98842b29f1..e1f47625c1 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -30,6 +30,22 @@ namespace grappler {
namespace {
+constexpr char kHoistFactorOptimizerMul[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Mul_";
+
+constexpr char kHoistFactorOptimizerAdd[] =
+ "ArithmeticOptimizer/HoistCommonFactor_Add_";
+
+// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation
+string HoistMulName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
+}
+
+// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
+string HoistAddName(const string& name) {
+ return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
+}
+
string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
@@ -61,22 +77,40 @@ class ArithmeticOptimizerTest : public GrapplerTest {
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
+ // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+ void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+ GraphDef* output) {
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ item->graph.Swap(output);
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ }
+
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.combine_add_to_addn = false;
+ options.hoist_common_factor_out_of_aggregation = false;
options.remove_inverse_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
optimizer->options_ = options;
}
+ void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+ optimizer->options_.combine_add_to_addn = false;
+ }
+
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.combine_add_to_addn = true;
}
+ void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.hoist_common_factor_out_of_aggregation = true;
+ }
+
void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_inverse_transpose = true;
@@ -396,59 +430,66 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
}
ArithmeticOptimizer optimizer;
- DisableAllStages(&optimizer);
+ DisableAddToAddNCombining(&optimizer);
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(17, output.node_size());
- // The graph gets optimized to
+ // We expect the following rewrite(s) to occur:
+ //
// Mul(p,
- // Add(Add(Const(2), Const(2)),
- // Add(Const(2), Const(2))))
+ // Add_6(Add_4(Const(2), Const(2)),
+ // Add_5(Const(2), Const(2))))
+ NodeMap node_map(&output);
+
EXPECT_EQ(17, output.node_size());
- for (const auto& node : output.node()) {
- if ("id" == node.name()) {
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0));
- } else if (OptimizedName("Add_6_hoist_mul") == node.name()) {
- EXPECT_EQ("Mul", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("Placeholder", node.input(0));
- EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1));
- } else if (OptimizedName("Add_6_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_4_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_5_hoist_add") == node.name()) {
- EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
- EXPECT_EQ("^Placeholder", node.input(2));
- } else if (OptimizedName("Add_const") == node.name()) {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^Placeholder", node.input(0));
- } else if (OptimizedName("Add_1_const") == node.name()) {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^Placeholder", node.input(0));
- }
- }
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr);
+ EXPECT_EQ(1, id_node->input_size());
+ EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
+
+ const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
+ ASSERT_TRUE(mul_node != nullptr);
+ EXPECT_EQ(2, mul_node->input_size());
+ EXPECT_EQ("Placeholder", mul_node->input(0));
+ EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
+
+ const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
+ ASSERT_TRUE(add_6_node != nullptr);
+ EXPECT_EQ(3, add_6_node->input_size());
+ EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
+ EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
+ EXPECT_EQ("^Placeholder", add_6_node->input(2));
+
+ const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
+ ASSERT_TRUE(add_4_node != nullptr);
+ EXPECT_EQ("Add", add_4_node->op());
+ EXPECT_EQ(3, add_4_node->input_size());
+ EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
+ EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
+ EXPECT_EQ("^Placeholder", add_4_node->input(2));
+
+ const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
+ ASSERT_TRUE(add_5_node != nullptr);
+ EXPECT_EQ("Add", add_5_node->op());
+ EXPECT_EQ(3, add_5_node->input_size());
+ EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
+ EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
+ EXPECT_EQ("^Placeholder", add_5_node->input(2));
+
+ const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
+ ASSERT_TRUE(add_const_node != nullptr);
+ EXPECT_EQ("Const", add_const_node->op());
+ EXPECT_EQ(1, add_const_node->input_size());
+ EXPECT_EQ("^Placeholder", add_const_node->input(0));
+
+ const NodeDef* add_1_const_node =
+ node_map.GetNode(OptimizedName("Add_1_const"));
+ ASSERT_TRUE(add_1_const_node != nullptr);
+ EXPECT_EQ("Const", add_1_const_node->op());
+ EXPECT_EQ(1, add_1_const_node->input_size());
+ EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
}
TEST_F(ArithmeticOptimizerTest, HoistFactor) {
@@ -469,31 +510,46 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
ops::Add(s.WithOpName("add"), mul1, mul2));
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCommonFactor(&optimizer);
+
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // Add Mul
+ // / \ / \
+ // Mul Mul -> x Add
+ // / \ / \ / \
+ // x y1 y2 x y1 y2
+ //
+ // If "root" op is AddN and shapes does not match, this rewrite is not
+ // possible and graph should stay intact.
+ NodeMap node_map(&output);
if (use_addn && !matching_shapes) {
VerifyGraphsMatch(item.graph, output, __LINE__);
} else {
EXPECT_EQ(9, output.node_size());
- const NodeDef& new_add = output.node(8);
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
- EXPECT_EQ("y1", new_add.input(0));
- EXPECT_EQ("y2", new_add.input(1));
- const NodeDef& new_mul = output.node(7);
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
- EXPECT_EQ("x", new_mul.input(0));
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
- const NodeDef& new_id = output.node(6);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+
+ const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+ ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+ EXPECT_EQ("y1", new_add_node->input(0));
+ EXPECT_EQ("y2", new_add_node->input(1));
+
+ const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
+ ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found";
+ EXPECT_EQ("x", new_mul_node->input(0));
+ EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
+
+ const NodeDef* id_node = node_map.GetNode("id");
+ ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+ EXPECT_EQ("id", id_node->name());
+ EXPECT_EQ(HoistMulName("add"), id_node->input(0));
}
}
}
@@ -1249,8 +1305,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
NodeMap node_map(&output);
// check add tree was replaced with AddN
- const NodeDef* collapsed_add = CHECK_NOTNULL(
- node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+ const NodeDef* collapsed_add =
+ node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+ ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(3, collapsed_add->input_size());
@@ -1259,7 +1316,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
EXPECT_EQ("c", collapsed_add->input(2));
// check output was re-wired to new node
- const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs"));
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_TRUE(updated_outputs != nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
@@ -1306,8 +1364,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
NodeMap node_map(&output);
// check left Add subtree replaced with AddN
- const NodeDef* collapsed_left = CHECK_NOTNULL(
- node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+ const NodeDef* collapsed_left =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+ ASSERT_TRUE(collapsed_left != nullptr);
EXPECT_EQ("AddN", collapsed_left->op());
EXPECT_EQ(3, collapsed_left->input_size());
@@ -1316,8 +1375,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
EXPECT_EQ("c", collapsed_left->input(2));
// check right Add subtree replaced with AddN
- const NodeDef* collapsed_right = CHECK_NOTNULL(
- node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy"));
+ const NodeDef* collapsed_right =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy");
+ ASSERT_TRUE(collapsed_right != nullptr);
EXPECT_EQ("AddN", collapsed_right->op());
EXPECT_EQ(3, collapsed_right->input_size());
@@ -1326,7 +1386,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
EXPECT_EQ("z", collapsed_right->input(2));
// check that Mul inputs re-wired to new Nodes
- const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul"));
+ const NodeDef* updated_mul = node_map.GetNode("Mul");
+ ASSERT_TRUE(updated_mul != nullptr);
EXPECT_EQ("Mul", updated_mul->op());
EXPECT_EQ(2, updated_mul->input_size());
@@ -1367,8 +1428,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
NodeMap node_map(&output);
// check Add tree replaced with AddN
- const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode(
- "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc"));
+ const NodeDef* collapsed_add = node_map.GetNode(
+ "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc");
+ ASSERT_TRUE(collapsed_add != nullptr);
EXPECT_EQ("AddN", collapsed_add->op());
EXPECT_EQ(4, collapsed_add->input_size());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 39cc4a9629..21037ff794 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -244,44 +244,41 @@ string ConstantFolding::AddControlDependency(const string& input_name,
}
}
-Status ConvertShapeToConstant(const string& op, const DataType& type,
- const PartialTensorShape& shp, Tensor* value) {
+// Puts the given value into the tensor at the given "flat" index.
+static Status PutValueIntoTensor(const int64 value, const DataType& type,
+ const int index, Tensor* tensor) {
+ if (type == DT_INT32) {
+ if (value >= INT_MAX) {
+ return Status(error::INVALID_ARGUMENT, "int32 overflow");
+ }
+ tensor->flat<int32>()(index) = static_cast<int32>(value);
+ } else {
+ tensor->flat<int64>()(index) = value;
+ }
+ return Status::OK();
+}
+
+// Writes the given tensor shape into the given tensor.
+// Op is assumed to be Shape, ShapeN, Size or Rank.
+static Status ConvertShapeToConstant(const string& op, const DataType& type,
+ const PartialTensorShape& shp,
+ Tensor* tensor) {
if (op == "Shape" || op == "ShapeN") {
- *value = Tensor(type, TensorShape({shp.dims()}));
+ *tensor = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
- if (type == DT_INT32) {
- if (shp.dim_size(i) >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(i) = shp.dim_size(i);
- } else {
- value->flat<int64>()(i) = shp.dim_size(i);
- }
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
}
} else if (op == "Size") {
int64 size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
- *value = Tensor(type, TensorShape({}));
- if (type == DT_INT32) {
- if (size >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(0) = size;
- } else {
- value->flat<int64>()(0) = size;
- }
+ *tensor = Tensor(type, TensorShape({}));
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
} else {
- *value = Tensor(type, TensorShape({}));
- if (type == DT_INT32) {
- if (shp.dims() >= INT_MAX) {
- return Status(error::INVALID_ARGUMENT, "Invalid dimension size");
- }
- value->flat<int32>()(0) = shp.dims();
- } else {
- value->flat<int64>()(0) = shp.dims();
- }
+ CHECK_EQ(op, "Rank");
+ *tensor = Tensor(type, TensorShape({}));
+ TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
}
return Status::OK();
}
@@ -306,13 +303,14 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end();
}
+// Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
- // We may add some nodes to the graph to encode control dependencies: there is
- // no need to process these, so only iterate over the nodes of the input
- // graph.
+ // We may add some nodes to the graph to encode control dependencies and hold
+ // the materialized shapes: there is no need to process these added nodes, so
+ // only iterate over the nodes of the input graph.
const int node_count = graph_->node_size();
- for (int i = 0; i < node_count; ++i) {
- NodeDef* node = graph_->mutable_node(i);
+ for (int node_idx = 0; node_idx < node_count; ++node_idx) {
+ NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
@@ -325,91 +323,109 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (input.empty() || output.empty()) {
continue;
}
+
if (op == "Shape" || op == "Size" || op == "Rank") {
CHECK_EQ(1, output.size());
CHECK_EQ(1, input.size());
+
+ const DataType type = output[0].dtype();
+ CHECK(type == DT_INT32 || type == DT_INT64);
+ const PartialTensorShape shape(input[0].shape());
+
+ if ((op != "Rank" && !shape.IsFullyDefined()) ||
+ (op == "Rank" && shape.unknown_rank())) {
+ continue;
+ }
+
+ Tensor constant_value(type);
+ if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
+ continue;
+ }
+
+ // Repurpose the existing node to be the constant.
+ // Device placement is preserved.
+ node->set_op("Const");
+ node->clear_attr();
+ (*node->mutable_attr())["dtype"].set_type(type);
+ constant_value.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+
+ // Turn the data input into a control dependency: this is needed to
+ // ensure that the constant value will only be run in the
+ // cases where the shape/rank/size would have been run in
+ // the original graph.
+ string ctrl_dep =
+ AddControlDependency(node->input(0), graph_, node_map_.get());
+ node->set_input(0, ctrl_dep);
+ node_map_->AddOutput(NodeName(ctrl_dep), node->name());
+
+ // Done with the Shape/Size/Rank node, move to the next node.
+ continue;
}
- CHECK_EQ(input.size(), output.size());
- for (int j = 0; j < output.size(); ++j) {
- const DataType type = output[j].dtype();
+ // Handle ShapeN materialization case.
+ // It's possible that not all input tensors have known shapes.
+ CHECK_EQ(op, "ShapeN");
+ CHECK_EQ(input.size(), output.size());
+ const NodeDef* const shape_n_node = node;
+ for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
+ const DataType type = output[port_idx].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
- const TensorShapeProto shape = input[j].shape();
- // Materialize the shapes using constants whenever possible.
- PartialTensorShape shp(shape);
- if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
- Tensor value(type);
- auto status = ConvertShapeToConstant(op, type, shp, &value);
- if (!status.ok()) {
- continue;
- }
- // We rewrite the existing node for the first const output and
- // create new nodes for the remaining const outputs (Note that ShapeN
- // could have multiple outputs).
- if (op == "Shape" || op == "Size" || op == "Rank") {
- // Replace the node with the corresponding constant.
- node->set_op("Const");
- node->clear_attr();
- (*node->mutable_attr())["dtype"].set_type(type);
- value.AsProtoTensorContent(
- (*node->mutable_attr())["value"].mutable_tensor());
-
- // Turn the data input into a control dependency: this is needed to
- // ensure that the constant value will only be run in the
- // cases where the shape/rank/size would have been run in
- // the original graph. Additional inputs are extra control
- string ctrl_dep =
- AddControlDependency(node->input(0), graph_, node_map_.get());
- node->set_input(0, ctrl_dep);
- node_map_->AddOutput(NodeName(ctrl_dep), node->name());
- } else {
- auto outputs = node_map_->GetOutputs(node->name());
- for (NodeDef* output : outputs) {
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node->name() && port == j) {
- // Create a const node as ShapeN's output if not already.
- const string const_name =
- OptimizedNodeName(*node, strings::StrCat("-matshapes-", j));
- if (node_map_->GetNode(const_name) == nullptr) {
- NodeDef* added_node = graph_->add_node();
- added_node->set_name(const_name);
- added_node->set_op("Const");
- added_node->set_device(node->device());
- node_map_->AddNode(added_node->name(), added_node);
- (*added_node->mutable_attr())["dtype"].set_type(type);
- value.AsProtoTensorContent(
- (*added_node->mutable_attr())["value"].mutable_tensor());
- // We add a control dependency to the original ShapeN node,
- // so that the node will only be run if all inputs of the
- // original ShapeN node are run.
- string ctrl_dep = AddControlDependency(node->name(), graph_,
- node_map_.get());
- *added_node->add_input() = ctrl_dep;
- node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
- }
- *output->mutable_input(k) = const_name;
- node_map_->AddOutput(const_name, output->name());
- }
- }
- bool remove_output = true;
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node->name()) {
- remove_output = false;
- break;
- }
- }
- if (remove_output) {
- node_map_->RemoveOutput(node->name(), output->name());
+ const PartialTensorShape shape(input[port_idx].shape());
+ if (!shape.IsFullyDefined()) {
+ continue;
+ }
+ Tensor constant_value(type);
+ auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
+ if (!status.ok()) {
+ continue;
+ }
+
+ // Find all nodes consuming this shape and connect them through the new
+ // constant node instead.
+ auto outputs = node_map_->GetOutputs(shape_n_node->name());
+ for (NodeDef* output : outputs) {
+ // Track whether there are any direct edges left between shape_n_node
+ // and this output node after the transformation.
+ bool direct_edges_exist = false;
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ const string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == shape_n_node->name() && port == port_idx) {
+ // Create a const node as ShapeN's output if not already.
+ const string const_name = OptimizedNodeName(
+ *shape_n_node, strings::StrCat("-matshapes-", port_idx));
+ if (node_map_->GetNode(const_name) == nullptr) {
+ NodeDef* added_node = graph_->add_node();
+ added_node->set_name(const_name);
+ added_node->set_op("Const");
+ added_node->set_device(shape_n_node->device());
+ node_map_->AddNode(added_node->name(), added_node);
+ (*added_node->mutable_attr())["dtype"].set_type(type);
+ constant_value.AsProtoTensorContent(
+ (*added_node->mutable_attr())["value"].mutable_tensor());
+ // We add a control dependency to the original ShapeN node,
+ // so that the node will only be run if all inputs of the
+ // original ShapeN node are run.
+ string ctrl_dep = AddControlDependency(shape_n_node->name(),
+ graph_, node_map_.get());
+ *added_node->add_input() = ctrl_dep;
+ node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
+ *output->mutable_input(k) = const_name;
+ node_map_->AddOutput(const_name, output->name());
+ }
+ if (node_name == shape_n_node->name() && port != port_idx) {
+ direct_edges_exist = true;
}
}
+ if (!direct_edges_exist) {
+ node_map_->RemoveOutput(node->name(), output->name());
+ }
}
}
}
+
return Status::OK();
}
@@ -1361,6 +1377,10 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
if (node.op() == "OnesLike") {
return true;
}
+ if (node.op() == "Fill") {
+ NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+ return values != nullptr && IsOnes(*values);
+ }
if (node.op() != "Const") {
return false;
}
@@ -1392,6 +1412,10 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
if (node.op() == "ZerosLike") {
return true;
}
+ if (node.op() == "Fill") {
+ NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+ return values != nullptr && IsZeros(*values);
+ }
if (!IsConstant(node)) {
return false;
}
@@ -1510,7 +1534,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
}
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties,
+ GraphProperties* properties,
bool use_shape_info) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < output->node_size(); ++i) {
@@ -1520,7 +1544,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
if (use_shape_info &&
(IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size 1)
bool replaceable = !shape.unknown_rank();
@@ -1529,14 +1553,15 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, output);
+ continue;
}
}
if (use_shape_info && IsSlice(*node) &&
- properties.GetInputProperties(node->name()).size() == 3) {
- const auto& input = properties.GetInputProperties(node->name())[0];
- const auto& b = properties.GetInputProperties(node->name())[1];
- const auto& s = properties.GetInputProperties(node->name())[2];
+ properties->GetInputProperties(node->name()).size() == 3) {
+ const auto& input = properties->GetInputProperties(node->name())[0];
+ const auto& b = properties->GetInputProperties(node->name())[1];
+ const auto& s = properties->GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
@@ -1569,13 +1594,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, output);
+ continue;
}
}
}
- if (IsTile(*node) &&
- properties.GetInputProperties(node->name()).size() == 2) {
- const auto& m = properties.GetInputProperties(node->name())[1];
+ if (use_shape_info && IsTile(*node) &&
+ properties->GetInputProperties(node->name()).size() == 2) {
+ const auto& m = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
@@ -1597,13 +1623,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, output);
+ continue;
}
}
}
- if (IsPad(*node) &&
- properties.GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties.GetInputProperties(node->name())[1];
+ if (use_shape_info && IsPad(*node) &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
@@ -1620,17 +1647,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, output);
+ continue;
}
}
}
if (use_shape_info && IsSqueeze(*node) &&
- !properties.GetInputProperties(node->name()).empty()) {
+ !properties->GetInputProperties(node->name()).empty()) {
// https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
// error to squeeze a dimension that is not 1, so we only need to check
// whether the input has > 1 size for each dimension.
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size > 1)
bool replaceable = !shape.unknown_rank();
@@ -1639,6 +1667,39 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, node, output);
+ continue;
+ }
+ }
+
+ if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+ !OptimizedNodeExists(*node, "_const_axis")) {
+ // Create constant axis node.
+ Tensor axis_t(DT_INT32, TensorShape({}));
+ NodeDef* axis_node = output->add_node();
+ axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+ const int axis = node->attr().at("axis").i();
+ if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+ !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
+ .ok()) {
+ continue;
+ }
+ VLOG(1) << "*** Rewriting trivial Pack node: " << node->DebugString();
+ // Add a control dependency to make sure axis_node is in the right frame.
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ node->input(0), graph_, node_map_.get());
+ axis_node->add_input(ctrl_dep);
+ axis_node->set_device(node->device());
+ node->set_op("ExpandDims");
+ if (node->attr().count("axis") != 0) {
+ node->mutable_attr()->erase("axis");
+ }
+ if (node->attr().count("N") != 0) {
+ node->mutable_attr()->erase("N");
+ }
+ (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+ node->add_input(axis_node->name());
+ if (node->input_size() > 2) {
+ node->mutable_input()->SwapElements(1, node->input_size() - 1);
}
}
@@ -1759,7 +1820,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
graph_modified_ = true;
continue;
}
- if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+ if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
@@ -1777,8 +1838,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties.HasInputProperties(node->name()) &&
- properties.HasOutputProperties(node->name())) {
+ properties->HasInputProperties(node->name()) &&
+ properties->HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
@@ -1786,14 +1847,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties.GetOutputProperties(node->name())[0].shape();
+ properties->GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties.GetInputProperties(node->name())[1].shape();
+ properties->GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
- const bool x_is_one = IsOnes(*x);
+ const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
@@ -1818,9 +1879,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
const TensorShapeProto& x_shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
- const bool y_is_one = IsOnes(*y);
+ const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
@@ -2139,7 +2200,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
TF_RETURN_IF_ERROR(FoldGraph(output));
node_map_.reset(new NodeMap(output));
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 2fd59c7f9c..13ecfcd281 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer {
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index f421a59989..cf151d4c4b 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -152,7 +152,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
}
TEST_F(ConstantFoldingTest, NeutralElement) {
- for (bool use_const : {true, false}) {
+ int kConst = 0;
+ int kLike = 1;
+ int kFill = 2;
+ for (int const_type : {kConst, kLike, kFill}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
@@ -164,11 +167,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
ops::Placeholder::Shape(TensorShape({2, 3})));
Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2})));
- Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
- : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
- Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
- : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
+ Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
+ Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
+ Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
+ Output zeros = const_type == kConst
+ ? zeros_const
+ : (const_type == kLike ? zeros_like : zeros_fill);
+ Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
+ Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
+ Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
+ Output ones = const_type == kConst
+ ? ones_const
+ : (const_type == kLike ? ones_like : ones_fill);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
@@ -201,6 +212,13 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ const string suffix =
+ (const_type == kConst ? "_const"
+ : (const_type == kLike ? "_like" : "_fill"));
+ const string zeros_name = strings::StrCat("zeros", suffix);
+ const string ones_name = strings::StrCat("ones", suffix);
+ const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
+ const string ctrl_ones_name = strings::StrCat("^ones", suffix);
EXPECT_EQ(28, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
@@ -208,19 +226,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
if (name == "mul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "mul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul4") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul5") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
@@ -232,23 +250,23 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "div1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "div2") {
EXPECT_EQ("Reciprocal", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "matmul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "matmul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "matmul3") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^a", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
EXPECT_EQ(0, t.float_val(0));
@@ -257,7 +275,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(2, t.tensor_shape().dim(1).size());
} else if (name == "matmul4") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^b", node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
@@ -268,11 +286,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add2") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "bias_add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
@@ -280,16 +298,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "bias_add2") {
// We don't eliminate this one, because it requires broadcasting.
EXPECT_EQ("BiasAdd", node.op());
- EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
} else if (name == "sub1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "sub2") {
EXPECT_EQ("Neg", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
}
const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
"mul6", "matmul1", "matmul2"};
@@ -1930,6 +1948,48 @@ TEST_F(ConstantFoldingTest, IdenticalN) {
EXPECT_EQ("^id_n", output.node(7).input(2));
}
+TEST_F(ConstantFoldingTest, TrivialPack) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output x =
+ ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
+ Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
+ auto stack =
+ ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
+ ops::Stack::Axis(1));
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch.push_back("stack");
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ for (const auto& node : output.node()) {
+ if (node.name() == "stack") {
+ EXPECT_EQ("stack", node.name());
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
+ EXPECT_EQ("^y", node.input(2));
+ } else if (node.name() == "ConstantFolding/stack_const_axis") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^x", node.input(0));
+ }
+ }
+
+ std::vector<string> fetch = {"stack"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 2a93dd679e..8f13c4a702 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -465,7 +465,6 @@ Status LoopOptimizer::LoopInvariantNodeMotion() {
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
TF_RETURN_IF_ERROR(RemoveStackOps(item.graph, optimized_graph));
-
optimized_graph_ = optimized_graph;
// Set up helper data structures.
@@ -475,6 +474,7 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
&frame_map_, &num_frames));
TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+ return Status::OK();
}
void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h
index b5944cd30b..c1b0321e4e 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h
@@ -47,7 +47,7 @@ class LoopOptimizer : public GraphOptimizer {
Status LoopInvariantNodeMotion();
Status FindInvariantNodes(NodeDef* node);
Status RevertInvariantNodes();
- Status MoveInvariantNodes(const int fname);
+ Status MoveInvariantNodes(const int frame_id);
Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs,
const int frame_id);
Status LINMHandleConst(NodeDef* node, const int num_outputs,
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 6fe3746a73..780f927a4f 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -867,7 +867,7 @@ class IteratorGetNextOp : public AsyncOpKernel {
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
thread_pool_->Schedule(std::bind(
- [this, ctx, iterator](DoneCallback done) {
+ [ctx, iterator](DoneCallback done) {
std::vector<Tensor> components;
bool end_of_sequence = false;
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
index fa67545a0d..bea3af98eb 100644
--- a/tensorflow/core/kernels/data_format_ops.cc
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -28,6 +28,15 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+namespace {
+inline functor::DataFormat FormatNameToEnum(const string& name) {
+ if (name == "NHWC") return functor::DataFormat::NHWC;
+ if (name == "NCHW") return functor::DataFormat::NCHW;
+ if (name == "HWNC") return functor::DataFormat::HWNC;
+ return functor::DataFormat::UNKNOWN;
+}
+} // namespace
+
template <typename Device, typename T>
class DataFormatDimMapOp : public OpKernel {
public:
@@ -69,12 +78,15 @@ class DataFormatVecPermuteOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
OP_REQUIRES(context,
(src_format == "NHWC" && dst_format == "NCHW") ||
- (src_format == "NCHW" && dst_format == "NHWC"),
+ (src_format == "NCHW" && dst_format == "NHWC") ||
+ (src_format == "NHWC" && dst_format == "HWNC") ||
+ (src_format == "HWNC" && dst_format == "NHWC"),
errors::InvalidArgument(strings::StrCat(
- "Current implementation only supports NCHW-to-NHWC and "
- "NHWC-to-NCHW format conversion; got source format ",
+ "Current implementation only supports NHWC<->NCHW and "
+ "NHWC<->HWNC conversion; got source format ",
src_format, " and destination format ", dst_format)));
- nhwc_to_nchw_ = (src_format == "NHWC") ? true : false;
+ src_format_ = FormatNameToEnum(src_format);
+ dst_format_ = FormatNameToEnum(dst_format);
}
void Compute(OpKernelContext* context) override {
@@ -106,11 +118,12 @@ class DataFormatVecPermuteOp : public OpKernel {
context->allocate_output(0, input.shape(), &output));
functor::DataFormatVecPermute<Device, T>()(
context->eigen_device<Device>(), input.flat<T>(), output->flat<T>(),
- nhwc_to_nchw_);
+ src_format_, dst_format_);
}
private:
- bool nhwc_to_nchw_;
+ functor::DataFormat src_format_;
+ functor::DataFormat dst_format_;
};
#define REGISTER_KERNEL(T) \
@@ -143,11 +156,12 @@ TF_CALL_int32(DECLARE_GPU_SPECS);
TF_CALL_int64(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPEC
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void DataFormatVecPermute<GPUDevice, T>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
- typename TTypes<T>::Vec y, bool nhwc_to_nchw); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatVecPermute<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+ typename TTypes<T>::Vec y, const DataFormat src_format, \
+ const DataFormat dst_format); \
extern template struct DataFormatVecPermute<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index bf704cc35c..d27415ed91 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -23,6 +23,13 @@ limitations under the License.
namespace tensorflow {
namespace functor {
+enum class DataFormat {
+ UNKNOWN = 0,
+ NHWC,
+ NCHW,
+ HWNC,
+};
+
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
@@ -97,15 +104,81 @@ struct VecPermuteNCHWToNHWC {
}
};
+template <typename T>
+struct VecPermuteNHWCToHWNC {
+ Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+ typename TTypes<T>::ConstFlat input) const {
+ Eigen::DSizes<Eigen::DenseIndex, 1> result;
+ result[0] = input.dimension(0);
+ return result;
+ }
+ template <typename Output, typename Device>
+ void eval(typename TTypes<T>::ConstFlat input, Output& output,
+ const Device& d) const {
+ if (input.size() == 8) {
+ output.template chip<0>(0).device(d) = input.template chip<0>(2);
+ output.template chip<0>(1).device(d) = input.template chip<0>(3);
+ output.template chip<0>(2).device(d) = input.template chip<0>(4);
+ output.template chip<0>(3).device(d) = input.template chip<0>(5);
+ output.template chip<0>(4).device(d) = input.template chip<0>(0);
+ output.template chip<0>(5).device(d) = input.template chip<0>(1);
+ output.template chip<0>(6).device(d) = input.template chip<0>(6);
+ output.template chip<0>(7).device(d) = input.template chip<0>(7);
+ } else {
+ output.template chip<0>(0).device(d) = input.template chip<0>(1);
+ output.template chip<0>(1).device(d) = input.template chip<0>(2);
+ output.template chip<0>(2).device(d) = input.template chip<0>(0);
+ output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ }
+ }
+};
+
+template <typename T>
+struct VecPermuteHWNCToNHWC {
+ Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
+ typename TTypes<T>::ConstFlat input) const {
+ Eigen::DSizes<Eigen::DenseIndex, 1> result;
+ result[0] = input.dimension(0);
+ return result;
+ }
+ template <typename Output, typename Device>
+ void eval(typename TTypes<T>::ConstFlat input, Output& output,
+ const Device& d) const {
+ if (input.size() == 8) {
+ output.template chip<0>(0).device(d) = input.template chip<0>(4);
+ output.template chip<0>(1).device(d) = input.template chip<0>(5);
+ output.template chip<0>(2).device(d) = input.template chip<0>(0);
+ output.template chip<0>(3).device(d) = input.template chip<0>(1);
+ output.template chip<0>(4).device(d) = input.template chip<0>(2);
+ output.template chip<0>(5).device(d) = input.template chip<0>(3);
+ output.template chip<0>(6).device(d) = input.template chip<0>(6);
+ output.template chip<0>(7).device(d) = input.template chip<0>(7);
+ } else {
+ output.template chip<0>(0).device(d) = input.template chip<0>(2);
+ output.template chip<0>(1).device(d) = input.template chip<0>(0);
+ output.template chip<0>(2).device(d) = input.template chip<0>(1);
+ output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ }
+ }
+};
+
// Functor used by DataFormatVecPermuteOp to do the computations.
template <typename Device, typename T>
struct DataFormatVecPermute {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
- typename TTypes<T>::Flat y, bool nhwc_to_nchw) {
- if (nhwc_to_nchw) {
+ typename TTypes<T>::Flat y, const DataFormat src_format,
+ const DataFormat dst_format) {
+ if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
- } else {
+ } else if (src_format == DataFormat::NCHW &&
+ dst_format == DataFormat::NHWC) {
y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
+ } else if (src_format == DataFormat::NHWC &&
+ dst_format == DataFormat::HWNC) {
+ y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
+ } else if (src_format == DataFormat::HWNC &&
+ dst_format == DataFormat::NHWC) {
+ y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
}
}
};
diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc
index b02a584d73..ddb7a606c1 100644
--- a/tensorflow/core/kernels/mutex_ops.cc
+++ b/tensorflow/core/kernels/mutex_ops.cc
@@ -127,7 +127,7 @@ class Mutex : public ResourceBase {
}
}
thread_pool_->Schedule(std::bind(
- [this, c, cm, cancelled,
+ [this, cm, cancelled,
token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
fn_) {
bool local_locked;
@@ -173,7 +173,7 @@ class MutexLockOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
c,
LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
- [this, c](Mutex** ptr) {
+ [c](Mutex** ptr) {
*ptr = new Mutex(
c, HandleFromInput(c, 0).name());
return Status::OK();
@@ -186,10 +186,10 @@ class MutexLockOp : public AsyncOpKernel {
mutex->AcquireAsync(
c, std::bind(
- [this, c, variant, mutex](DoneCallback done_,
- // End of bound arguments.
- const Status& s,
- Mutex::SharedLockReleaser&& lock) {
+ [c, variant, mutex](DoneCallback done_,
+ // End of bound arguments.
+ const Status& s,
+ Mutex::SharedLockReleaser&& lock) {
VLOG(2) << "Finished locking mutex " << mutex
<< " with lock: " << lock.shared_lock.get()
<< " status: " << s.ToString();
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index f254036ba7..aecad0185f 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -351,7 +351,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
Var* variable = nullptr;
OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
context, HandleFromInput(context, 0), &variable,
- [this, context](Var** ptr) {
+ [](Var** ptr) {
// Created on host.
*ptr = new Var(DT_VARIANT);
return Status::OK();
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index fe0a2782f9..d0703d7576 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,6 +24,13 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index 7cd4532ad6..4b5df7aff0 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -327,7 +327,7 @@ class SparseCrossOp : public OpKernel {
typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
output_start_indices, indices_out, values_out);
- auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
+ auto do_work = [&columns, crosser, updater](int64 begin, int64 end) {
for (int b = begin; b < end; b++) {
ProductIterator<InternalType> product_iterator(columns, b);
int64 cross_count = 0;
diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc
index 0ce0b552e6..5c19a45fb1 100644
--- a/tensorflow/core/kernels/split_v_op.cc
+++ b/tensorflow/core/kernels/split_v_op.cc
@@ -208,10 +208,10 @@ class SplitVOpCPUImpl {
input_element_count >= std::max(num_threads, num_split) * 4096 &&
input_element_count < num_split * 180 * 1024);
- auto range_output_func = [&indices, context, &input_shape, prefix_dim_size,
- split_dim, &split_sizes_vec, &split_start_points,
- suffix_dim_size, use_parallelism_between_outputs,
- &input_reshaped, &make_sizes,
+ auto range_output_func = [&indices, context, &input_shape, split_dim,
+ &split_sizes_vec, &split_start_points,
+ use_parallelism_between_outputs, &input_reshaped,
+ &make_sizes,
&reshape_result](int64 start, int64 limit) {
for (int64 i = start; i < limit; ++i) {
TensorShape output_shape(input_shape);
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 18b8bc5495..85dd1a423a 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -65413,6 +65413,59 @@ op {
}
}
op {
+ name: "UniqueWithCountsV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Taxis"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "idx"
+ type_attr: "out_idx"
+ }
+ output_arg {
+ name: "count"
+ type_attr: "out_idx"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Taxis"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "out_idx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Unpack"
input_arg {
name: "value"
@@ -65687,6 +65740,130 @@ op {
}
}
op {
+ name: "UnsortedSegmentMin"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "UnsortedSegmentProd"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "UnsortedSegmentSum"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 3d84ab3f25..3faa4eeada 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -30943,6 +30943,59 @@ op {
}
}
op {
+ name: "UniqueWithCountsV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Taxis"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "idx"
+ type_attr: "out_idx"
+ }
+ output_arg {
+ name: "count"
+ type_attr: "out_idx"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Taxis"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "out_idx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Unpack"
input_arg {
name: "value"
@@ -31061,6 +31114,130 @@ op {
}
}
op {
+ name: "UnsortedSegmentMin"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "UnsortedSegmentProd"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "UnsortedSegmentSum"
input_arg {
name: "data"
diff --git a/tensorflow/docs_src/programmers_guide/version_compat.md b/tensorflow/docs_src/programmers_guide/version_compat.md
index 5412fba5d0..72e427c5f8 100644
--- a/tensorflow/docs_src/programmers_guide/version_compat.md
+++ b/tensorflow/docs_src/programmers_guide/version_compat.md
@@ -183,7 +183,7 @@ Our versioning scheme has three requirements:
* **Forward compatibility** to support scenarios where the producer of a
graph or checkpoint is upgraded to a newer version of TensorFlow before
the consumer.
-* Enable evolving TensorFlow in incompatible ways. For example, removing Ops,
+* Enable evolving TensorFlow in incompatible ways. For example, removing ops,
adding attributes, and removing attributes.
Note that while the `GraphDef` version mechanism is separate from the TensorFlow
@@ -245,10 +245,10 @@ contains a main data version which is treated as either `producer` or
`TF_CHECKPOINT_VERSION_MIN_CONSUMER`, and
`TF_CHECKPOINT_VERSION_MIN_PRODUCER`.
-### Add a new attribute with default to an existing Op
+### Add a new attribute with default to an existing op
Following the guidance below gives you forward compatibility only if the set of
-Ops has not changed.
+ops has not changed:
1. If forward compatibility is desired, set `strip_default_attrs` to `True`
while exporting the model using either the
@@ -257,39 +257,39 @@ Ops has not changed.
methods of the `SavedModelBuilder` class, or
@{tf.estimator.Estimator.export_savedmodel$`Estimator.export_savedmodel`}
2. This strips off the default valued attributes at the time of
- producing/exporting the models; thereby making sure that the exported
- @{tf.MetaGraphDef} does not contain the new Op-attribute when the default
+ producing/exporting the models. This makes sure that the exported
+ @{tf.MetaGraphDef} does not contain the new op-attribute when the default
value is used.
-3. Having this control lets potentially old consumers aka serving binaries
- (lagging behind training binaries) continue loading the models
- thereby preventing interruptions in model serving.
+3. Having this control could allow out-of-date consumers (for example, serving
+ binaries that lag behind training binaries) to continue loading the models
+ and prevent interruptions in model serving.
### Evolving GraphDef versions
This section explains how to use this versioning mechanism to make different
types of changes to the `GraphDef` format.
-#### Add an Op
+#### Add an op
-Add the new Op to both consumers and producers at the same time, and do not
+Add the new op to both consumers and producers at the same time, and do not
change any `GraphDef` versions. This type of change is automatically
backward compatible, and does not impact forward compatibility plan since
existing producer scripts will not suddenly use the new functionality.
-#### Add an Op and switch existing Python wrappers to use it
+#### Add an op and switch existing Python wrappers to use it
1. Implement new consumer functionality and increment the `GraphDef` version.
2. If it is possible to make the wrappers use the new functionality only in
cases that did not work before, the wrappers can be updated now.
3. Change Python wrappers to use the new functionality. Do not increment
- `min_consumer`, since models that do not use this Op should not break.
+ `min_consumer`, since models that do not use this op should not break.
-#### Remove or restrict an Op's functionality
+#### Remove or restrict an op's functionality
-1. Fix all producer scripts (not TensorFlow itself) to not use the banned Op or
+1. Fix all producer scripts (not TensorFlow itself) to not use the banned op or
functionality.
2. Increment the `GraphDef` version and implement new consumer functionality
- that bans the removed Op or functionality for GraphDefs at the new version
+ that bans the removed op or functionality for GraphDefs at the new version
and above. If possible, make TensorFlow stop producing `GraphDefs` with the
banned functionality. To do so, add the
[`REGISTER_OP(...).Deprecated(deprecated_at_version,
@@ -298,15 +298,15 @@ existing producer scripts will not suddenly use the new functionality.
4. Increase `min_producer` to the GraphDef version from (2) and remove the
functionality entirely.
-#### Change an Op's functionality
+#### Change an op's functionality
-1. Add a new similar Op named `SomethingV2` or similar and go through the
+1. Add a new similar op named `SomethingV2` or similar and go through the
process of adding it and switching existing Python wrappers to use it, which
may take three weeks if forward compatibility is desired.
-2. Remove the old Op (Can only take place with a major version change due to
+2. Remove the old op (Can only take place with a major version change due to
backward compatibility).
-3. Increase `min_consumer` to rule out consumers with the old Op, add back the
- old Op as an alias for `SomethingV2`, and go through the process to switch
+3. Increase `min_consumer` to rule out consumers with the old op, add back the
+ old op as an alias for `SomethingV2`, and go through the process to switch
existing Python wrappers to use it.
4. Go through the process to remove `SomethingV2`.
@@ -314,6 +314,6 @@ existing producer scripts will not suddenly use the new functionality.
1. Bump the `GraphDef` version and add the bad version to `bad_consumers` for
all new GraphDefs. If possible, add to `bad_consumers` only for GraphDefs
- which contain a certain Op or similar.
+ which contain a certain op or similar.
2. If existing consumers have the bad version, push them out as soon as
possible.
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 336df7c2f7..469d1e9adb 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -384,64 +384,125 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t
return op.Output(0), op.Output(1), op.Output(2)
}
-// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
-type MutableHashTableV2Attr func(optionalAttr)
+// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient.
+type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr)
-// MutableHashTableV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this table is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func MutableHashTableV2Container(value string) MutableHashTableV2Attr {
+// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value.
+// If not specified, defaults to -6
+func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr {
return func(m optionalAttr) {
- m["container"] = value
+ m["min"] = value
}
}
-// MutableHashTableV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this table is shared under the given name across
-// multiple sessions.
-// If not specified, defaults to ""
-func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr {
+// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value.
+// If not specified, defaults to 6
+func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr {
return func(m optionalAttr) {
- m["shared_name"] = value
+ m["max"] = value
}
}
-// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
-//
-// value: If true and shared_name is empty, the table is shared
-// using the node name.
+// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value.
+// If not specified, defaults to 8
+func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr {
+ return func(m optionalAttr) {
+ m["num_bits"] = value
+ }
+}
+
+// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value.
// If not specified, defaults to false
-func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr {
+func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr {
return func(m optionalAttr) {
- m["use_node_name_sharing"] = value
+ m["narrow_range"] = value
}
}
-// Creates an empty hash table.
-//
-// This op creates a mutable hash table, specifying the type of its keys and
-// values. Each value must be a scalar. Data can be inserted into the table using
-// the insert operations. It does not support the initialization operation.
+// Compute gradients for a FakeQuantWithMinMaxArgs operation.
//
// Arguments:
-// key_dtype: Type of the table keys.
-// value_dtype: Type of the table values.
+// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation.
+// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation.
//
-// Returns Handle to a table.
-func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) {
+// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
+// `gradients * (inputs >= min && inputs <= max)`.
+func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "MutableHashTableV2",
+ Type: "FakeQuantWithMinMaxArgsGradient",
+ Input: []tf.Input{
+ gradients, inputs,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs.
+type FakeQuantWithMinMaxArgsAttr func(optionalAttr)
+
+// FakeQuantWithMinMaxArgsMin sets the optional min attribute to value.
+// If not specified, defaults to -6
+func FakeQuantWithMinMaxArgsMin(value float32) FakeQuantWithMinMaxArgsAttr {
+ return func(m optionalAttr) {
+ m["min"] = value
+ }
+}
+
+// FakeQuantWithMinMaxArgsMax sets the optional max attribute to value.
+// If not specified, defaults to 6
+func FakeQuantWithMinMaxArgsMax(value float32) FakeQuantWithMinMaxArgsAttr {
+ return func(m optionalAttr) {
+ m["max"] = value
+ }
+}
+
+// FakeQuantWithMinMaxArgsNumBits sets the optional num_bits attribute to value.
+// If not specified, defaults to 8
+func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr {
+ return func(m optionalAttr) {
+ m["num_bits"] = value
+ }
+}
+
+// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value.
+// If not specified, defaults to false
+func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr {
+ return func(m optionalAttr) {
+ m["narrow_range"] = value
+ }
+}
+
+// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
+//
+// Attributes `[min; max]` define the clamping range for the `inputs` data.
+// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
+// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
+// then de-quantized and output as floats in `[min; max]` interval.
+// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
+//
+// Quantization is called fake since the output is still in floating point.
+func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FakeQuantWithMinMaxArgs",
+ Input: []tf.Input{
+ inputs,
+ },
Attrs: attrs,
}
op := scope.AddOperation(opspec)
@@ -1146,6 +1207,21 @@ func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Computes rectified linear 6: `min(max(features, 0), 6)`.
+func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Relu6",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along segments of a tensor.
//
// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
@@ -3861,21 +3937,6 @@ func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_
return op.Output(0)
}
-// Computes rectified linear 6: `min(max(features, 0), 6)`.
-func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Relu6",
- Input: []tf.Input{
- features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes rectified linear gradients for a Relu operation.
//
// Arguments:
@@ -4279,68 +4340,6 @@ func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax
return op.Output(0)
}
-// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient.
-type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr)
-
-// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value.
-// If not specified, defaults to -6
-func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr {
- return func(m optionalAttr) {
- m["min"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value.
-// If not specified, defaults to 6
-func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr {
- return func(m optionalAttr) {
- m["max"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value.
-// If not specified, defaults to 8
-func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr {
- return func(m optionalAttr) {
- m["num_bits"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value.
-// If not specified, defaults to false
-func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr {
- return func(m optionalAttr) {
- m["narrow_range"] = value
- }
-}
-
-// Compute gradients for a FakeQuantWithMinMaxArgs operation.
-//
-// Arguments:
-// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation.
-// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation.
-//
-// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
-// `gradients * (inputs >= min && inputs <= max)`.
-func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FakeQuantWithMinMaxArgsGradient",
- Input: []tf.Input{
- gradients, inputs,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// AvgPool3DAttr is an optional argument to AvgPool3D.
type AvgPool3DAttr func(optionalAttr)
@@ -16864,6 +16863,70 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out
return scope.AddOperation(opspec)
}
+// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
+type MutableHashTableV2Attr func(optionalAttr)
+
+// MutableHashTableV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this table is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func MutableHashTableV2Container(value string) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MutableHashTableV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this table is shared under the given name across
+// multiple sessions.
+// If not specified, defaults to ""
+func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
+//
+// value: If true and shared_name is empty, the table is shared
+// using the node name.
+// If not specified, defaults to false
+func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr {
+ return func(m optionalAttr) {
+ m["use_node_name_sharing"] = value
+ }
+}
+
+// Creates an empty hash table.
+//
+// This op creates a mutable hash table, specifying the type of its keys and
+// values. Each value must be a scalar. Data can be inserted into the table using
+// the insert operations. It does not support the initialization operation.
+//
+// Arguments:
+// key_dtype: Type of the table keys.
+// value_dtype: Type of the table values.
+//
+// Returns Handle to a table.
+func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MutableHashTableV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Deprecated. Disallowed in GraphDef version >= 2.
//
// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead
@@ -23901,69 +23964,6 @@ func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, pa
return op.Output(0)
}
-// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs.
-type FakeQuantWithMinMaxArgsAttr func(optionalAttr)
-
-// FakeQuantWithMinMaxArgsMin sets the optional min attribute to value.
-// If not specified, defaults to -6
-func FakeQuantWithMinMaxArgsMin(value float32) FakeQuantWithMinMaxArgsAttr {
- return func(m optionalAttr) {
- m["min"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsMax sets the optional max attribute to value.
-// If not specified, defaults to 6
-func FakeQuantWithMinMaxArgsMax(value float32) FakeQuantWithMinMaxArgsAttr {
- return func(m optionalAttr) {
- m["max"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsNumBits sets the optional num_bits attribute to value.
-// If not specified, defaults to 8
-func FakeQuantWithMinMaxArgsNumBits(value int64) FakeQuantWithMinMaxArgsAttr {
- return func(m optionalAttr) {
- m["num_bits"] = value
- }
-}
-
-// FakeQuantWithMinMaxArgsNarrowRange sets the optional narrow_range attribute to value.
-// If not specified, defaults to false
-func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr {
- return func(m optionalAttr) {
- m["narrow_range"] = value
- }
-}
-
-// Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
-//
-// Attributes `[min; max]` define the clamping range for the `inputs` data.
-// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
-// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
-// then de-quantized and output as floats in `[min; max]` interval.
-// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive.
-//
-// Quantization is called fake since the output is still in floating point.
-func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FakeQuantWithMinMaxArgs",
- Input: []tf.Input{
- inputs,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// StageAttr is an optional argument to Stage.
type StageAttr func(optionalAttr)
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index a8f2154db8..3119ab0037 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -52,9 +52,11 @@ py_library(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index e0d63b5ebc..390ce852b1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -111,11 +111,11 @@ class Dataset(object):
self.output_types, self.output_shapes,
self.output_classes)
- def make_one_shot_iterator(self):
+ def __iter__(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
- Note: The returned iterator will be initialized automatically.
- A "one-shot" iterator does not currently support re-initialization.
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
Returns:
An `Iterator` over the elements of this dataset.
@@ -124,9 +124,22 @@ class Dataset(object):
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
- raise RuntimeError(
- "dataset.make_one_shot_iterator is not supported when eager "
- "execution is enabled.")
+ return iterator_ops.EagerIterator(self)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+
+ def make_one_shot_iterator(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ Note: The returned iterator will be initialized automatically.
+ A "one-shot" iterator does not currently support re-initialization.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+ """
+ if context.executing_eagerly():
+ return iterator_ops.EagerIterator(self)
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 4756ec7482..d79b9d6011 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -17,14 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import threading
import warnings
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util.tf_export import tf_export
@@ -412,3 +416,147 @@ class Iterator(object):
of an element of this dataset.
"""
return self._output_types
+
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+class EagerIterator(object):
+ """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+
+ def __init__(self, dataset):
+ """Creates a new iterator over the given dataset.
+
+ For example:
+ ```python
+ dataset = tf.data.Dataset.range(4)
+ for x in Iterator(dataset):
+ print(x)
+ ```
+
+ Tensors produced will be placed on the device on which this iterator object
+ was created.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+
+ Raises:
+ RuntimeError: When invoked without eager execution enabled.
+ """
+
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "{} objects can only be used when eager execution is enabled, use "
+ "tf.data.Dataset.make_initializable_iterator or "
+ "tf.data.Dataset.make_one_shot_iterator for graph construction".
+ format(type(self)))
+ with ops.device("/device:CPU:0"):
+ ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
+ self._output_classes = dataset.output_classes
+ self._output_types = dataset.output_types
+ self._output_shapes = dataset.output_shapes
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._resource = gen_dataset_ops.iterator(
+ shared_name="",
+ container=_generate_shared_name("eageriterator"),
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ gen_dataset_ops.make_iterator(ds_variant, self._resource)
+ # Delete the resource when this object is deleted
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device="/device:CPU:0")
+ self._device = context.context().device_name
+
+ def __iter__(self):
+ return self
+
+ def __next__(self): # For Python 3 compatibility
+ return self.next()
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ with ops.device(self._device):
+ # TODO(ashankar): Consider removing this ops.device() contextmanager
+ # and instead mimic ops placement in graphs: Operations on resource
+ # handles execute on the same device as where the resource is placed.
+ # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
+ # because in eager mode this code will run synchronously on the calling
+ # thread. Therefore we do not need to make a defensive context switch
+ # to a background thread, and can achieve a small constant performance
+ # boost by invoking the iterator synchronously.
+ ret = gen_dataset_ops.iterator_get_next_sync(
+ self._resource,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+
+ def next(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ try:
+ return self._next_internal()
+ except errors.OutOfRangeError:
+ raise StopIteration
+
+ @property
+ def output_classes(self):
+ """Returns the class of each component of an element of this iterator.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this dataset.
+ """
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this iterator.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this dataset.
+ """
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ """Returns the type of each component of an element of this iterator.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this dataset.
+ """
+ return self._output_types
+
+ def get_next(self, name=None):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+
+ Args:
+ name: (Optional.) A name for the created operation. Currently unused.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects.
+
+ Raises:
+ `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
+ """
+ del name
+ return self._next_internal()
diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md
index a2273b050b..269bbb19bd 100644
--- a/tensorflow/python/debug/README.md
+++ b/tensorflow/python/debug/README.md
@@ -37,12 +37,18 @@ models:
* Association of nodes and tensors in graphs with Python source lines
* Profiling of models at the level of graph nodes and Python source lines.
(Omitted internal-only feature)
+* A [gRPC](https://grpc.io/)-based remote debugging protocol, which allows us to
+ build a browser-based graphical user interface (GUI) for TFDBG: the
+ [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
## How to use TFDBG?
* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/programmers_guide/debugger.
+* For information on the web GUI of TFDBG (TensorBoard Debugger Plugin), see
+ [this README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
* For programmatic use of the API of TFDBG, see https://www.tensorflow.org/api_docs/python/tfdbg.
+
## Related Publications
* Cai, S., Breck E., Nielsen E., Salib M., Sculley D. (2016) TensorFlow Debugger:
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index d504ca0b05..012c68f68e 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -250,13 +250,23 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteTooManyNumOutputs(self):
# num_outputs provided is 50, but only one output is produced.
- # That should be okay.
- product = execute(
- b'Mul',
- num_outputs=50,
- inputs=[constant_op.constant(3), constant_op.constant(5)],
- attrs=('T', dtypes.int32.as_datatype_enum))[0]
- self.assertAllEqual(15, product)
+ with self.assertRaises(errors.InvalidArgumentError):
+ _ = execute(
+ b'Mul',
+ num_outputs=50,
+ inputs=[constant_op.constant(3),
+ constant_op.constant(5)],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
+
+ def testExecuteTooFewNumOutputs(self):
+ # num_outputs provided is 50, but only one output is produced.
+ with self.assertRaises(errors.InvalidArgumentError):
+ _ = execute(
+ b'Mul',
+ num_outputs=0,
+ inputs=[constant_op.constant(3),
+ constant_op.constant(5)],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
def testMatMulGPU(self):
if not context.context().num_gpus():
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 8338bc4343..105c09e81f 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -340,8 +340,10 @@ void EagerTensor_dealloc(EagerTensor* self) {
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
Py_DECREF(self->tensor_shape);
- TFE_DeleteTensorHandle(self->handle);
- self->handle = nullptr;
+ if (self->handle != nullptr) {
+ TFE_DeleteTensorHandle(self->handle);
+ self->handle = nullptr;
+ }
// We have the global interpreter lock, so use this chance to perform delayed
// refcount decrements.
tensorflow::ClearDecrefCache();
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index fcb0452a14..fe9785dc66 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1012,7 +1012,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = EagerTensor_id(tensor);
- return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
+ const tensorflow::Tensor* tensor = nullptr;
+ const tensorflow::Status status = t->Tensor(&tensor);
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ return tensorflow::eager::TapeTensor{id, t->dtype,
+ tensorflow::TensorShape({})};
+ } else {
+ return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()};
+ }
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
diff --git a/tensorflow/python/estimator/replicate_model_fn.py b/tensorflow/python/estimator/replicate_model_fn.py
index 7418852096..144d89abf3 100644
--- a/tensorflow/python/estimator/replicate_model_fn.py
+++ b/tensorflow/python/estimator/replicate_model_fn.py
@@ -50,7 +50,6 @@ from tensorflow.python.training import optimizer as optimizer_lib
def _replicate_model_fn(model_fn,
- loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
devices=None):
"""Replicate `Estimator.model_fn` over GPUs.
@@ -109,8 +108,9 @@ def _replicate_model_fn(model_fn,
On reduction algorithms:
Certain algorithms were chosen for aggregating results of computations on
multiple towers:
- - Losses from all towers are reduced according to `loss_reduction`.
- - Gradients from all towers are reduced according to `loss_reduction`
+ - Losses from all towers are reduced according to `loss_reduction` argument
+ to TowerOptimizer..
+ - Gradients from all towers are reduced according to the `loss_reduction`
for each trainable variable.
- `eval_metrics_ops` are reduced per metric using `reduce_mean`.
- `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
@@ -134,16 +134,11 @@ def _replicate_model_fn(model_fn,
Args:
model_fn: `model_fn` as defined in `Estimator`. See the section above about
the train_op argument of `EstimatorSpec`.
- loss_reduction: controls whether losses are summed or averaged.
devices: Optional list of devices to replicate the model across. This
argument can be used to replice only on the subset of available GPUs.
If `None`, then all available GPUs are going to be used for replication.
If no GPUs are available, then the model is going to be placed on the CPU.
- Raises:
- ValueError: if there is no `loss_reduction` or if _TowerOptimizer is
- mis-used.
-
Returns:
A replicated version of the supplied `model_fn`. Returned function that
conforms to the requirements of `Estimator`'s `model_fn` and can be used
@@ -151,7 +146,6 @@ def _replicate_model_fn(model_fn,
"""
return _replicate_model_fn_with_mode(
model_fn,
- loss_reduction,
devices,
# TODO(isaprykin): Query the system configuration to choose modes other
# than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often
@@ -186,13 +180,9 @@ class _VariableDistributionMode(object):
def _replicate_model_fn_with_mode(
model_fn,
- loss_reduction,
devices=None,
mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
"""A version of `replicate_model_fn` that allows to specify a `mode`."""
- if loss_reduction == losses.Reduction.NONE:
- raise ValueError('Tower losses need to be reduced in some way, yet {} '
- 'reduction is specified.'.format(loss_reduction))
if not devices:
devices = _get_local_devices('GPU') or _get_local_devices('CPU')
@@ -215,7 +205,6 @@ def _replicate_model_fn_with_mode(
features=[features],
labels=[labels],
params=params,
- loss_reduction=loss_reduction,
config=config,
devices=devices,
local_ps_devices=ps_devices)[0] # One device, so one spec is out.
@@ -230,7 +219,6 @@ def _replicate_model_fn_with_mode(
features=feature_shards,
labels=label_shards,
params=params,
- loss_reduction=loss_reduction,
config=config,
devices=devices,
local_ps_devices=ps_devices)
@@ -255,7 +243,8 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
- def __init__(self, optimizer_or_optimizer_fn):
+ def __init__(self, optimizer_or_optimizer_fn,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE):
"""Wrap an existing optimizer for gathering gradients across towers.
Each invocation of model_fn has to call the same optimizers in the same
@@ -275,8 +264,10 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
optimizer_or_optimizer_fn: an instance of optimizer to wrap. That
instance is going to be used for optimizer-specific logic. This can
also be a no-argument function that returns such an optimizer instance.
+ loss_reduction: controls whether losses are summed or averaged.
"""
self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
+ self._loss_reduction = loss_reduction
@staticmethod
def has_been_used():
@@ -296,8 +287,9 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
def compute_gradients(self, loss, *args, **kwargs):
"""Compute gradients, but first, if needed, scale the loss."""
+ _TowerOptimizer._graph_state().set_loss_reduction(self._loss_reduction)
loss = _scale_loss(loss,
- self._graph_state().loss_reduction,
+ self._loss_reduction,
self._graph_state().number_of_towers)
return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
@@ -402,10 +394,12 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
self._collected_grads_and_vars[tower_id][index_of_last_gradients])
return grads_and_vars
- def set_reduction_across_towers(self, loss_reduction, number_of_towers):
- self._loss_reduction = loss_reduction
+ def set_number_of_towers(self, number_of_towers):
self._number_of_towers = number_of_towers
+ def set_loss_reduction(self, loss_reduction):
+ self._loss_reduction = loss_reduction
+
@contextmanager
def tower(self, tower_id, var_scope, name_scope):
if tower_id == 0:
@@ -509,7 +503,6 @@ def _get_loss_towers(model_fn,
config,
devices,
local_ps_devices,
- loss_reduction,
name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
"""Replicate the loss computation across devices."""
tower_specs = []
@@ -524,8 +517,7 @@ def _get_loss_towers(model_fn,
# pylint: disable=protected-access
round_robin_strategy = device_setter_lib._RoundRobinStrategy(
num_tasks=len(local_ps_devices))
- _TowerOptimizer._graph_state().set_reduction_across_towers(
- loss_reduction, len(devices))
+ _TowerOptimizer._graph_state().set_number_of_towers(len(devices))
for i, device in enumerate(devices):
is_the_first_tower = (i == 0)
@@ -567,7 +559,9 @@ def _get_loss_towers(model_fn,
# Scaling the loss here doesn't actually affect gradients. Another
# instance of scaling happens inside the _TowerOptimizer.
tower_spec = _scale_tower_loss(
- tower_spec, loss_reduction, number_of_towers=len(devices))
+ tower_spec,
+ _TowerOptimizer._graph_state().loss_reduction,
+ number_of_towers=len(devices))
tower_specs.append(tower_spec)
if not _TowerOptimizer._did_towers_have_same_optimizer_calls():
@@ -607,20 +601,27 @@ def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
return tower_spec
estimator_spec = _asdict(tower_spec)
- estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction,
- number_of_towers)
+ estimator_spec['loss'] = _scale_loss(
+ tower_spec.loss,
+ loss_reduction,
+ number_of_towers,
+ reduced_loss_name='averaged_loss')
return model_fn_lib.EstimatorSpec(**estimator_spec)
-def _scale_loss(loss, loss_reduction, number_of_towers):
+def _scale_loss(loss, loss_reduction, number_of_towers, reduced_loss_name=None):
"""If needed, scale down the loss for averaging loss by summing."""
if loss is None:
return None
if number_of_towers == 1:
return loss
+ if loss_reduction == losses.Reduction.NONE:
+ raise ValueError('Tower losses need to be reduced in some way, yet {} '
+ 'reduction is specified.'.format(loss_reduction))
+
if loss_reduction != losses.Reduction.SUM:
- return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss')
+ return math_ops.div(loss, 1.0 * number_of_towers, name=reduced_loss_name)
else:
return loss
diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py
index b6dd4e981f..ad1f9c02b9 100644
--- a/tensorflow/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/python/estimator/replicate_model_fn_test.py
@@ -121,8 +121,9 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
estimator = dnn.DNNClassifier(
hidden_units=(2, 2),
# Adagrad is configured with `get_optimizer_instance`, so the function
- # form of `_TowerOptimizer.__init__` is used.
- optimizer=replicate_model_fn._TowerOptimizer(optimizer_fn),
+ # form of `TowerOptimizer.__init__` is used.
+ optimizer=replicate_model_fn._TowerOptimizer(
+ optimizer_fn, loss_reduction=losses.Reduction.SUM),
feature_columns=feature_columns,
n_classes=n_classes,
model_dir=self._model_dir)
@@ -134,7 +135,6 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
model_fn = replicate_model_fn._replicate_model_fn_with_mode(
estimator.model_fn,
devices=['/gpu:0', '/gpu:1', '/gpu:2'],
- loss_reduction=losses.Reduction.SUM,
mode=mode)
estimator = estimator_lib.Estimator(
@@ -178,32 +178,39 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
class ReplicateModelTest(test_util.TensorFlowTestCase):
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
+ def create_model_fn_with_loss_reduction(self, loss_reduction):
- predictions = math_ops.multiply(features, c)
+ def model_fn(mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
+ predictions = math_ops.multiply(features, c)
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions,
+ reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
- optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(params['learning_rate']))
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
+ optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(params['learning_rate']),
+ loss_reduction=loss_reduction)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=optimizer.minimize(loss))
+
+ return model_fn
@property
def params(self):
@@ -217,8 +224,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- loss_reduction=losses.Reduction.SUM,
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -248,7 +254,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
dtype=dtypes.float64)
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
+ devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
session.run(variables.global_variables_initializer())
@@ -284,8 +291,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session, variable_scope.variable_scope(
'', reuse=variable_scope.AUTO_REUSE):
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- loss_reduction=losses.Reduction.SUM,
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -307,8 +313,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn,
- loss_reduction=losses.Reduction.SUM,
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
@@ -338,7 +343,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
+ devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
session.run(variables.local_variables_initializer())
@@ -367,7 +373,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
session.run(variables.global_variables_initializer())
@@ -382,7 +389,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
session.run(variables.global_variables_initializer())
@@ -404,7 +412,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
session.run(variables.local_variables_initializer())
@@ -432,7 +441,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
session.run(variables.global_variables_initializer())
@@ -448,15 +458,22 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0', '/gpu:1'])
_ = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
def test_unsupported_loss_reduction(self):
+ features = np.array([[1.0], [2.0], [3.0]])
+ labels = np.array([[1.0], [2.0], [3.0]])
+
with self.assertRaisesRegexp(ValueError,
'.+none.+reduction.+is.+specified.+'):
- _ = replicate_model_fn._replicate_model_fn(self.model_fn,
- losses.Reduction.NONE)
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.create_model_fn_with_loss_reduction(losses.Reduction.NONE),
+ devices=['/gpu:0', '/gpu:1', '/gpu:2'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
def test_places_on_gpu_with_upper_case_spelling(self):
features = np.array([[0.01], [0.002]])
@@ -464,7 +481,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session():
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/GPU:0'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/GPU:0'])
_ = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -478,7 +496,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
with self.test_session():
replicated_model_fn = replicate_model_fn._replicate_model_fn(
- self.model_fn, devices=['/gpu:0'])
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+ devices=['/gpu:0'])
_ = replicated_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -624,7 +643,8 @@ class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
optimizer = training.SyncReplicasOptimizer(
optimizer, replicas_to_aggregate=1)
sync_hook = optimizer.make_session_run_hook(True)
- optimizer = replicate_model_fn._TowerOptimizer(optimizer)
+ optimizer = replicate_model_fn._TowerOptimizer(
+ optimizer, loss_reduction=losses.Reduction.SUM)
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -650,7 +670,6 @@ class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
model_fn = replicate_model_fn._replicate_model_fn(
self.model_fn,
- loss_reduction=losses.Reduction.SUM,
devices=['/gpu:0', '/gpu:1'])
estimator = estimator_lib.Estimator(
@@ -687,9 +706,10 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
}
first_optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0))
+ gradient_descent.GradientDescentOptimizer(1.0),
+ loss_reduction=losses.Reduction.SUM)
second_optimizer = replicate_model_fn._TowerOptimizer(
- adam.AdamOptimizer(1.0))
+ adam.AdamOptimizer(1.0), loss_reduction=losses.Reduction.SUM)
with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
first_grads_and_vars = first_optimizer.compute_gradients(loss)
@@ -712,7 +732,6 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
with self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
self.model_fn,
- loss_reduction=losses.Reduction.SUM,
devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN, {})
@@ -787,11 +806,13 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
train_ops = []
optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0))
+ gradient_descent.GradientDescentOptimizer(1.0),
+ loss_reduction=losses.Reduction.SUM)
train_ops.append(optimizer.minimize(loss, var_list=[c]))
if not self.should_skip_optimizer():
another_optimizer = replicate_model_fn._TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(1.0))
+ gradient_descent.GradientDescentOptimizer(1.0),
+ loss_reduction=losses.Reduction.SUM)
train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
train_op = control_flow_ops.group(train_ops)
@@ -806,10 +827,9 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with ops_lib.Graph().as_default(), self.test_session() as session:
replicated_model_fn = replicate_model_fn._replicate_model_fn(
self.model_fn,
- loss_reduction=losses.Reduction.SUM,
devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN, {})
@@ -881,7 +901,7 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
with self.test_session():
with self.assertRaisesRegexp(ValueError,
- 'Please.+wrap.+with.+_TowerOptimizer'):
+ 'Please.+wrap.+with.+TowerOptimizer'):
replicated_model_fn = replicate_model_fn._replicate_model_fn(
self.model_fn, devices=['/gpu:0', '/gpu:1'])
_ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
@@ -890,30 +910,43 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
class GetLossTowersTest(test_util.TensorFlowTestCase):
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(0.25, dtype=dtypes.float64),
- dtype=dtypes.float64)
+ def create_model_fn_with_loss_reduction(self, loss_reduction):
- predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
- labels = np.array([0.1, 0.2, 0.3, labels[0]])
+ def model_fn(mode, features, labels, params):
+ del params
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+ labels = np.array([0.1, 0.2, 0.3, labels[0]])
- return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions,
+ reduction=losses.Reduction.SUM)
+
+ optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(1.0),
+ loss_reduction)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=math_ops.reduce_sum(loss),
+ train_op=optimizer.minimize(loss))
+
+ return model_fn
def test_gradients_are_computed(self):
with self.test_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
- self.model_fn,
+ self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
mode=None,
features=[[0.6], [1.6]],
labels=[[0.6], [0.6]],
params=None,
config=None,
- loss_reduction=losses.Reduction.SUM,
devices=['/gpu:0', '/gpu:1'],
local_ps_devices=['/gpu:0'],
name_scope_pattern='test_tower_{}')
@@ -941,12 +974,11 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
def test_gradients_are_computed_with_mean_reduction(self):
with self.test_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
- self.model_fn,
+ self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
mode=model_fn_lib.ModeKeys.EVAL,
features=[[0.6], [1.6]],
labels=[[0.6], [0.6]],
params=None,
- loss_reduction=losses.Reduction.MEAN,
config=None,
devices=['/gpu:0', '/gpu:1'],
local_ps_devices=['/gpu:0'],
@@ -999,7 +1031,6 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
features=[[0.6], [1.6], [2.6]],
labels=[[0.6], [0.6], [2.6]],
params=None,
- loss_reduction=losses.Reduction.SUM,
config=None,
devices=['/gpu:0', '/gpu:1', '/gpu:3'],
local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
@@ -1296,7 +1327,6 @@ class PredictSpecTest(test_util.TensorFlowTestCase):
self.model_fn,
mode=None,
features=[[0.1], [0.2]],
- loss_reduction=losses.Reduction.SUM,
labels=[[], []],
params=None,
config=None,
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 85971c91bf..381153c66a 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -1804,6 +1804,21 @@ def _create_categorical_column_weighted_sum(
name='weighted_sum')
+class _SequenceDenseColumn(_FeatureColumn):
+ """Represents dense sequence data."""
+
+ __metaclass__ = abc.ABCMeta
+
+ TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length'])
+
+ @abc.abstractmethod
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ """Returns a `TensorSequenceLengthPair`."""
+ pass
+
+
class _LazyBuilder(object):
"""Handles caching of transformations while building the model.
@@ -2152,7 +2167,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
class _EmbeddingColumn(
- _DenseColumn,
+ _DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_EmbeddingColumn', (
'categorical_column', 'dimension', 'combiner', 'initializer',
'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'
@@ -2178,7 +2193,9 @@ class _EmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor_internal(
+ self, inputs, weight_collections=None, trainable=None):
+ """Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
inputs, weight_collections=weight_collections, trainable=trainable)
@@ -2210,6 +2227,43 @@ class _EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(
+ self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(
+ inputs=inputs, weight_collections=weight_collections,
+ trainable=trainable)
+
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(
+ self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
class _SharedEmbeddingColumn(
_DenseColumn,
@@ -2890,7 +2944,7 @@ def _prune_invalid_ids(sparse_ids, sparse_weights):
return sparse_ids, sparse_weights
-class _IndicatorColumn(_DenseColumn,
+class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_IndicatorColumn',
['categorical_column'])):
"""Represents a one-hot column for use in deep networks.
@@ -2966,15 +3020,53 @@ class _IndicatorColumn(_DenseColumn,
Returns:
Dense `Tensor` created within `_transform_feature`.
+
+ Raises:
+ ValueError: If `categorical_column` is a `_SequenceCategoricalColumn`.
"""
# Do nothing with weight_collections and trainable since no variables are
# created in this function.
del weight_collections
del trainable
+ if isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(
+ self.name, type(self.categorical_column),
+ self.categorical_column))
# Feature has been already transformed. Return the intermediate
# representation created by _transform_feature.
return inputs.get(self)
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ # Do nothing with weight_collections and trainable since no variables are
+ # created in this function.
+ del weight_collections
+ del trainable
+ if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(
+ self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ dense_tensor = inputs.get(self)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
def _verify_static_batch_size_equality(tensors, columns):
# bath_size is a tf.Dimension object.
@@ -2990,3 +3082,68 @@ def _verify_static_batch_size_equality(tensors, columns):
'Batch size of columns ({}, {}): ({}, {})'.format(
columns[bath_size_column_index].name, columns[i].name,
expected_batch_size, tensors[i].shape[0]))
+
+
+def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
+ """Returns a [batch_size] Tensor with per-example sequence length."""
+ with ops.name_scope(None, 'sequence_length') as name_scope:
+ row_ids = sp_tensor.indices[:, 0]
+ column_ids = sp_tensor.indices[:, 1]
+ column_ids += array_ops.ones_like(column_ids)
+ seq_length = math_ops.to_int64(
+ math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # If the last n rows do not have ids, seq_length will have shape
+ # [batch_size - n]. Pad the remaining values with zeros.
+ n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
+ padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
+ return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
+
+
+class _SequenceCategoricalColumn(
+ _CategoricalColumn,
+ collections.namedtuple(
+ '_SequenceCategoricalColumn', ['categorical_column'])):
+ """Represents sequences of categorical data."""
+
+ @property
+ def name(self):
+ return self.categorical_column.name
+
+ @property
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
+ def _transform_feature(self, inputs):
+ return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
+
+ @property
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands final dimension, so that embeddings are not combined during
+ # embedding lookup.
+ check_id_rank = check_ops.assert_equal(
+ array_ops.rank(id_tensor), 2,
+ data=[
+ 'Column {} expected ID tensor of rank 2. '.format(self.name),
+ 'id_tensor shape: ', array_ops.shape(id_tensor)])
+ with ops.control_dependencies([check_id_rank]):
+ id_tensor = sparse_ops.sparse_reshape(
+ id_tensor,
+ shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+ if weight_tensor is not None:
+ check_weight_rank = check_ops.assert_equal(
+ array_ops.rank(weight_tensor), 2,
+ data=[
+ 'Column {} expected weight tensor of rank 2.'.format(self.name),
+ 'weight_tensor shape:', array_ops.shape(weight_tensor)])
+ with ops.control_dependencies([check_weight_rank]):
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor,
+ shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index 3172f3c2c3..4bb030cb89 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -48,6 +48,7 @@
## Graph collections
@@add_to_collection
+@@add_to_collections
@@get_collection
@@get_collection_ref
@@GraphKeys
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f5dde3a358..6174d32237 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -5602,7 +5602,7 @@ def add_to_collection(name, value):
"""
get_default_graph().add_to_collection(name, value)
-
+@tf_export("add_to_collections")
def add_to_collections(names, value):
"""Wrapper for `Graph.add_to_collections()` using the default graph.
diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py
index 582ce81e7a..1170a41c99 100644
--- a/tensorflow/python/framework/smart_cond_test.py
+++ b/tensorflow/python/framework/smart_cond_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@@ -144,5 +145,22 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
+@test_util.with_c_api
+class SmartConstantValueTest(test_util.TensorFlowTestCase):
+
+ # TODO(skyewm): this is essentially a regression test for
+ # TF_TryEvaluateConstant, and is not really a valid smart_constant_value test
+ # (smart_constant_value is only supposed to return bools). Move the
+ # TF_TryEvaluateConstant call to tensor_util.constant_value and make this a
+ # constant_value test instead.
+ def testCond(self):
+ with ops.Graph().as_default():
+ pred = array_ops.placeholder_with_default(True, shape=())
+ x = control_flow_ops.cond(pred,
+ lambda: constant_op.constant(1),
+ lambda: constant_op.constant(2))
+ self.assertIsNone(smart_cond.smart_constant_value(x))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
index 27a9ab8c60..546c48adba 100644
--- a/tensorflow/python/framework/tensor_spec.py
+++ b/tensorflow/python/framework/tensor_spec.py
@@ -65,6 +65,11 @@ class TensorSpec(object):
else:
raise ValueError("`tensor` should be a tf.Tensor")
+ @classmethod
+ def is_bounded(cls):
+ del cls
+ return False
+
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
@@ -80,6 +85,16 @@ class TensorSpec(object):
"""Returns the name of the described tensor."""
return self._name
+ @property
+ def is_discrete(self):
+ """Whether spec is discrete."""
+ return self.dtype.is_integer
+
+ @property
+ def is_continuous(self):
+ """Whether spec is continuous."""
+ return self.dtype.is_floating
+
def is_compatible_with(self, spec_or_tensor):
"""True if the shape and dtype of `spec_or_tensor` are compatible."""
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
@@ -164,6 +179,11 @@ class BoundedTensorSpec(TensorSpec):
self._maximum.setflags(write=False)
@classmethod
+ def is_bounded(cls):
+ del cls
+ return True
+
+ @classmethod
def from_spec(cls, spec):
dtype = dtypes.as_dtype(spec.dtype)
minimum = getattr(spec, "minimum", dtype.min)
diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py
index 54ca4d9a19..b33d769d86 100644
--- a/tensorflow/python/framework/tensor_spec_test.py
+++ b/tensorflow/python/framework/tensor_spec_test.py
@@ -127,6 +127,22 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
self.assertEqual(bounded_spec.dtype, spec.dtype)
self.assertEqual(bounded_spec.name, spec.name)
+ def testIsDiscrete(self):
+ discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+ self.assertTrue(discrete_spec.is_discrete)
+ self.assertFalse(continuous_spec.is_discrete)
+
+ def testIsContinuous(self):
+ discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+ self.assertFalse(discrete_spec.is_continuous)
+ self.assertTrue(continuous_spec.is_continuous)
+
+ def testIsBounded(self):
+ unbounded_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ self.assertFalse(unbounded_spec.is_bounded())
+
class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
@@ -138,6 +154,11 @@ class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "not compatible"):
tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1))
+ def testIsBounded(self):
+ bounded_spec = tensor_spec.BoundedTensorSpec(
+ (1, 2), dtypes.int32, minimum=0, maximum=1)
+ self.assertTrue(bounded_spec.is_bounded())
+
def testMinimumMaximumAttributes(self):
spec = tensor_spec.BoundedTensorSpec(
(1, 2, 3), dtypes.float32, 0, (5, 5, 5))
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 77ecc94b7f..284e264acd 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -434,6 +434,32 @@ def with_c_api(cls):
return cls
+def assert_no_new_pyobjects_executing_eagerly(f):
+ """Decorator for asserting that no new Python objects persist after a test.
+
+ Runs the test multiple times executing eagerly, first as a warmup and then
+ several times to let objects accumulate. The warmup helps ignore caches which
+ do not grow as the test is run repeatedly.
+
+ Useful for checking that there are no missing Py_DECREFs in the C exercised by
+ a bit of Python.
+ """
+ def decorator(self, **kwargs):
+ """Warms up, gets an object count, runs the test, checks for new objects."""
+ with context.eager_mode():
+ gc.disable()
+ f(self, **kwargs)
+ gc.collect()
+ previous_count = len(gc.get_objects())
+ for _ in range(3):
+ f(self, **kwargs)
+ gc.collect()
+ # There should be no new Python objects hanging around.
+ new_count = len(gc.get_objects())
+ self.assertEqual(previous_count, new_count)
+ gc.enable()
+ return decorator
+
def assert_no_new_tensors(f):
"""Decorator for asserting that no new Tensors persist after a test.
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 20d816050f..02ffa93bae 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -448,6 +448,26 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
LeakedTensorTest().test_has_no_leak()
+ def test_no_new_objects_decorator(self):
+
+ class LeakedObjectTest(object):
+
+ def __init__(inner_self): # pylint: disable=no-self-argument
+ inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
+ inner_self.accumulation = []
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def test_has_leak(self):
+ self.accumulation.append([1.])
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def test_has_no_leak(self):
+ self.not_accumulating = [1.]
+
+ with self.assertRaises(AssertionError):
+ LeakedObjectTest().test_has_leak()
+
+ LeakedObjectTest().test_has_no_leak()
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 16e56349c4..ffbdb0e61a 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import logging_ops
@@ -180,6 +181,11 @@ class ConstantTest(test.TestCase):
shape=[2, 3, 5])
self.assertEqual(c.get_shape(), [2, 3, 5])
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ """Tests PyObject refs are managed correctly when executing eagerly."""
+ constant_op.constant([[1.]])
+
def testImplicitShapeNumPy(self):
with ops.Graph().as_default():
c = constant_op.constant(
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 7d74046caf..cf45b07637 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import collections
-import gc
import numpy as np
@@ -84,27 +83,13 @@ class DenseTest(test.TestCase):
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
self.assertEqual(dense.bias.name, 'my_dense/bias:0')
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testNoEagerLeak(self):
# Tests that repeatedly constructing and building a Layer does not leak
# Python objects.
- def _test_fn():
- inputs = random_ops.random_uniform((5, 4), seed=1)
- core_layers.Dense(5)(inputs)
- core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)
-
- with context.eager_mode():
- _test_fn() # warmup
- gc.disable()
- gc.collect()
- object_count = len(gc.get_objects())
- for _ in range(100):
- _test_fn()
- gc.collect()
- self.assertLessEqual(
- len(gc.get_objects()),
- # DEBUG_SAVEALL messes with this slightly.
- object_count + 1)
- gc.enable()
+ inputs = random_ops.random_uniform((5, 4), seed=1)
+ core_layers.Dense(5)(inputs)
+ core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)
@test_util.run_in_graph_and_eager_modes()
def testCallTensorDot(self):
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 343415b264..02eafd42b3 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -164,9 +164,9 @@ bool IsSingleNone(PyObject* obj) {
}
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
-void ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
- Tensor* output_tensor) {
- *output_tensor = EagerTensor_Handle(eager_tensor)->t;
+tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+ const Tensor** output_tensor) {
+ return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor);
}
// Calls the registered py function through the trampoline.
@@ -220,7 +220,9 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
if (call->eager) {
const PyObject* item = PyList_GetItem(result, i);
if (EagerTensor_CheckExact(item)) {
- ExtractTensorFromEagerTensor(item, &t);
+ const Tensor* tensor = nullptr;
+ s = ExtractTensorFromEagerTensor(item, &tensor);
+ if (s.ok()) t = *tensor;
} else {
s = errors::FailedPrecondition(
"Expected EagerTensor, found PyObject of type: ",
@@ -238,10 +240,10 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
// result is an `EagerTensor` or `None`.
DCHECK(call->eager);
- Tensor t;
if (result != Py_None) {
- ExtractTensorFromEagerTensor(result, &t);
- call->out.push_back(t);
+ const Tensor* t = nullptr;
+ s = ExtractTensorFromEagerTensor(result, &t);
+ if (s.ok()) call->out.push_back(*t);
}
} else if (PyArray_Check(result)) {
// `result` is a NumPy array.
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 317bdc2e14..8247d354db 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -84,6 +84,7 @@ bool IsPyDimension(PyObject* obj) {
}
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
+ std::vector<Safe_PyObjectPtr> refs_to_clean;
while (true) {
// We test strings first, in case a string is considered a sequence.
if (IsPyString(obj)) {
@@ -93,6 +94,7 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
if (length > 0) {
shape->AddDim(length);
obj = PySequence_GetItem(obj, 0);
+ refs_to_clean.push_back(make_safe(obj));
continue;
} else if (length == 0) {
shape->AddDim(length);
@@ -167,14 +169,15 @@ const char ErrorFoundFloat[] =
if (shape.dims() > 1) { \
/* Iterate over outer dim, and recursively convert each element. */ \
const int64 s = shape.dim_size(0); \
- if (TF_PREDICT_FALSE(s != PySequence_Length(obj))) { \
+ Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
+ if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
return ErrorRectangular; \
} \
TensorShape rest = shape; \
rest.RemoveDim(0); \
for (int64 i = 0; i < s; ++i) { \
- const char* error = \
- FUNCTION##Helper(PySequence_GetItem(obj, i), rest, buf); \
+ const char* error = FUNCTION##Helper( \
+ PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
} \
} else { \
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b24711c8a6..0b3509360e 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -331,7 +331,7 @@ def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name)
-multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Mul", "`tf.multiply`")
+multiply.__doc__ = gen_math_ops.mul.__doc__.replace("Multiply", "`tf.multiply`")
# TODO(aselle): put deprecation in after another round of global code changes
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 21eea3db25..af9dae2aa6 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1049,6 +1049,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = sess.run(y)
self.assertAllEqual(y_val, [7, 9, 3, 4])
+ def testNHWCToHWNC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [9, 7, 4, 3])
+
def testNHWCToNCHW2D(self):
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
@@ -1057,6 +1073,22 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = sess.run(y)
self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]])
+ def testNHWCToHWNC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]])
+
def testNCHWToNHWC2D(self):
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 63f16c53a2..5415881cae 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -14,6 +14,7 @@ py_library(
name = "tools_pip",
deps = [
":freeze_graph",
+ ":import_pb_to_tensorboard",
":inspect_checkpoint",
":optimize_for_inference",
":print_selective_registration_header",
diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD b/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD
deleted file mode 100755
index 0acc139df9..0000000000
--- a/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD
+++ /dev/null
@@ -1,67 +0,0 @@
-package(default_visibility = ["//visibility:public"])
-
-load("@rbe_integration_test//skylark:integration_tests.bzl", "sut_component", "integration_test")
-load("@rbe_integration_test//skylark:toolchains.bzl", "toolchain_container_images")
-
-sut_component(
- name = "gcs",
- docker_image = toolchain_container_images()["tensorflow"],
- setups = [{
- "program": "setup.sh",
- "args": [
- "gs://tensorflow-test-bucket/tf-gcs-test",
- ],
- "output_properties": ["gcs_path"],
- "timeout_seconds": 100,
- }],
- teardowns = [{
- "program": "teardown.sh",
- "args": ["{gcs_path}"],
- "timeout_seconds": 100,
- }],
-)
-
-py_binary(
- name = "gcs_smoke",
- srcs = ["gcs_smoke.py"],
-)
-
-sh_binary(
- name = "test_wrapper",
- srcs = ["test_wrapper.sh"],
- data = [
- "gcs_smoke",
- ],
-)
-
-integration_test(
- name = "gcs_smoke_test",
- sut_deps = {
- ":gcs": "gcs",
- },
- tags = [
- "manual",
- "notap",
- ],
- test = {
- "program": ":test_wrapper",
- "args": [
- "--gcs_bucket_url={gcs#gcs_path}",
- "--num_examples=20",
- ],
- "timeout_seconds": 250,
- },
- test_docker_image = toolchain_container_images()["tensorflow"],
- test_type = "MultiMachine",
-)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py b/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py
deleted file mode 100755
index 8438c2156c..0000000000
--- a/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Smoke test for reading records from GCS to TensorFlow."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-import time
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.core.example import example_pb2
-from tensorflow.python.lib.io import file_io
-
-flags = tf.app.flags
-flags.DEFINE_string("gcs_bucket_url", "",
- "The URL to the GCS bucket in which the temporary "
- "tfrecord file is to be written and read, e.g., "
- "gs://my-gcs-bucket/test-directory")
-flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
-
-FLAGS = flags.FLAGS
-
-
-def create_examples(num_examples, input_mean):
- """Create ExampleProto's containing data."""
- ids = np.arange(num_examples).reshape([num_examples, 1])
- inputs = np.random.randn(num_examples, 1) + input_mean
- target = inputs - input_mean
- examples = []
- for row in range(num_examples):
- ex = example_pb2.Example()
- ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0]))
- ex.features.feature["target"].float_list.value.append(target[row, 0])
- ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
- examples.append(ex)
- return examples
-
-
-def create_dir_test():
- """Verifies file_io directory handling methods."""
-
- # Test directory creation.
- starttime_ms = int(round(time.time() * 1000))
- dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
- print("Creating dir %s" % dir_name)
- file_io.create_dir(dir_name)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Created directory in: %d milliseconds" % elapsed_ms)
-
- # Check that the directory exists.
- dir_exists = file_io.is_directory(dir_name)
- assert dir_exists
- print("%s directory exists: %s" % (dir_name, dir_exists))
-
- # Test recursive directory creation.
- starttime_ms = int(round(time.time() * 1000))
- recursive_dir_name = "%s/%s/%s" % (dir_name,
- "nested_dir1",
- "nested_dir2")
- print("Creating recursive dir %s" % recursive_dir_name)
- file_io.recursive_create_dir(recursive_dir_name)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Created directory recursively in: %d milliseconds" % elapsed_ms)
-
- # Check that the directory exists.
- recursive_dir_exists = file_io.is_directory(recursive_dir_name)
- assert recursive_dir_exists
- print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists))
-
- # Create some contents in the just created directory and list the contents.
- num_files = 10
- files_to_create = ["file_%d.txt" % n for n in range(num_files)]
- for file_num in files_to_create:
- file_name = "%s/%s" % (dir_name, file_num)
- print("Creating file %s." % file_name)
- file_io.write_string_to_file(file_name, "test file.")
-
- print("Listing directory %s." % dir_name)
- starttime_ms = int(round(time.time() * 1000))
- directory_contents = file_io.list_directory(dir_name)
- print(directory_contents)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms))
- assert set(directory_contents) == set(files_to_create + ["nested_dir1/"])
-
- # Test directory renaming.
- dir_to_rename = "%s/old_dir" % dir_name
- new_dir_name = "%s/new_dir" % dir_name
- file_io.create_dir(dir_to_rename)
- assert file_io.is_directory(dir_to_rename)
- assert not file_io.is_directory(new_dir_name)
-
- starttime_ms = int(round(time.time() * 1000))
- print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name))
- file_io.rename(dir_to_rename, new_dir_name)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Renamed directory %s to %s in %s milliseconds" % (
- dir_to_rename, new_dir_name, elapsed_ms))
- assert not file_io.is_directory(dir_to_rename)
- assert file_io.is_directory(new_dir_name)
-
- # Test Delete directory recursively.
- print("Deleting directory recursively %s." % dir_name)
- starttime_ms = int(round(time.time() * 1000))
- file_io.delete_recursively(dir_name)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- dir_exists = file_io.is_directory(dir_name)
- assert not dir_exists
- print("Deleted directory recursively %s in %s milliseconds" % (
- dir_name, elapsed_ms))
-
-
-def create_object_test():
- """Verifies file_io's object manipulation methods ."""
- starttime_ms = int(round(time.time() * 1000))
- dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
- print("Creating dir %s." % dir_name)
- file_io.create_dir(dir_name)
-
- num_files = 5
- # Create files of 2 different patterns in this directory.
- files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n)
- for n in range(num_files)]
- files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n)
- for n in range(num_files)]
-
- starttime_ms = int(round(time.time() * 1000))
- files_to_create = files_pattern_1 + files_pattern_2
- for file_name in files_to_create:
- print("Creating file %s." % file_name)
- file_io.write_string_to_file(file_name, "test file creation.")
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Created %d files in %s milliseconds" %
- (len(files_to_create), elapsed_ms))
-
- # Listing files of pattern1.
- list_files_pattern = "%s/test_file*.txt" % dir_name
- print("Getting files matching pattern %s." % list_files_pattern)
- starttime_ms = int(round(time.time() * 1000))
- files_list = file_io.get_matching_files(list_files_pattern)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Listed files in %s milliseconds" % elapsed_ms)
- print(files_list)
- assert set(files_list) == set(files_pattern_1)
-
- # Listing files of pattern2.
- list_files_pattern = "%s/testfile*.txt" % dir_name
- print("Getting files matching pattern %s." % list_files_pattern)
- starttime_ms = int(round(time.time() * 1000))
- files_list = file_io.get_matching_files(list_files_pattern)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("Listed files in %s milliseconds" % elapsed_ms)
- print(files_list)
- assert set(files_list) == set(files_pattern_2)
-
- # Test renaming file.
- file_to_rename = "%s/oldname.txt" % dir_name
- file_new_name = "%s/newname.txt" % dir_name
- file_io.write_string_to_file(file_to_rename, "test file.")
- assert file_io.file_exists(file_to_rename)
- assert not file_io.file_exists(file_new_name)
-
- print("Will try renaming file %s to %s" % (file_to_rename, file_new_name))
- starttime_ms = int(round(time.time() * 1000))
- file_io.rename(file_to_rename, file_new_name)
- elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
- print("File %s renamed to %s in %s milliseconds" % (
- file_to_rename, file_new_name, elapsed_ms))
- assert not file_io.file_exists(file_to_rename)
- assert file_io.file_exists(file_new_name)
-
- # Delete directory.
- print("Deleting directory %s." % dir_name)
- file_io.delete_recursively(dir_name)
-
-
-def main(argv):
- del argv # Unused.
- # Sanity check on the GCS bucket URL.
- if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
- print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
- sys.exit(1)
-
- # Verify that writing to the records file in GCS works.
- print("\n=== Testing writing and reading of GCS record file... ===")
- example_data = create_examples(FLAGS.num_examples, 5)
- with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf:
- for e in example_data:
- hf.write(e.SerializeToString())
-
- print("Data written to: %s" % FLAGS.gcs_bucket_url)
-
- # Verify that reading from the tfrecord file works and that
- # tf_record_iterator works.
- record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url)
- read_count = 0
- for _ in record_iter:
- read_count += 1
- print("Read %d records using tf_record_iterator" % read_count)
-
- if read_count != FLAGS.num_examples:
- print("FAIL: The number of records read from tf_record_iterator (%d) "
- "differs from the expected number (%d)" % (read_count,
- FLAGS.num_examples))
- sys.exit(1)
-
- # Verify that running the read op in a session works.
- print("\n=== Testing TFRecordReader.read op in a session... ===")
- with tf.Graph().as_default() as _:
- filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url],
- num_epochs=1)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
-
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- sess.run(tf.local_variables_initializer())
- tf.train.start_queue_runners()
- index = 0
- for _ in range(FLAGS.num_examples):
- print("Read record: %d" % index)
- sess.run(serialized_example)
- index += 1
-
- # Reading one more record should trigger an exception.
- try:
- sess.run(serialized_example)
- print("FAIL: Failed to catch the expected OutOfRangeError while "
- "reading one more record than is available")
- sys.exit(1)
- except tf.errors.OutOfRangeError:
- print("Successfully caught the expected OutOfRangeError while "
- "reading one more record than is available")
-
- create_dir_test()
- create_object_test()
-
-if __name__ == "__main__":
- tf.app.run(main)
diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh
deleted file mode 100755
index 6553ba5e30..0000000000
--- a/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-GCS_NUMBER=$(cat /dev/urandom | tr -dc 'A-F0-9' | fold -w 8 | head -n 1)
-GCS_PATH="$1"/"$GCS_NUMBER".tfrecord
-
-echo "gcs_path=$GCS_PATH" > "$_SETUP_OUTPUT"
-touch "$_SETUP_DONE"
diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh
deleted file mode 100755
index 852486d167..0000000000
--- a/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh
+++ /dev/null
@@ -1,26 +0,0 @@
-#!/bin/bash
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-GSUTIL_BIN="/var/gcloud/google-cloud-sdk/bin/gsutil"
-
-echo "Got teardown argument $1"
-
-if "${GSUTIL_BIN}" rm "$1"
-then
- echo "Cleaned up new tfrecord file in GCS: '$1'"
-else
- echo "FAIL: Unable to clean up new tfrecord file in GCS: '$1'"
- exit 1
-fi
diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh
deleted file mode 100755
index d4b6524a81..0000000000
--- a/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/bin/bash
-# This is a python2 only test.
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# Test Tensorflow package installation.
-/usr/local/bin/pip install --user tf-nightly
-
-# Test Tensorflow interaction with GCS.
-python tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py "$@"
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 9b02b2f94c..2607b9d704 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -169,6 +169,7 @@ sh_binary(
"//tensorflow:windows_msvc": [":simple_console_for_windows"],
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
+ "//tensorflow/contrib/lite/python:interpreter_test_data",
"//tensorflow/contrib/lite/toco:toco",
"//tensorflow/contrib/lite/toco/python:toco_wrapper",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 73d759eb13..b66c45ec13 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -58,6 +58,10 @@ BLACKLIST = [
# contrib
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
"//tensorflow/contrib/keras:testing_utils",
+ "//tensorflow/contrib/lite/python:interpreter",
+ "//tensorflow/contrib/lite/python:interpreter_test",
+ "//tensorflow/contrib/lite/python:interpreter.py",
+ "//tensorflow/contrib/lite/python:interpreter_test.py",
"//tensorflow/contrib/ffmpeg:test_data",
"//tensorflow/contrib/factorization/examples:mnist",
"//tensorflow/contrib/factorization/examples:mnist.py",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 754e70a3f2..f5b66edfec 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -481,11 +481,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/636e2230de961637b059b9cd15799daef32544f8.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/636e2230de961637b059b9cd15799daef32544f8.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/197b6c81959a17be37035d4fe71b382023bff2f0.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/197b6c81959a17be37035d4fe71b382023bff2f0.tar.gz",
],
- sha256 = "44f08a32ac48eca545fd6eac4d5ef3a9cea4382f805b87dce38340255e7d2138",
- strip_prefix = "llvm-636e2230de961637b059b9cd15799daef32544f8",
+ sha256 = "e77a8715fbd5d3c049bc7707da236152faab50ee2b7cec5234a0737b72ddb52a",
+ strip_prefix = "llvm-197b6c81959a17be37035d4fe71b382023bff2f0",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
)
@@ -703,16 +703,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
)
tf_http_archive(
- name = "rbe_integration_test",
- urls = [
- "http://mirror.bazel.build/github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz",
- "https://github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz",
- ],
- sha256 = "66d93b3919a165d486c31f5290d312abe9fda2685242f812c110653c124e1db4",
- strip_prefix = "rbe-integration-test-78a6194c7dda200b9522cf07707e3bc695804d1e",
- )
-
- tf_http_archive(
name = "arm_neon_2_x86_sse",
sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",