aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar xiejw <xiejw0217@gmail.com>2018-03-26 11:19:54 -0700
committerGravatar GitHub <noreply@github.com>2018-03-26 11:19:54 -0700
commit8d1392f7c75e766ce1181136c7ca90e660350a09 (patch)
tree91369ad40fbeec8eae18a9351396257ef1de1a01
parent90e015a863246cd022ce7257204f0d716ac2d400 (diff)
parentfb822667ca73e0d84d31fb5b1d03917db6b73095 (diff)
Merge pull request #18002 from xiejw/branch_190479555
Branch 190479555
-rw-r--r--SECURITY.md19
-rw-r--r--tensorflow/c/eager/BUILD3
-rw-r--r--tensorflow/c/eager/c_api.cc301
-rw-r--r--tensorflow/c/eager/c_api_internal.h85
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/client.cc51
-rw-r--r--tensorflow/compiler/xla/client/client.h27
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc18
-rw-r--r--tensorflow/compiler/xla/client/local_client.h9
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD14
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc700
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h729
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc92
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.cc26
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.h55
-rw-r--r--tensorflow/compiler/xla/literal_util.cc24
-rw-r--r--tensorflow/compiler/xla/literal_util.h8
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc24
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h4
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc6
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD16
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc16
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc7
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc150
-rw-r--r--tensorflow/compiler/xla/service/local_service.h13
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc29
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/service.cc119
-rw-r--r--tensorflow/compiler/xla/service/service.h21
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc101
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h91
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc35
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py25
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py141
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py21
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py43
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py24
-rw-r--r--tensorflow/contrib/estimator/BUILD2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py108
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py140
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD2
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py314
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py157
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py167
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java197
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java209
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java54
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java176
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt1001
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py6
-rw-r--r--tensorflow/contrib/lite/toco/BUILD37
-rw-r--r--tensorflow/contrib/lite/toco/args.h7
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc4
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc6
-rw-r--r--tensorflow/contrib/lite/toco/toco.cc97
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc98
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.cc186
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.h53
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model_test.cc274
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees.py6
-rw-r--r--tensorflow/contrib/py2tf/converters/for_loops.py30
-rw-r--r--tensorflow/contrib/py2tf/converters/lists.py3
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py16
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py17
-rw-r--r--tensorflow/contrib/py2tf/utils/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/utils/__init__.py2
-rw-r--r--tensorflow/contrib/py2tf/utils/builtins.py69
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_feed.py19
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py42
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/constant_folding_test.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD43
-rw-r--r--tensorflow/core/common_runtime/eager/copy_to_device_node.h69
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc178
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h133
-rw-r--r--tensorflow/core/framework/device_base.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD36
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc7
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.h43
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc44
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc14
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc1
-rw-r--r--tensorflow/core/kernels/immutable_constant_op_test.cc1
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc81
-rw-r--r--tensorflow/core/kernels/scatter_functor.cc27
-rw-r--r--tensorflow/core/kernels/scatter_functor.h170
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.h108
-rw-r--r--tensorflow/core/kernels/scatter_op.cc126
-rw-r--r--tensorflow/core/kernels/scatter_op_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc26
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt359
-rw-r--r--tensorflow/core/ops/ops.pbtxt359
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc92
-rw-r--r--tensorflow/core/ops/state_ops.cc25
-rw-r--r--tensorflow/core/platform/env_test.cc1
-rw-r--r--tensorflow/core/platform/file_system.h68
-rw-r--r--tensorflow/core/platform/file_system_test.cc1
-rw-r--r--tensorflow/core/platform/null_file_system.h98
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto2
-rw-r--r--tensorflow/docs_src/api_guides/python/state_ops.md2
-rw-r--r--tensorflow/docs_src/community/index.md3
-rw-r--r--tensorflow/docs_src/community/leftnav_files1
-rw-r--r--tensorflow/docs_src/community/security.md7
-rw-r--r--tensorflow/docs_src/get_started/get_started_for_beginners.md34
-rw-r--r--tensorflow/docs_src/get_started/index.md7
-rw-r--r--tensorflow/docs_src/install/install_linux.md11
-rw-r--r--tensorflow/docs_src/install/install_mac.md12
-rw-r--r--tensorflow/docs_src/install/install_windows.md14
-rw-r--r--tensorflow/docs_src/programmers_guide/embedding.md3
-rw-r--r--tensorflow/go/op/wrappers.go2990
-rw-r--r--tensorflow/python/BUILD3
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc7
-rw-r--r--tensorflow/python/estimator/canned/head.py3
-rw-r--r--tensorflow/python/framework/meta_graph.py4
-rw-r--r--tensorflow/python/framework/meta_graph_test.py15
-rw-r--r--tensorflow/python/framework/ops.py6
-rw-r--r--tensorflow/python/framework/test_file_system.cc1
-rwxr-xr-xtensorflow/python/keras/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py221
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py145
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.cc4
-rw-r--r--tensorflow/python/lib/core/py_func.cc2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py32
-rw-r--r--tensorflow/python/ops/standard_ops.py2
-rw-r--r--tensorflow/python/ops/state_ops.py2
-rw-r--r--tensorflow/python/training/device_util.py68
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt8
-rwxr-xr-xtensorflow/tools/ci_build/builds/android.sh3
-rwxr-xr-xtensorflow/tools/ci_build/builds/android_full.sh6
-rw-r--r--tensorflow/tools/docs/generate_lib.py1
-rw-r--r--third_party/examples/eager/spinn/spinn.py168
176 files changed, 10170 insertions, 3021 deletions
diff --git a/SECURITY.md b/SECURITY.md
index 378e776967..5ca304404d 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -168,7 +168,18 @@ below).
Please use a descriptive subject line for your report email. After the initial
reply to your report, the security team will endeavor to keep you informed of
-the progress being made towards a fix and announcement.
+the progress being made towards a fix and announcement.
+
+In addition, please include the following information along with your report:
+
+* Your name and affiliation (if any).
+* A description the technical details of the vulnerabilities. It is very
+ important to let us know how we can reproduce your findings.
+* An explanation who can exploit this vulnerability, and what they gain when
+ doing so -- write an attack scenario. This will help us evaluate your report
+ quickly, especially if the issue is complex.
+* Whether this vulnerability public or known to third parties. If it is, please
+ provide details.
If you believe that an existing (public) issue is security-related, please send
an email to `security@tensorflow.org`. The email should include the issue ID and
@@ -233,7 +244,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known vulnerabilities
-| Type | Versions affected | Reported by | Additional Information |
-|-------------------|:-----------------:|--------------------|-----------------------------|
-| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+| Type | Versions affected | Reported by | Additional Information |
+|--------------------|:-----------------:|--------------------|-----------------------------|
+| Out Of Bounds Read | <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index bea5a121b3..8df7b56623 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -31,6 +31,8 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/common_runtime/eager:copy_to_device_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@@ -68,6 +70,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 2402a6d044..eaeb2fd07a 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -161,29 +162,32 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
DCHECK(h);
- h->Unref();
+ if (h->handle) {
+ h->handle->Unref();
+ }
+ delete h;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
- return static_cast<TF_DataType>(h->dtype);
+ return static_cast<TF_DataType>(h->handle->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
- status->status = h->Tensor(&t);
+ status->status = h->handle->Tensor(&t);
return t == nullptr ? 0 : t->dims();
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
- status->status = h->Tensor(&t);
+ status->status = h->handle->Tensor(&t);
return t == nullptr ? 0 : t->dim_size(dim_index);
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
- status->status = h->OpDevice(&d);
+ status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
@@ -193,7 +197,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
const tensorflow::Tensor* t = nullptr;
- status->status = h->TensorAndDevice(&t, &d, &op_device);
+ status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
if (!status->status.ok()) return nullptr;
if (!IsCPU(d)) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
@@ -210,82 +214,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
} // extern "C"
-namespace {
-
-tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
- TFE_Context* ctx,
- tensorflow::Device* dstd,
- TFE_TensorHandle** output) {
- const tensorflow::Tensor* src = nullptr;
- tensorflow::Device* srcd = nullptr;
- // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
- // nullptr.
- tensorflow::Device* src_opd = nullptr;
- TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
- if (srcd == nullptr) srcd = ctx->context.HostCPU();
- bool is_same_device =
- (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
- const bool dst_cpu = IsCPU(dstd);
- const bool src_cpu = IsCPU(srcd);
- // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
- // has device type XLA_CPU, and the other CPU.
- const bool both_on_cpu = src_cpu && dst_cpu;
- if (is_same_device || both_on_cpu) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(*src, dstd, dstd);
- return tensorflow::Status::OK();
- }
- if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
- !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
- return tensorflow::errors::InvalidArgument(
- "Can't copy Tensor with type ",
- tensorflow::DataTypeString(src->dtype()), " to device ",
- DeviceName(dstd), ".");
- }
- tensorflow::AllocatorAttributes attr;
- if (src->dtype() == tensorflow::DT_VARIANT) {
- attr.set_on_host(true);
- }
- tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
- if (src->shape().num_elements() == 0) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(dst, dstd, dstd);
- return tensorflow::Status::OK();
- }
- tensorflow::DeviceContext* src_device_context = nullptr;
- if (!src_cpu) {
- src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
- }
- tensorflow::DeviceContext* dst_device_context = nullptr;
- if (!dst_cpu) {
- dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
- }
- // TODO(ashankar): The Sync() call below may be more aggressive than
- // necessary. It is based on knowledge of implementation details - that
- // GPU devices are implemented using 3 streams - one for host->device copies,
- // one for device->host copies and one for sending operations to the GPU.
- // With that setup, Sync()ing across all 3 streams should be sufficient
- // but more than necessary (since it waits for operations that might have
- // nothing to do with this tensor to complete).
- TF_RETURN_IF_ERROR(srcd->Sync());
- tensorflow::Notification n;
- tensorflow::Status status;
- tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
- srcd, dstd, tensorflow::AllocatorAttributes(),
- tensorflow::AllocatorAttributes(), src, &dst,
- [&status, &n](const tensorflow::Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- if (status.ok()) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(dst, dstd, dstd);
- }
- return status;
-}
-} // namespace
-
extern "C" {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
@@ -335,12 +263,12 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
// TODO(agarwal): This call may block if h is not ready. Avoid this if
// possible.
- status->status = h->Device(&d);
+ status->status = h->handle->Device(&d);
if (!status->status.ok()) return;
if (!IsCPU(d)) op->device = d;
}
- h->Ref();
- op->inputs.push_back(h);
+ h->handle->Ref();
+ op->inputs.push_back(h->handle);
op->attrs.NumInputs(op->inputs.size());
}
@@ -506,6 +434,37 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
namespace {
+// TODO(apassos) move to TensorHandle
+tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
+ tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name,
+ TF_Status* status) {
+ status->status = ctx->context.GetStatus();
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ tensorflow::Device* dstd = ctx->context.HostCPU();
+ if (device_name != nullptr && strlen(device_name) > 0) {
+ status->status =
+ ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
+ if (!status->status.ok()) return nullptr;
+ }
+ if (ctx->context.Async()) {
+ // Note that `h` may not be currently ready. However execution order will
+ // make sure that `h` is ready before the copy is actually done.
+ tensorflow::CopyToDeviceNode* node =
+ new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context);
+ tensorflow::TensorHandle* output = node->dst();
+ // Note that calling Add makes `node` accessible by the EagerExecutor
+ // thread. So further accesses need to be thread-safe.
+ ctx->context.ExecutorAdd(node);
+ return output;
+ } else {
+ tensorflow::TensorHandle* output = nullptr;
+ status->status = h->CopyToDevice(&ctx->context, dstd, &output);
+ return output;
+ }
+}
+
tensorflow::Status ValidateInputTypeAndPlacement(
TFE_Context* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
@@ -518,7 +477,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
for (int i = 0; i < op->inputs.size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
- TFE_TensorHandle* handle = op->inputs[i];
+ tensorflow::TensorHandle* handle = op->inputs[i];
tensorflow::Device* handle_device = nullptr;
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
const tensorflow::Device* actual_device =
@@ -560,8 +519,9 @@ tensorflow::Status ValidateInputTypeAndPlacement(
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TF_Status* s = TF_NewStatus();
- TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
- handle, ctx, expected_device->name().c_str(), s);
+ tensorflow::TensorHandle* copied_tensor =
+ TFE_TensorHandleCopyToDevice_Internal(
+ handle, ctx, expected_device->name().c_str(), s);
tensorflow::Status status = s->status;
TF_DeleteStatus(s);
if (!status.ok()) {
@@ -616,9 +576,10 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
tensorflow::Status Execute(
TFE_Context* ctx, tensorflow::Device* device,
- const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
+ const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>&
+ op_inputs,
tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
- TFE_TensorHandle** retvals, int num_retvals) {
+ tensorflow::TensorHandle** retvals, int num_retvals) {
if (!ctx->context.SoftPlacement() && device == nullptr) {
device = ctx->context.HostCPU();
}
@@ -683,7 +644,7 @@ tensorflow::Status Execute(
d = nullptr;
}
if (retvals[i] == nullptr) {
- retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device);
+ retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device);
} else {
retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
}
@@ -711,9 +672,10 @@ class ExecuteNode : public tensorflow::EagerNode {
}
TFE_Context* ctx = op->ctx;
for (int i = 0; i < num_retvals; ++i) {
- TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx);
+ tensorflow::TensorHandle* h =
+ new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context);
h->Ref();
- retvals[i] = h;
+ retvals[i] = new TFE_TensorHandle(h);
retvals_[i] = h;
}
}
@@ -745,54 +707,12 @@ class ExecuteNode : public tensorflow::EagerNode {
private:
TFE_Context* ctx_;
tensorflow::Device* op_device_;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
tensorflow::KernelAndDevice* kernel_;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_;
};
-class CopyToDeviceNode : public tensorflow::EagerNode {
- public:
- CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
- TFE_Context* ctx)
- : tensorflow::EagerNode(ctx->context.NextId()),
- src_(src),
- dstd_(dstd),
- ctx_(ctx),
- dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) {
- src_->Ref();
- dst_->Ref();
- }
-
- ~CopyToDeviceNode() override {
- src_->Unref();
- dst_->Unref();
- }
-
- tensorflow::Status Run() override {
- TFE_TensorHandle* temp = nullptr;
- TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
- const tensorflow::Tensor* tensor = nullptr;
- tensorflow::Device* device = nullptr;
- tensorflow::Device* op_device = nullptr;
- tensorflow::Status status =
- temp->TensorAndDevice(&tensor, &device, &op_device);
- // `temp` is a ready handle. So the following call should return OK.
- TF_DCHECK_OK(status) << status.error_message();
- DCHECK(tensor);
- dst_->SetTensorAndDevice(*tensor, device, op_device);
- temp->Unref();
- return tensorflow::Status::OK();
- }
-
- TFE_TensorHandle* dst() { return dst_; }
-
- private:
- TFE_TensorHandle* src_;
- tensorflow::Device* dstd_;
- TFE_Context* ctx_;
- TFE_TensorHandle* dst_;
-};
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
@@ -917,7 +837,7 @@ const tensorflow::FunctionDef* OpToFunction(
}
VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
- ctx->context.AddFunctionDef(fdef);
+ status->status = ctx->context.AddFunctionDef(fdef);
if (!status->status.ok()) return nullptr;
const auto ret = ctx->context.FindFunctionDef(signature->name());
DCHECK(ret != nullptr);
@@ -965,7 +885,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
// Since input param reordering may have occurred between `op` and `launch_op`
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
launch_op->inputs = op->inputs;
- for (TFE_TensorHandle* h : launch_op->inputs) {
+ for (tensorflow::TensorHandle* h : launch_op->inputs) {
h->Ref();
}
if (!op_input_to_func_input.empty()) {
@@ -1140,11 +1060,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
+ std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
+ nullptr);
+ status->status =
+ Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(),
+ handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
- retvals[i] = nullptr;
+ retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
- status->status = Execute(op->ctx, op->device, op->inputs, kernel,
- maybe_stats.get(), retvals, *num_retvals);
}
}
@@ -1152,30 +1075,12 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- status->status = ctx->context.GetStatus();
- if (!status->status.ok()) {
- return nullptr;
- }
- tensorflow::Device* dstd = ctx->context.HostCPU();
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status =
- ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
- if (!status->status.ok()) return nullptr;
- }
- if (ctx->context.Async()) {
- // Note that `h` may not be currently ready. However execution order will
- // make sure that `h` is ready before the copy is actually done.
- CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
- TFE_TensorHandle* output = node->dst();
- // Note that calling Add makes `node` accessible by the EagerExecutor
- // thread. So further accesses need to be thread-safe.
- ctx->context.ExecutorAdd(node);
- return output;
- } else {
- TFE_TensorHandle* output = nullptr;
- status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
- return output;
+ tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal(
+ h->handle, ctx, device_name, status);
+ if (status->status.ok()) {
+ return new TFE_TensorHandle(handle);
}
+ return nullptr;
}
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
@@ -1214,7 +1119,7 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
const tensorflow::Tensor* t = nullptr;
- status->status = h->TensorAndDevice(&t, &d, &op_device);
+ status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
if (!status->status.ok()) return nullptr;
if (d != nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -1306,70 +1211,8 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} // namespace tensorflow
-
-bool TFE_TensorHandle::IsReady() {
- if (node_id == 0) return true;
- tensorflow::mutex_lock l(ctx_mutex_);
- return ctx_ == nullptr;
-}
-
-tensorflow::Status TFE_TensorHandle::WaitReady() {
- if (node_id == 0) return tensorflow::Status::OK();
- tensorflow::EagerExecutor* executor = nullptr;
- {
- tensorflow::mutex_lock l(ctx_mutex_);
- if (ctx_ == nullptr) return tensorflow::Status::OK();
- executor = ctx_->context.Executor();
- }
- return executor->WaitFor(node_id);
-}
-
-tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *t = &tensor_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *d = device_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *d = op_device_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::TensorAndDevice(
- const tensorflow::Tensor** tensor, tensorflow::Device** device,
- tensorflow::Device** op_device) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *tensor = &tensor_;
- *device = device_;
- *op_device = op_device_;
- return tensorflow::Status::OK();
-}
-
-void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
- tensorflow::Device* device,
- tensorflow::Device* op_device) {
- tensorflow::mutex_lock l(ctx_mutex_);
- DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called "
- << "on non-ready handles.";
- ctx_ = nullptr;
- tensor_ = tensor;
- device_ = device;
- op_device_ = op_device;
-}
-
TFE_Op::~TFE_Op() {
- for (TFE_TensorHandle* h : inputs) {
+ for (tensorflow::TensorHandle* h : inputs) {
h->Unref();
}
}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 5b29120b40..e6d2ab75ff 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -67,84 +68,18 @@ struct TFE_Context {
tensorflow::EagerContext context;
};
-struct TFE_TensorHandle : public tensorflow::core::RefCounted {
- public:
+struct TFE_TensorHandle {
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
tensorflow::Device* op_device)
- : dtype(t.dtype()),
- node_id(0),
- tensor_(t),
- device_(d),
- op_device_(op_device),
- ctx_(nullptr) {}
+ : handle(new tensorflow::TensorHandle(t, d, op_device)) {}
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
- TFE_Context* ctx)
- : dtype(dtype),
- node_id(node_id),
- tensor_(dtype),
- device_(nullptr),
- op_device_(nullptr),
- ctx_(ctx) {
- DCHECK_GT(node_id, 0);
- }
-
- ~TFE_TensorHandle() override {}
-
- tensorflow::Status Tensor(const tensorflow::Tensor** t);
-
- tensorflow::Status Device(tensorflow::Device** d);
-
- tensorflow::Status OpDevice(tensorflow::Device** d);
-
- tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor,
- tensorflow::Device** device,
- tensorflow::Device** op_device);
-
- // Note that this can be called at most once, and only on non-ready handles,
- // and makes them ready.
- void SetTensorAndDevice(const tensorflow::Tensor& tensor,
- tensorflow::Device* device,
- tensorflow::Device* op_device);
-
- // dtype for the handle. It must be the same as t.dtype() once the handle is
- // ready.
- const tensorflow::DataType dtype;
-
- private:
- // If the contents of the Tensor pointed to by this handle is yet to be
- // computed by a EagerNode, this function will block till that compuatation is
- // done and the handle is "ready".
- tensorflow::Status WaitReady();
-
- bool IsReady();
-
- // Id for the EagerNode that will compute the value pointed to by this handle.
- // If the value is 0, the handle is already ready, but not vice-versa.
- const tensorflow::uint64 node_id;
-
- tensorflow::Tensor tensor_;
-
- // TODO(ashankar): device_ == nullptr iff local CPU
- // This was expedient, but perhaps worth revisiting ('device_' should always
- // be a valid pointer?)
- // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
- // provided with the appropriate TFE_Context.
- //
- // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
- // TFE_TensorHandle does not outlive the TFE_Context from which it came?
- tensorflow::Device* device_;
-
- // Device in which the op producing this tensor was executed. Equals to
- // device_ for constant tensors.
- tensorflow::Device* op_device_;
-
- tensorflow::mutex ctx_mutex_;
-
- // `ctx` is only guaranteed to be set if the handle is not "ready". This is
- // typically true when the handle was produced during async execution.
- // `ctx` object is not owned and should outlive this handle.
- TFE_Context* ctx_ GUARDED_BY(ctx_mutex_);
+ tensorflow::EagerContext* ctx)
+ : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
+
+ TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
+
+ tensorflow::TensorHandle* handle;
};
struct TFE_Op {
@@ -161,7 +96,7 @@ struct TFE_Op {
const tensorflow::string name;
tensorflow::AttrBuilder attrs;
const tensorflow::AttrTypeMap* attr_types;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs;
tensorflow::Device* device;
bool use_xla = false;
};
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 02356699a2..5094e5ce67 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -74,6 +74,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index d15ccb0c28..5ce3c45528 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -177,6 +177,22 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
+StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const ExecutionOptions* execution_options,
+ ExecutionProfile* execution_profile) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<GlobalData> data,
+ Execute(computation, arguments, execution_options, execution_profile));
+
+ const Shape* shape_with_output_layout = nullptr;
+ if (execution_options && execution_options->has_shape_with_output_layout()) {
+ shape_with_output_layout = &execution_options->shape_with_output_layout();
+ }
+ return Transfer(*data, shape_with_output_layout);
+}
+
StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
LoadComputationSnapshotRequest request;
*request.mutable_module() = module;
@@ -231,6 +247,41 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
return MakeUnique<GlobalData>(stub_, response.output());
}
+StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const ExecutionOptions* execution_options,
+ ExecutionProfile* execution_profile) {
+ ExecuteGraphRequest request;
+ *request.mutable_computation() = computation.proto();
+
+ if (execution_options == nullptr) {
+ *request.mutable_execution_options() = CreateDefaultExecutionOptions();
+ } else {
+ *request.mutable_execution_options() = *execution_options;
+ }
+ for (GlobalData* argument : arguments) {
+ CHECK(argument != nullptr) << "Argument pointers must not be null.";
+ *request.add_arguments() = argument->handle();
+ }
+
+ ExecuteResponse response;
+ VLOG(1) << "making execute request: " << request.ShortDebugString();
+ Status s = stub_->ExecuteGraph(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (execution_profile != nullptr) {
+ *execution_profile = response.profile();
+ // TODO(b/74197823): Get execution stats for the graph and VLOG(1) them.
+ }
+
+ return MakeUnique<GlobalData>(stub_, response.output());
+}
+
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
tensorflow::gtl::ArraySlice<ComputationInstance> computations) {
ExecuteParallelRequest request;
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index c28380b689..ec87646ebf 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
@@ -57,6 +58,21 @@ class Client {
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
+ // Executes the computation with the given arguments and returns the global
+ // data that was produced from the execution.
+ // * If execution_options is not nullptr, these options are passed to the
+ // service to affect how it compiles our computation. (The pointer does not
+ // need to live beyond this call.)
+ // * If execution_profile is not nullptr then the pointed-to ExecutionProfile
+ // will be filled with profile data from the execution.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<GlobalData>> Execute(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const ExecutionOptions* execution_options = nullptr,
+ ExecutionProfile* execution_profile = nullptr);
+
// A struct to represent a computation instance to be executed.
// * If execution_options.device_handles is not empty, the computation is
// executed on the devices associated with the handles by partitioning the
@@ -137,6 +153,17 @@ class Client {
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
+ // Executes the computation with the given arguments and transfers the result
+ // to the client as a literal. Parameters are defined the same as for
+ // Execute() and Transfer().
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const ExecutionOptions* execution_options = nullptr,
+ ExecutionProfile* execution_profile = nullptr);
+
// Unregister the memory for the given GlobalData on the device.
Status Unregister(const GlobalData& data);
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 91396f055f..30594243dc 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -265,6 +265,24 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
updated_options));
}
+StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
+ const XlaComputation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& options) {
+ ExecutableBuildOptions updated_options = options;
+ if (options.device_ordinal() == -1) {
+ updated_options.set_device_ordinal(default_device_ordinal());
+ VLOG(3) << "Set device ordinal to default value of: "
+ << updated_options.device_ordinal();
+ }
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+ local_service_->CompileExecutable(
+ computation, argument_layouts, updated_options));
+ return WrapUnique(new LocalExecutable(std::move(executable),
+ local_service_->mutable_backend(),
+ updated_options));
+}
+
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
DeviceMemoryAllocator* allocator) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index de0ed13c43..98ee7c62c9 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -123,6 +123,15 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options);
+ // Build and return a LocalExecutable object. The executable is compiled using
+ // the given XlaComputation, argument layouts and options.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<LocalExecutable>> Compile(
+ const XlaComputation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& options);
+
// Copy the literal data to the device with the given ordinal and return as a
// ScopedShapedBuffer. If non-null the given memory allocator is used for
// device memory allocation. If null, the default memory allocator for the
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index b912889e26..cc5f551c9c 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -25,12 +25,25 @@ filegroup(
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+cc_library(
+ name = "xla_computation",
+ srcs = ["xla_computation.cc"],
+ hdrs = ["xla_computation.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
# TODO(b/74197823): Replace computation_builder with xla_builder.
cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],
hdrs = ["xla_builder.h"],
deps = [
+ ":xla_computation",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -38,6 +51,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shape_inference",
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 82b61d4d51..596f39b4fd 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include <numeric>
#include <string>
#include <utility>
@@ -80,40 +81,32 @@ void XlaBuilder::NoteError(const Status& error) {
}
}
-StatusOr<XlaComputation> XlaBuilder::Build() {
- if (!first_error_.ok()) {
- string backtrace;
- first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
- return AppendStatus(first_error_, backtrace);
- }
-
- HloComputationProto entry;
- ProgramShape* program_shape = entry.mutable_program_shape();
-
- entry.set_name(name_);
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) {
+ TF_RET_CHECK(root_id != nullptr);
+ ProgramShape program_shape;
// Not all instructions can be roots. Walk backwards from the last added
// instruction until a valid root is found.
- entry.set_root_id(-1);
- for (int64 i = instructions_.size() - 1; i >= 0; i--) {
+ int64 index = instructions_.size() - 1;
+ for (; index >= 0; index--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
- StringToHloOpcode(instructions_[i].opcode()));
+ StringToHloOpcode(instructions_[index].opcode()));
if (CanBeRoot(opcode)) {
- entry.set_root_id(instructions_[i].id());
- *program_shape->mutable_result() = instructions_[i].shape();
break;
}
}
- if (entry.root_id() == -1) {
+ if (index < 0) {
return FailedPrecondition("no root instruction was found");
}
+ *root_id = instructions_[index].id();
+ *program_shape.mutable_result() = instructions_[index].shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
const int64 param_count = parameter_numbers_.size();
for (int64 i = 0; i < param_count; i++) {
- program_shape->add_parameters();
- program_shape->add_parameter_names();
+ program_shape.add_parameters();
+ program_shape.add_parameter_names();
}
for (const HloInstructionProto& instr : instructions_) {
// Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
@@ -123,10 +116,35 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
const int64 index = instr.parameter_number();
TF_RET_CHECK(index >= 0 && index < param_count)
<< "invalid parameter number: " << index;
- *program_shape->mutable_parameters(index) = instr.shape();
- *program_shape->mutable_parameter_names(index) = instr.name();
+ *program_shape.mutable_parameters(index) = instr.shape();
+ *program_shape.mutable_parameter_names(index) = instr.name();
}
}
+ return program_shape;
+}
+
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape() {
+ int64 root_id;
+ return GetProgramShape(&root_id);
+}
+
+StatusOr<XlaComputation> XlaBuilder::Build() {
+ if (!first_error_.ok()) {
+ string backtrace;
+ first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
+ return AppendStatus(first_error_, backtrace);
+ }
+
+ HloComputationProto entry;
+ entry.set_name(name_);
+
+ {
+ int64 root_id;
+ ProgramShape program_shape;
+ TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id));
+ entry.mutable_program_shape()->Swap(&program_shape);
+ entry.set_root_id(root_id);
+ }
for (auto& instruction : instructions_) {
entry.add_instructions()->Swap(&instruction);
@@ -149,31 +167,134 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
return std::move(computation);
}
-XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- auto op = [&]() -> StatusOr<XlaOp> {
+StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
+ const Shape& shape, const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape;
+ for (int64 dim : broadcast_dimensions) {
+ instr.add_dimensions(dim);
+ }
+ return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
+}
+
+StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
+ const XlaOp& operand) {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+
+ CHECK(ShapeUtil::IsScalar(operand_shape) ||
+ ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
+ Shape broadcast_shape =
+ ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
+
+ // Do explicit broadcast for scalar.
+ if (ShapeUtil::IsScalar(operand_shape)) {
+ return InDimBroadcast(broadcast_shape, operand, {});
+ }
+
+ // Do explicit broadcast for degenerate broadcast.
+ std::vector<int64> broadcast_dimensions;
+ std::vector<int64> reshaped_dimensions;
+ for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) {
+ if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
+ broadcast_dimensions.push_back(i);
+ reshaped_dimensions.push_back(operand_shape.dimensions(i));
+ } else {
+ TF_RET_CHECK(operand_shape.dimensions(i) == 1)
+ << "An explicit broadcast sequence requires the broadcasted "
+ "dimensions to be trivial; operand shape: "
+ << operand_shape << "; output_shape: " << output_shape;
+ }
+ }
+ // Eliminate the size one dimensions.
+ TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
+ Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
+ reshaped_dimensions),
+ operand));
+ // Broadcast 'reshape' up to the larger size.
+ return InDimBroadcast(broadcast_shape, reshaped_operand,
+ broadcast_dimensions);
+}
+
+XlaOp XlaBuilder::BinaryOp(
+ HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
- TF_ASSIGN_OR_RETURN(
- *instr.mutable_shape(),
- ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs_shape,
- rhs_shape, broadcast_dimensions));
- return AddInstruction(std::move(instr), HloOpcode::kAdd, {lhs, rhs});
- };
- return NoteErrorOrReturn(op());
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferBinaryOpShape(
+ binop, lhs_shape, rhs_shape, broadcast_dimensions));
+
+ const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
+ const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
+
+ XlaOp updated_lhs = lhs;
+ XlaOp updated_rhs = rhs;
+
+ if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
+ const bool should_broadcast_lhs = lhs_rank < rhs_rank;
+ XlaOp from = should_broadcast_lhs ? lhs : rhs;
+ const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
+
+ std::vector<int64> to_size;
+ for (int64 size : instr.shape().dimensions()) {
+ to_size.push_back(size);
+ }
+ for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
+ from_dim++) {
+ int64 to_dim = broadcast_dimensions[from_dim];
+ to_size[to_dim] = from_shape.dimensions(from_dim);
+ }
+
+ const Shape& broadcasted_shape =
+ ShapeUtil::MakeShape(from_shape.element_type(), to_size);
+ TF_ASSIGN_OR_RETURN(
+ XlaOp broadcasted_operand,
+ InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
+
+ updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
+ updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
+ }
+
+ TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape());
+ if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
+ TF_ASSIGN_OR_RETURN(updated_lhs,
+ AddBroadcastSequence(instr.shape(), updated_lhs));
+ }
+ TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape());
+ if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
+ TF_ASSIGN_OR_RETURN(updated_rhs,
+ AddBroadcastSequence(instr.shape(), updated_rhs));
+ }
+
+ return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
+ }());
+}
+
+XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
- HloInstructionProto instr;
- *instr.mutable_shape() = literal.shape();
- *instr.mutable_literal() = literal.ToProto();
- return AddInstruction(std::move(instr), HloOpcode::kConstant);
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = literal.shape();
+ *instr.mutable_literal() = literal.ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kConstant);
+ }());
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
- auto op = [&]() -> StatusOr<XlaOp> {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
std::vector<Shape> operand_shapes;
@@ -196,13 +317,12 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
}
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
- };
- return NoteErrorOrReturn(op());
+ }());
}
XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
const string& name) {
- auto op = [&]() -> StatusOr<XlaOp> {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
return InvalidArgument("parameter %lld already registered",
@@ -213,12 +333,496 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
instr.set_name(name);
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr), HloOpcode::kParameter);
- };
- return NoteErrorOrReturn(op());
+ }());
+}
+
+XlaOp XlaBuilder::Broadcast(
+ const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(
+ const Shape& shape,
+ ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
+
+ // The client-level broadcast op just appends dimensions on the left (adds
+ // lowest numbered dimensions). The HLO broadcast instruction is more
+ // flexible and can add new dimensions anywhere. The instruction's
+ // dimensions field maps operand dimensions to dimensions in the broadcast
+ // output, so to append dimensions on the left the instruction's dimensions
+ // should just be the n highest dimension numbers of the output shape where
+ // n is the number of input dimensions.
+ const int64 operand_rank = ShapeUtil::Rank(operand_shape);
+ std::vector<int64> dimensions(operand_rank);
+ for (int i = 0; i < operand_rank; ++i) {
+ dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank;
+ }
+ return InDimBroadcast(shape, operand, dimensions);
+ }());
+}
+
+StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape;
+ return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
+}
+
+XlaOp XlaBuilder::Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& shape,
+ ShapeInference::InferReshapeShape(
+ operand_shape, dimensions, new_sizes));
+ XlaOp transposed = IsIdentityPermutation(dimensions)
+ ? operand
+ : Transpose(operand, dimensions);
+ return Reshape(shape, transposed);
+ }());
+}
+
+XlaOp XlaBuilder::Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape());
+ std::vector<int64> dimensions(shape.dimensions_size());
+ std::iota(dimensions.begin(), dimensions.end(), 0);
+ return Reshape(operand, dimensions, new_sizes);
+ }());
+}
+
+XlaOp XlaBuilder::Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return UnimplementedOp();
+}
+
+void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
+ UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
+ const XlaOp& on_false) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
+ const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
+ return UnimplementedOp();
+}
+
+void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ UnimplementedOp();
+}
+
+XlaOp XlaBuilder::CustomCall(const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name,
+ int64 cost_estimate_ns, const Shape& shape) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Complex(
+ const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Conj(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::ShiftLeft(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Atan2(
+ const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Ceil(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Round(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Cos(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Sin(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Tanh(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Real(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Imag(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferTransposeShape(operand_shape, permutation));
+ for (int64 dim : permutation) {
+ instr.add_dimensions(dim);
+ }
+ return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
+ }());
+}
+
+XlaOp XlaBuilder::Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Sort(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnimplementedOp(); }
+
+XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
+ const XlaOp& max) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
+ const Shape& shape) {
+ return UnimplementedOp();
}
-XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
+ const Shape& shape) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::While(const XlaComputation& condition,
+ const XlaComputation& body, const XlaOp& init) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Reduce(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ReduceWindow(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::SelectAndScatter(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter) {
+ return UnimplementedOp();
+}
+
+XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits) {
+ return UnimplementedOp();
+}
+
+void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
+ UnimplementedOp();
+}
+
+XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
+ return UnimplementedOp();
+}
+
+StatusOr<XlaOp> XlaBuilder::AddInstruction(
+ HloInstructionProto&& instr, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<XlaOp> operands) {
const int64 handle = instructions_.size();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
@@ -229,7 +833,12 @@ XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
instr.set_name(StrCat(instr.name(), ".", handle));
}
for (const auto& operand : operands) {
+ TF_RET_CHECK(operand.builder_ != nullptr);
+ TF_RET_CHECK(operand.builder_ == this)
+ << "Do not add XlaOp from builder " << operand.builder_->name()
+ << " to builder " << this->name();
instr.add_operand_ids(operand.handle());
+ // TODO(b/74197823): Set metadata and sharding.
}
instructions_.push_back(instr);
@@ -246,4 +855,9 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
return &instructions_[op.handle()];
}
+XlaOp XlaBuilder::UnimplementedOp() {
+ NoteError(Unimplemented("Op not yet implemented"));
+ return {};
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index f1d10ecdb9..c19eb47165 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -24,6 +24,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -50,10 +52,11 @@ class XlaBuilder;
// TODO(b/74197823): Replace xla::ComputationDataHandle with this one.
class XlaOp {
public:
+ XlaOp() : handle_(0), builder_(nullptr) {}
+
StatusOr<Shape> GetShape() const;
private:
- XlaOp() : handle_(0), builder_(nullptr) {}
XlaOp(int64 handle, XlaBuilder* builder)
: handle_(handle), builder_(builder) {}
@@ -64,38 +67,6 @@ class XlaOp {
XlaBuilder* builder_; // Not owned.
};
-// The computation graph that the user builds up with the XlaBuilder.
-//
-// TODO(b/74197823): Replace xla::Computation with this one.
-class XlaComputation {
- public:
- XlaComputation(const XlaComputation&) = delete;
- XlaComputation& operator=(const XlaComputation&) = delete;
-
- XlaComputation(XlaComputation&& from) { *this = std::move(from); }
-
- XlaComputation& operator=(XlaComputation&& from) {
- proto_ = std::move(from.proto());
- unique_id_ = from.unique_id_;
- return *this;
- }
-
- // Returns the "program shape" (parameter and return shapes) for this
- // computation.
- const ProgramShape& GetProgramShape() const { return proto_.program_shape(); }
-
- const HloModuleProto& proto() const { return proto_; }
-
- private:
- // Creates a null Computation.
- XlaComputation(const int64 unique_id) : unique_id_(unique_id) {}
- HloModuleProto* mutable_proto() { return &proto_; }
- friend class XlaBuilder;
-
- int64 unique_id_;
- HloModuleProto proto_;
-};
-
// A convenient interface for building up computations.
//
// Thread-compatible.
@@ -121,14 +92,6 @@ class XlaBuilder {
die_immediately_on_error_ = enabled;
}
- // Enqueues an add instruction onto the computation.
- XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a call instruction onto the computation.
- XlaOp Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
-
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@@ -156,17 +119,597 @@ class XlaBuilder {
// corresponding native type yet.
template <typename NativeT>
XlaOp ConstantR0(NativeT value);
+ template <typename NativeT>
+ XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ XlaOp ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
- // Returns the shape of the given op.
- StatusOr<Shape> GetShape(const XlaOp& op) const;
+ // Enqueues a rank one constant (vector) onto the computation. The vector has
+ // size 'length' and every element has the value 'value'.
+ template <typename NativeT>
+ XlaOp ConstantR1(int64 length, NativeT value);
+
+ // Adds dimensions to an array by duplicating the data in the array.
+ //
+ // The new dimensions are inserted on the left, i.e. if
+ // broadcast_sizes has values {a0, ..., aN} and the operand shape
+ // has dimensions {b0, ..., bM} then the shape of the output has
+ // dimensions {a0, ..., aN, b0, ..., bM}.
+ //
+ // The new dimensions index into copies of the operand, i.e.
+ //
+ // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+ XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ // Enqueues a pad operation onto the computation that pads the given value on
+ // the edges as well as between the elements of the input. padding_config
+ // specifies the padding amount for each dimension.
+ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ // Enqueues an operation onto the computation that flattens the operand based
+ // on the dimension order (major/slowest-varying to minor/fastest-varying)
+ // given, followed by reshaping it into the shape with the given dimension
+ // sizes (also major to minor). Conceptually, this is a limited form of
+ // "shape casting".
+ XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Enqueues an operation onto the computation that collapses the operand, from
+ // first to last dimension (C order), then reshapes it to the given dimension
+ // sizes. Conceptually, this is a limited form of "shape casting".
+ XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Wrapper for Reshape.
+ // Enqueues an operation to collapse the provided dimensions; e.g. an
+ // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
+ // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
+ // be a consecutive, in-order subsequence of the operand dimensions.
+ //
+ // Note that collapsing a single dimension does nothing:
+ //
+ // {256} collapsing {0} => {256}
+ // {1} collapsing {0} => {1}
+ //
+ // Collapsing multiple dimensions produces a single result dimension:
+ //
+ // {256, 2} collapsing {0,1} => {512}
+ // {256, 2, 3} collapsing {0,1} => {512, 3}
+ //
+ // This could potentially cause data to be moved -- it provides a more
+ // structured form of reshaping than an arbitrary Reshape operation.
+ XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a slice operation onto the computation that slices the operand
+ // from the start indices to the limit indices; e.g.
+ //
+ // x
+ // [ 0 1 2 3 ]
+ // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
+ // [ 8 9 a b ]
+ //
+ // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
+ // range notation.
+ // The strides parameter determines the stride over the slice
+ XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ // Enqueues a slice operation in a given dimension, taking all other
+ // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
+ // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
+ // for:
+ //
+ // array[:, 2:4:1, :]
+ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno);
+
+ // Enqueues a slice operation onto the computation that slices the 'operand'
+ // from dynamic start indices which are passed in 'start_indices'.
+ // The size of the slice in each dimension is passed in 'slice_sizes',
+ // which specify the end point of exclusive slice intervals in each
+ // dimension [start, start + size).
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo input dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ // Enqueues a dynamic update slice operation onto the computation, which
+ // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
+ // The shape of 'update' determines the shape of the slice of 'operand'
+ // which is updated.
+ // The indices specified in 'start_indices' specify the offset of the slice
+ // of 'operand' which is updated.
+ //
+ // update = {10, 11} // calculated at runtime.
+ // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
+ // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
+ // [7 8 9] [7 8 9 ]
+ //
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo update dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ // Enqueues a concatenate instruction onto the computation. 'operands' must
+ // have >= 1 entry.
+ XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ // Enqueue a tracing operation onto the computation; the computation will emit
+ // a logging message with the operand.
+ void Trace(const string& tag, const XlaOp& operand);
+
+ // Enqueues a conditional-move-like select operation onto the computation;
+ // predicated on pred, selects between on_true and on_false.
+ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
+
+ // Enqueues a tuple-creation instruction onto the computation.
+ XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements);
+
+ // Enqueues a tuple-element-get instruction onto the computation.
+ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+
+ // Enqueues an equal-to comparison instruction onto the computation.
+ XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a not-equal comparison instruction onto the computation.
+ XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-or-equal comparison instruction onto the computation.
+ XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-than comparison instruction onto the computation.
+ XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-than comparison instruction onto the computation.
+ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-or-equal comparison instruction onto the computation.
+ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a dot instruction onto the computation.
+ XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+
+ // Enqueues a general dot instruction onto the computation.
+ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, which uses the
+ // default convolution dimension numbers.
+ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration in the format returned by MakePadding().
+ XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided dimension numbers configuration.
+ XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration as well as the dimension numbers.
+ XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration, dilation factors and dimension numbers.
+ XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues an FFT instruction onto the computation, of the given type and
+ // with the given FFT length.
+ XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
+ // Enqueues an infeed instruction onto the computation, which writes data of
+ // the given shape to the infeed buffer of the device.
+ XlaOp Infeed(const Shape& shape, const string& config = "");
+
+ // Enqueues an outfeed instruction onto the computation. This instruction
+ // generates outgoing data transfers for the given data.
+ //
+ // shape_with_layout communicates the laid out shape that we want to outfeed
+ // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
+ // will occur.
+ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+ // Enqueues a call instruction onto the computation.
+ XlaOp Call(const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+ // Enqueues a custom call instruction onto the computation.
+ // During code generation, a call instruction is emitted which targets a
+ // symbol with the name |call_target_name|. The |operands| are passed to the
+ // call instruction. |shape| is the resultant shape.
+ XlaOp CustomCall(const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+
+ // Enqueues a pseudo-op to represent host-side computation data-dependencies.
+ // During code generation, host send and receive operations will be generated
+ // to transfer |operands| to the host and a single result of |shape| back to
+ // the device. Host send/recv operations are emitted using |channel_name|.
+ // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
+ // instruction scheduling.
+ XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+
+ // The following methods enqueue element-wise binary arithmetic operations
+ // onto the computation. The shapes of the operands have to match unless one
+ // of the operands is a scalar, or an explicit broadcast dimension is given
+ // (see g3doc for more details).
+
+ // Enqueues a complex compose instruction onto the computation.
+ XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a complex conjugate instruction onto the computation.
+ XlaOp Conj(const XlaOp& operand);
+
+ // Enqueues an add instruction onto the computation.
+ XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a subtract instruction onto the computation.
+ XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a multiply instruction onto the computation.
+ XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a divide instruction onto the computation.
+ XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a remainder instruction onto the computation.
+ XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a max instruction onto the computation.
+ XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a min instruction onto the computation.
+ XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Element-wise logical operators
+ XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ XlaOp Not(const XlaOp& operand);
+
+ XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Reduces an array among the provided dimensions, given "computation" as a
+ // reduction operator.
+ XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+ // Convenience wrapper around the above that reduces all the dimensions in the
+ // operand shape.
+ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+
+ // Enqueues a windowed reduce instruction onto the computation.
+ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+ // As ReduceWindow(), but the padding is given in the format
+ // returned by MakePadding().
+ XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Returns the sum of the operand value across all replicas. All replicas
+ // supply one input to the sum and all replicas receive the resulting sum.
+ XlaOp CrossReplicaSum(const XlaOp& operand);
+
+ // Enqueues an operation that scatters the `source` array to the selected
+ // indices of each window.
+ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+ // As SelectAndScatter(), but the padding is given in the format
+ // returned by MakePadding().
+ XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+ // Enqueues an abs instruction onto the computation.
+ XlaOp Abs(const XlaOp& operand);
+
+ // Enqueues a atan2 instruction onto the computation.
+ XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues an exp instruction onto the computation.
+ XlaOp Exp(const XlaOp& operand);
+
+ // Enqueues a floor instruction onto the computation.
+ XlaOp Floor(const XlaOp& operand);
+
+ // Enqueues a ceil instruction onto the computation.
+ XlaOp Ceil(const XlaOp& operand);
+
+ // Enqueues a round instruction onto the computation, rounding to nearest even
+ // with half-way cases rounding away from zero.
+ XlaOp Round(const XlaOp& operand);
+
+ // Enqueues an log instruction (natural logarithm) onto the computation.
+ XlaOp Log(const XlaOp& operand);
+
+ // Enqueues a sign instruction onto the computation.
+ XlaOp Sign(const XlaOp& operand);
+
+ // Enqueues a cosine instruction onto the computation.
+ XlaOp Cos(const XlaOp& operand);
+
+ // Enqueues a sine instruction onto the computation.
+ XlaOp Sin(const XlaOp& operand);
+
+ // Enqueues a tanh instruction onto the computation.
+ XlaOp Tanh(const XlaOp& operand);
+
+ // Enqueues a real-part instruction onto the computation.
+ XlaOp Real(const XlaOp& operand);
+
+ // Enqueues an imaginary-part instruction onto the computation.
+ XlaOp Imag(const XlaOp& operand);
+
+ // Enqueues a float32 sqrt instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 0.5f constant
+ // exponent).
+ XlaOp SqrtF32(const XlaOp& operand);
+
+ // Enqueues a float32 square instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 2.0f constant
+ // exponent).
+ XlaOp SquareF32(const XlaOp& operand);
+
+ // Enqueues a lhs^rhs computation onto the computation.
+ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues an operator that tests if the operand's values are finite, i.e.,
+ // not Inf or NaN. Defined only for floating-point types. Returns an array of
+ // booleans with the same shape where entries are true iff the corresponding
+ // entry was NaN.
+ XlaOp IsFinite(const XlaOp& operand);
+
+ // Enqueues a convert instruction onto the computation that changes the
+ // element type of the operand array to primitive_type.
+ XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+
+ // Enqueues a no-op instruction onto the computation that changes
+ // the element type of the operand array to primitive_type. The
+ // bit-widths of the source and destination element types must be
+ // identical.
+ XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+
+ // Enqueues a float32 reciprocal instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 -1.0f constant
+ // exponent).
+ //
+ // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
+ // shape of the operand.
+ XlaOp ReciprocalF32(const XlaOp& operand);
+
+ // Enqueues a negate instruction onto the computation.
+ XlaOp Neg(const XlaOp& operand);
+
+ // Enqueues a transpose instruction onto the computation.
+ XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+ // Enqueues a reverse instruction onto the computation. The order of the
+ // elements in the given dimensions is reversed (i.e., the element at index i
+ // is moved to index dimension_size - 1 - i).
+ XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a sort (as increasing order) instruction onto the computation.
+ XlaOp Sort(const XlaOp& operand);
+
+ // Enqueues a clamp instruction onto the computation.
+ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+
+ // Enqueues a map instruction onto the computation.
+ XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+
+ // Enqueues a N(mu, sigma) random number generation instruction onto the
+ // computation.
+ XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
+
+ // Enqueues a U(a, b) random number generation instruction onto the
+ // computation. Returns values in the semi-open interval [a, b).
+ XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+
+ // Enqueues a while node onto the computation.
+ XlaOp While(const XlaComputation& condition, const XlaComputation& body,
+ const XlaOp& init);
+
+ // Enqueues a conditional node onto the computation.
+ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+
+ // Enqueues a ReducePrecision node onto the computation.
+ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+
+ // Enqueues a Gather node onto the computation.
+ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
+ // Enqueues a Send node onto the computation, to send the given operand to
+ // a Recv instruction that shares the same channel handle.
+ void Send(const XlaOp& operand, const ChannelHandle& handle);
+
+ // Enqueues a Recv node onto the computation. The data comes from a Send
+ // instruction that shares the same channel handle and its shape must
+ // be the same as the given shape.
+ XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on parameters with index greater than or equal to
+ // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
+ // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
+ // compile-time constant without evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand, int64 num_parameters = 0);
+
+ // Normalizes operand across spatial and batch dimensions for each feature.
+ //
+ // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
+ // is the normalized result and batch_mean and batch_var are the mean and
+ // variance, respectively, across batch for the operand.
+ XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+
+ // Normalizes operand across spatial and batch dimensions for each feature.
+ //
+ // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
+ // computing `mean` and `variance` for each batch inside the operation. It
+ // uses the input `mean` and `variance` instead as estimated values. The
+ // purpose of this op is to reduce latency in inference, hence the name
+ // `BatchNormInference`.
+ //
+ // The output has the same shape as `operand`, and contains the normalized
+ // values for each batch.
+ XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+
+ // Calculates the gradients of a batch norm op.
+ //
+ // The inputs `batch_mean` and `batch_var` represent the mean and variance
+ // across the batch.
+ //
+ // Returns a tuple of three elements:
+ // - grad_operand: Gradient with respect to input `operand`
+ // - grad_offset: Gradient with respect to input `offset`
+ // - grad_scale: Gradient with respect to input `scale`
+ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
// Builds the computation with the requested operations, or returns a non-ok
// status.
StatusOr<XlaComputation> Build();
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // XlaOp and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
+ // Returns the shape of the given op.
+ StatusOr<Shape> GetShape(const XlaOp& op) const;
+
+ // Returns the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape();
+
private:
- XlaOp AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands = {});
+ StatusOr<XlaOp> AddInstruction(
+ HloInstructionProto&& instr, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<XlaOp> operands = {});
// Notes that the error occurred by:
// * storing it internally and capturing a backtrace if it's the first error
@@ -182,8 +725,34 @@ class XlaBuilder {
return op.ConsumeValueOrDie();
}
+ // Helper method that creates an empty op and notes error.
+ XlaOp UnimplementedOp();
+
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ // Internal helper method that does the building for an arbitrary binary op.
+ // broadcast_dimensions specifies which dimensions to use for broadcasting
+ // when the operation is between tensors of different ranks.
+ XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ StatusOr<XlaOp> InDimBroadcast(
+ const Shape& shape, const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Internal helper method that creates a sequence of instructions that
+ // performs an explicit broadcast of the operand to the target shape.
+ StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
+ const XlaOp& operand);
+
+ // Internal helper method for creating a Reshape op with the already inferred
+ // shape.
+ StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
+
+ // Returns the (inferred) result for the program shape for the current
+ // computation and fills the root_id in the pointer.
+ StatusOr<ProgramShape> GetProgramShape(int64* root_id);
+
string name_; // Name to use for the built computation.
// The first error encountered while building the computation.
@@ -213,6 +782,76 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) {
return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
}
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+ return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+ literal.PopulateWithValue(value);
+ return ConstantLiteral(literal);
+}
+
+inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
+ return ConstantLiteral(*Literal::CreateR1(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
+ return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
+ return ConstantLiteral(
+ *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
+ return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
+ return ConstantLiteral(
+ *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) {
+ return ConstantFromArray(values);
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
+ return ConstantFromArrayWithLayout(values, layout);
+}
+
+template <typename NativeT>
+XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
+ return ConstantFromArray(values);
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
index a400e4e78b..529287a57a 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
@@ -57,16 +57,16 @@ TEST_F(XlaBuilderTest, OnePlusTwo) {
EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
}
-TEST_F(XlaBuilderTest, ParamPlusConstant) {
+TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
XlaBuilder b(TestName());
auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
b.Add(x, b.ConstantR0<float>(1.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant()));
+ EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
}
-TEST_F(XlaBuilderTest, ParamPlusParam) {
+TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
XlaBuilder b(TestName());
const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
@@ -79,7 +79,7 @@ TEST_F(XlaBuilderTest, ParamPlusParam) {
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1)));
+ EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1))));
}
TEST_F(XlaBuilderTest, XPlusX) {
@@ -133,5 +133,89 @@ TEST_F(XlaBuilderTest, Call) {
op::Call(op::Constant(), op::Constant())));
}
+TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
+ b.Add(x, y);
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+
+ // Expected:
+ //
+ // x: f32[1,2,3] y: f32[1,2,1]
+ // | |
+ // | reshape: f32[1,2]
+ // | |
+ // | broadcast: f32[1,2,3]
+ // \ /
+ // add
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Parameter(0),
+ op::Broadcast(op::Reshape(op::Parameter(1)))));
+}
+
+TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
+ b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+
+ // The binary operation has in-dim broadcast and degenerate broadcast, should
+ // first do the in-dim broadcast then convert the degnerate broadcast into a
+ // reshape and a broadcast.
+ //
+ // Expected:
+ //
+ // x: f32[2,3] y: f32[2,1,4]
+ // | |
+ // broadcast: f32[2,3,4] reshape: f32[2,4]
+ // | |
+ // | broadcast: f32[2,3,4]
+ // \ /
+ // add
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)),
+ op::Broadcast(op::Reshape(op::Parameter(1)))));
+}
+
+TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
+ XlaBuilder b1("b1");
+ auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
+ XlaBuilder builder("main");
+ builder.Add(p0, p0);
+ auto statusor = builder.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Do not add XlaOp from builder b1 to builder main"));
+}
+
+TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
+ b.Reshape(x, /*new_sizes=*/{6, 35});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Reshape(op::Parameter()));
+}
+
+TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
+ b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
+}
+
+TEST_F(XlaBuilderTest, Transpose) {
+ XlaBuilder b(TestName());
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ b.Transpose(x, /*permutation=*/{1, 0});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Transpose(op::Parameter()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
new file mode 100644
index 0000000000..3681792eee
--- /dev/null
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+
+#include <utility>
+
+namespace xla {
+
+const ProgramShape& XlaComputation::GetProgramShape() const {
+ return proto_.program_shape();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
new file mode 100644
index 0000000000..5b89747fdd
--- /dev/null
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// The computation graph that the user builds up with the XlaBuilder.
+//
+// TODO(b/74197823): Replace xla::Computation with this one.
+class XlaComputation {
+ public:
+ XlaComputation(const XlaComputation&) = delete;
+ XlaComputation& operator=(const XlaComputation&) = delete;
+
+ XlaComputation(XlaComputation&& from) = default;
+
+ XlaComputation& operator=(XlaComputation&& from) = default;
+
+ // Returns the "program shape" (parameter and return shapes) for this
+ // computation.
+ const ProgramShape& GetProgramShape() const;
+ const HloModuleProto& proto() const { return proto_; }
+
+ private:
+ XlaComputation(const int64 unique_id) : unique_id_(unique_id) {}
+ HloModuleProto* mutable_proto() { return &proto_; }
+ friend class XlaBuilder;
+
+ int64 unique_id_;
+ HloModuleProto proto_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 214c2030cd..13675b7d00 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -1385,8 +1385,9 @@ void Literal::EachCellAsString(
}
namespace {
-template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
+std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
+ const Literal& src_literal, const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
src_literal.shape(),
@@ -1396,11 +1397,18 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
- dest_data[i] = static_cast<NativeDestT>(src_data[i]);
+ dest_data[i] = converter(src_data[i]);
}
return result_literal;
}
+template <typename NativeSrcT, typename NativeDestT>
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+ auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
@@ -1492,8 +1500,16 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
}
StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
- const Shape& dest_shape) const {
+ const Shape& dest_shape, bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
+ if (round_f32_to_bf16 && shape().element_type() == F32 &&
+ dest_shape.element_type() == BF16) {
+ auto converter = [](float src) {
+ return tensorflow::bfloat16::round_to_bfloat16(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
+ converter);
+ }
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e24f5285d9..a96a76fbb4 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -340,8 +340,14 @@ class Literal {
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
+ //
+ // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+ // instead of truncation; otherwise, truncation is used.
+ //
+ // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+ // the default behavior.
StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape) const;
+ const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index d4d67872cf..da16976d06 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -623,6 +623,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:executable_build_options",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 971c2935c8..f9fabd8a35 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -302,7 +302,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplication on platforms where it causes a slowdown.
+ // Disable convolution simplification on platforms where it causes a slowdown.
bool enable_conv_simplification_;
};
@@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
auto operand = broadcast->mutable_operand(0);
+ auto dims = broadcast->dimensions();
// A degenerate broadcast of a reshape that does not change the number of
// elements can be replaced by a reshape.
- if (std::is_sorted(broadcast->dimensions().begin(),
- broadcast->dimensions().end()) &&
+ if (std::is_sorted(dims.begin(), dims.end()) &&
ShapeUtil::ElementsIn(broadcast->shape()) ==
ShapeUtil::ElementsIn(operand->shape())) {
VLOG(10) << "transform broadcast(X) -> reshape(X) where "
@@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
VLOG(10) << "transform broadcast(X) -> transpose(X) where "
"n(broadcast(X)) == n(X)";
return ReplaceWithNewInstruction(
- broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
- broadcast->dimensions()));
+ broadcast,
+ HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
}
// A broadcast of a reshape which merely inserts 1-sized dimensions can
@@ -1157,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
if (merely_inserts_or_deletes_1_sized_dimensions &&
deleted_indices.empty()) {
std::reverse(inserted_indices.begin(), inserted_indices.end());
- auto dims = broadcast->dimensions();
for (auto inserted_index : inserted_indices) {
dims.erase(dims.begin() + inserted_index);
}
@@ -1201,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return user->ReplaceAllUsesWith(new_broadcast);
}
}
+ return Status::OK();
+ }
+
+ // Merge two consecutive broadcasts into a single one.
+ if (operand->opcode() == HloOpcode::kBroadcast) {
+ std::vector<int64> new_dimensions;
+ for (auto dim : operand->dimensions()) {
+ new_dimensions.push_back(dims[dim]);
+ }
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateBroadcast(
+ broadcast->shape(), operand->mutable_operand(0), new_dimensions));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index f0590943be..c48196e861 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface {
bool is_layout_sensitive_;
ValidBitcastCallback valid_bitcast_callback_;
- // Enable dot simplication on platforms where it is profitable.
+ // Enable dot simplification on platforms where it is profitable.
bool enable_dot_strength_reduction_;
- // Enable convolution simplication on platforms where it is profitable.
+ // Enable convolution simplification on platforms where it is profitable.
bool enable_conv_simplification_;
};
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 451294ef5d..3b80a827bf 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -35,6 +35,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
+using ::testing::ElementsAre;
+
namespace xla {
namespace {
@@ -2462,6 +2464,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
op::DynamicSlice(op::Parameter(), op::Parameter()));
}
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* input_array = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+ HloInstruction* inner_bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Constant()));
+ EXPECT_THAT(root->dimensions(), ElementsAre(2));
+}
+
+// Test that two consecutive broadcasts can be merged to one.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ // The initial dimensions go to places 0 and 2 in the 3-dim array,
+ // and to places 1 and 3 in the 4-dim array,
+ HloInstruction* inner_bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Parameter(0)));
+ EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 7195c31d9c..c26d2feef5 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -606,8 +606,10 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
- TF_ASSIGN_OR_RETURN(auto converted_literal,
- hlo->literal().ConvertToShape(hlo->shape()));
+ TF_ASSIGN_OR_RETURN(
+ auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape(),
+ /*round_f32_to_bf16=*/true));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 6664496ab6..c83da9eddc 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -100,7 +100,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, instance.argument_layouts,
- &execution_options, *user_computation));
+ &execution_options, user_computation));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 33e19efc72..b4b53ae2ed 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -127,7 +127,7 @@ class Compiler {
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are
// applied to module. Generally a module should be passed through RunHloPasses
- // prior to calling this method because the some HLO passes are required for
+ // prior to calling this method because some HLO passes are required for
// correctness. Takes ownership of the HLO module and is free to transform it.
//
// The compiler may optionally specialize to the individual device
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 093db020c0..0faa9e9c41 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -670,6 +670,22 @@ cc_library(
],
)
+tf_cc_test(
+ name = "ir_emission_utils_test",
+ srcs = ["ir_emission_utils_test.cc"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
cc_library(
name = "cpu_layout_assignment",
srcs = ["cpu_layout_assignment.cc"],
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index 788217aab6..f209a69e3c 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution(
//
// To be sufficient, certain layout constraints need to be satisfied as well.
const Shape& input_shape = convolution.operand(0)->shape();
- const Shape& kernel_shape = convolution.operand(0)->shape();
+ const Shape& kernel_shape = convolution.operand(1)->shape();
if (ShapeUtil::HasZeroElements(input_shape) ||
ShapeUtil::HasZeroElements(kernel_shape)) {
return false;
}
+ // Make sure input and kernel has the same data type.
+ CHECK(
+ ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape));
// TODO(b/65408531): Explore using Eigen dot for complex64 type.
- if (ShapeUtil::ElementIsComplex(input_shape) ||
- ShapeUtil::ElementIsComplex(kernel_shape)) {
+ if (ShapeUtil::ElementIsComplex(input_shape)) {
return false;
}
if (window_util::HasWindowReversal(convolution.window())) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
new file mode 100644
index 0000000000..215f48c4cc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) {
+ const char* const hlo_string = R"(
+HloModule ModuleWithConv
+
+ENTRY Conv {
+ input = f32[32,50,28,28]{3,2,1,0} parameter(0)
+ kernel = f32[0,32,5,5]{3,2,1,0} parameter(1)
+ ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel),
+ window={size=5x5},
+ dim_labels=b01f_01io->b01f
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ HloComputation* entry_computation = module->entry_computation();
+
+ HloInstruction* conv_instr = entry_computation->root_instruction();
+ EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 3b8056d505..3405277d44 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -438,12 +438,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer,
- length_32, 1);
+ ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1,
+ acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address,
- length_32, 1);
+ ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1,
+ program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
ir_builder_.CreateCall(release_func,
@@ -2441,7 +2443,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
auto* memcpy_instruction = ir_builder_.CreateMemCpy(
- target, source, element_count * primitive_type_size, element_alignment);
+ target, /*DstAlign=*/element_alignment, source,
+ /*SrcAlign=*/element_alignment, element_count * primitive_type_size);
// The memcpy does the load and the store internally. The aliasing related
// metadata has to reflect that.
@@ -2905,7 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1);
+ ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 86e8be8461..fb28280fad 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -128,6 +128,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
// one of the following properties:
// *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall).
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
+ // *) Operations that are not thread safe (like infeed and rng).
// *) Tuple-shaped.
// TODO(b/27458679) Parallelize instructions which are skipped here.
auto opcode = instruction->opcode();
@@ -135,7 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
- opcode == HloOpcode::kFft ||
+ opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
+ opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index 90191221eb..13eb75a572 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -80,5 +80,39 @@ TEST_F(ParallelTaskAssignmentTest,
EXPECT_FALSE(changed);
}
+TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) {
+ const string hlo_string = R"(
+ HloModule TestTaskParallel_rng
+ ENTRY Rng {
+ src0 = f32[] parameter(0)
+ src1 = f32[] parameter(1)
+ ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1),
+ distribution=rng_uniform
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
+ max_parallelism_, shape_size_func_)
+ .Run(&module()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
+ const string hlo_string = R"(
+ HloModule TestTaskParallel_infeed_outfeed
+ ENTRY InfeedOutfeed {
+ infeed0 = u32[12345678,2]{1,0} infeed()
+ ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0)
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string);
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
+ max_parallelism_, shape_size_func_)
+ .Run(&module()));
+ EXPECT_FALSE(changed);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 279edd4ba8..cd7cbbdd71 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -109,6 +109,11 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
continue;
}
+ // Skip instructions which have side effects.
+ if (instruction->HasSideEffect()) {
+ continue;
+ }
+
// An instruction is considered to be equivalent to another only if they
// share the exact same set of operands. So to find equivalent
// instructions, we just search among instructions which share operand(0)
@@ -118,7 +123,7 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
tensorflow::gtl::InlinedVector<HloInstruction*, 8>
equivalent_instructions;
for (HloInstruction* user : operand->users()) {
- if (user != instruction &&
+ if (user != instruction && !user->HasSideEffect() &&
user->Identical(*instruction, eq_instructions, eq_computations,
is_layout_sensitive_)) {
equivalent_instructions.push_back(user);
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 3601a790c4..df8853f34f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -414,8 +414,7 @@ TEST_F(HloCseTest, DoNotCombineRng) {
EXPECT_THAT(root, op::Add(rng1, rng2));
}
-// TODO(b/28245743): Handle impure functions correctly in CSE.
-TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
+TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
// Test that two calls to an impure function are not commoned. RNG
// is the source of the impurity.
@@ -458,14 +457,16 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Add(op::Map(), op::Map()));
+ VLOG(3) << "before: " << module->ToString();
+
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ VLOG(3) << "after: " << module->ToString();
EXPECT_EQ(4, computation->instruction_count());
root = computation->root_instruction();
- auto operand = root->operand(0)->operand(0);
- EXPECT_THAT(operand, op::Map());
- EXPECT_THAT(root, op::Add(operand, operand));
+ EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 2037764dae..595c531ccf 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -237,8 +237,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
module_config.entry_computation_layout().parameter_layout(i).shape();
- TF_RET_CHECK(
- ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape))
+ TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
+ parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
<< " than the HLO module. Expected: "
<< ShapeUtil::HumanStringWithLayout(
@@ -247,7 +247,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
const Shape& result_shape =
module_config.entry_computation_layout().result_layout().shape();
- TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape))
+ TF_RET_CHECK(
+ ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
"Expected: "
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 5690a89909..499f280211 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {}
+namespace {
+
+// Retrieves the parameter metadata for the given computation and parameter
+// number.
+//
+// If the parameter number is invalid for this computation, nullopt is
+// returned. When the return value has_value(), nullptr will never be
+// the held value.
+tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+ const XlaComputation& computation, int parameter_number) {
+ for (const HloComputationProto& comp : computation.proto().computations()) {
+ if (comp.id() == computation.proto().entry_computation_id()) {
+ for (const HloInstructionProto& instr : comp.instructions()) {
+ if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
+ instr.parameter_number() == parameter_number) {
+ if (!instr.has_metadata()) {
+ return tensorflow::gtl::nullopt;
+ }
+ return &instr.metadata();
+ }
+ }
+ }
+ }
+ return tensorflow::gtl::nullopt;
+}
+
+ExecutionOptions CreateExecutionOptions(
+ const ExecutableBuildOptions& build_options,
+ const ProgramShape* program_shape) {
+ ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+ if (build_options.hlo_profile().has_value()) {
+ execution_options.mutable_debug_options()->set_xla_hlo_profile(
+ *build_options.hlo_profile());
+ }
+ if (build_options.generate_hlo_graph().has_value()) {
+ execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
+ build_options.generate_hlo_graph().value());
+ }
+ if (build_options.dump_optimized_hlo_proto_to().has_value()) {
+ execution_options.mutable_debug_options()
+ ->set_xla_dump_optimized_hlo_proto_to(
+ build_options.dump_optimized_hlo_proto_to().value());
+ }
+ if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
+ execution_options.mutable_debug_options()
+ ->set_xla_dump_per_pass_hlo_proto_to(
+ build_options.dump_per_pass_hlo_proto_to().value());
+ }
+ if (build_options.result_layout() != nullptr) {
+ *execution_options.mutable_shape_with_output_layout() =
+ *build_options.result_layout();
+ } else {
+ *execution_options.mutable_shape_with_output_layout() =
+ program_shape->result();
+ LayoutUtil::SetToDefaultLayout(
+ execution_options.mutable_shape_with_output_layout());
+ }
+ return execution_options;
+}
+
+} // namespace
+
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
@@ -118,44 +180,78 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
*build_options.result_layout(), program_shape->result()));
}
- ExecutionOptions execution_options = CreateDefaultExecutionOptions();
- if (build_options.hlo_profile().has_value()) {
- execution_options.mutable_debug_options()->set_xla_hlo_profile(
- *build_options.hlo_profile());
- }
- if (build_options.generate_hlo_graph().has_value()) {
- execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
- build_options.generate_hlo_graph().value());
- }
- if (build_options.dump_optimized_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_optimized_hlo_proto_to(
- build_options.dump_optimized_hlo_proto_to().value());
+ ExecutionOptions execution_options =
+ CreateExecutionOptions(build_options, program_shape.get());
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, argument_layouts,
+ &execution_options, user_computation));
+
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ execute_backend_->stream_executor(build_options.device_ordinal()));
+
+ return BuildExecutable(versioned_handle, std::move(module_config),
+ execute_backend_.get(), executor,
+ build_options.device_allocator());
+}
+
+StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
+ const XlaComputation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& build_options) {
+ const HloModuleProto& proto = computation.proto();
+ TF_RET_CHECK(proto.has_program_shape());
+ const ProgramShape& program_shape = proto.program_shape();
+
+ // Validate incoming layouts.
+ if (argument_layouts.size() != program_shape.parameters_size()) {
+ return InvalidArgument(
+ "Invalid number of arguments for computation: expected %d, got %zu.",
+ program_shape.parameters_size(), argument_layouts.size());
}
- if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_per_pass_hlo_proto_to(
- build_options.dump_per_pass_hlo_proto_to().value());
+
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& argument_shape = *argument_layouts[i];
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+ if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
+ tensorflow::gtl::optional<const OpMetadata*> metadata =
+ ParameterMetadata(computation, /*parameter_number=*/i);
+ auto metadata_string = [&metadata]() -> string {
+ if (!metadata.has_value()) {
+ return "";
+ }
+ CHECK(metadata.value() != nullptr);
+ const OpMetadata& m = *metadata.value();
+ if (!m.source_file().empty()) {
+ return tensorflow::strings::Printf(
+ " (%s:%d)", m.source_file().c_str(), m.source_line());
+ }
+ return "";
+ };
+ return InvalidArgument(
+ "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
+ metadata_string().c_str(),
+ ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(argument_shape).c_str());
+ }
}
if (build_options.result_layout() != nullptr) {
- *execution_options.mutable_shape_with_output_layout() =
- *build_options.result_layout();
- } else {
- *execution_options.mutable_shape_with_output_layout() =
- program_shape->result();
- LayoutUtil::SetToDefaultLayout(
- execution_options.mutable_shape_with_output_layout());
+ TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(
+ *build_options.result_layout(), program_shape.result()));
}
+
+ ExecutionOptions execution_options =
+ CreateExecutionOptions(build_options, &program_shape);
+
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, argument_layouts, &execution_options,
- *user_computation));
+ CreateModuleConfig(program_shape, argument_layouts, &execution_options));
TF_ASSIGN_OR_RETURN(
se::StreamExecutor * executor,
execute_backend_->stream_executor(build_options.device_ordinal()));
- return BuildExecutable(versioned_handle, std::move(module_config),
+ return BuildExecutable(proto, std::move(module_config),
execute_backend_.get(), executor,
build_options.device_allocator());
}
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 15e120685e..06567cabd6 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -50,6 +51,18 @@ class LocalService : public Service {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options);
+ // Builds an Executable with the given XlaComputation, argument layouts and
+ // options. If result_layout is non-null, then the executable is compiled to
+ // produce a result of the given layout. If device_allocator is non-null,
+ // then the compiler may use it to allocate temp space on the device. The
+ // compiler is responsible for freeing any memory it allocates this way.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<Executable>> CompileExecutable(
+ const XlaComputation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& build_options);
+
// Returns the device ordinal that corresponds to the given replica number.
//
// This returns an error if there is not a one-to-one correspondence of
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index e62bafc50b..f15117f45c 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -53,6 +53,14 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) {
instruction->opcode() == HloOpcode::kTranspose;
}
+// Returns true if `a` is a broadcast instruction to target shape `shape` and
+// its operand is a scalar.
+bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) {
+ return a->opcode() == HloOpcode::kBroadcast &&
+ ShapeUtil::SameDimensions(a->shape(), shape) &&
+ ShapeUtil::IsScalar(a->operand(0)->shape());
+}
+
// Returns true iff `instruction` can change its shape simply by adjusting
// metadata.
bool CanTriviallyChangeShape(const HloInstruction* instruction) {
@@ -88,6 +96,7 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) {
instruction->user_count() == 1) {
return true;
}
+
return false;
}
@@ -148,6 +157,8 @@ bool AllOperandsHaveEasyShapeChanges(
// or
// 2. Are one of kConstant, kRng, and scalars that can change shape
// trivially,
+ // or
+ // 3. Are broadcast with a scalar operand.
for (const HloInstruction* operand : instruction->operands()) {
if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
VLOG(5) << "Operand shape differs from output shape; may be "
@@ -158,6 +169,12 @@ bool AllOperandsHaveEasyShapeChanges(
return false;
}
+ // Skip the rest checks if the current operand is first_reshape_operand
+ // itself.
+ if (first_reshape_operand == operand) {
+ continue;
+ }
+
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
<< first_reshape_operand->ToString(print_no_metadata)
@@ -171,6 +188,12 @@ bool AllOperandsHaveEasyShapeChanges(
continue;
}
+ if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) {
+ VLOG(5) << "Broadcast scalar to shape: "
+ << operand->ToString(print_no_metadata);
+ continue;
+ }
+
// TODO(someone): Look into supporting general ops for the operands as
// well.
VLOG(5) << "Operand is neither equalivant to the first Reshape operand"
@@ -222,6 +245,12 @@ HloInstruction* UpdateOperand(HloComputation* computation,
VLOG(5) << "Using existing operand of kReshape or kTranspose";
return operand->mutable_operand(0);
}
+ case HloOpcode::kBroadcast:
+ CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape()));
+ VLOG(5) << "Changing broadcast";
+ return computation->AddInstruction(
+ operand->CloneWithNewOperands(new_shape, operand->operands()));
+
default:
LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
}
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index aac8638a54..4e0a0a8832 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -560,5 +560,25 @@ TEST_F(ReshapeMoverTest, MultiplePasses) {
op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1)))));
}
+TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) {
+ const string hlo_string = R"(
+ HloModule TransposeMulInversedTransposeModule
+ ENTRY TransposeMulInversedTranspose {
+ src0 = f32[1,20,8,32]{3,2,1,0} parameter(0)
+ transpose0 = f32[1,8,20,32]{3,2,1,0} transpose(src0), dimensions={0,2,1,3}
+ src1 = f32[] parameter(1)
+ broadcast0 = f32[1,8,20,32]{3,2,1,0} broadcast(src1), dimensions={}
+ ROOT multiply0 = f32[1,8,20,32]{3,2,1,0} multiply(transpose0, broadcast0)
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string.c_str());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module()));
+ EXPECT_TRUE(changed);
+
+ EXPECT_THAT(module().entry_computation()->root_instruction(),
+ op::Transpose(op::Multiply()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 0becc9d8f8..1d379f0d03 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -272,7 +272,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options,
- const UserComputation& user_computation) {
+ const UserComputation* user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = config->mutable_entry_computation_layout();
@@ -286,8 +286,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
+ if (user_computation == nullptr) {
+ return InvalidArgument(
+ "Argument does not match shape of computation parameter %d: want "
+ "%s, got %s",
+ i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(*argument_shapes[i]).c_str());
+ }
return InvalidParameterArgument(
- *user_computation.ParameterMetadata(i).value(),
+ *user_computation->ParameterMetadata(i).value(),
"Argument does not match shape of computation parameter %d: want %s, "
"got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
@@ -330,7 +337,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutionOptions& execution_options,
- const UserComputation& user_computation) {
+ const UserComputation* user_computation) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape());
@@ -778,7 +785,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, replicated_arguments.front(),
- request.execution_options(), *user_computation));
+ request.execution_options(), user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -854,6 +861,33 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
return tensorflow::Status::OK();
}
+tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg,
+ ExecuteResponse* result) {
+ ExecuteParallelRequest parallel_arg;
+ *parallel_arg.add_requests() = *arg;
+ ExecuteParallelResponse parallel_result;
+ TF_RETURN_IF_ERROR(ExecuteParallel(&parallel_arg, &parallel_result));
+ // The "result device" selection is a bit hacky, but better than assuming it
+ // is device 0. We have b/76035356 for restructuring the client API to clean
+ // up the current asymmetries and support more functionalities.
+ for (int64 i = 0; i < parallel_result.responses_size(); ++i) {
+ TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
+ allocation_tracker_.ResolveForReplica(
+ parallel_result.responses(i).output(), 0));
+ const Shape& shape = buffer->on_host_shape();
+ if (!ShapeUtil::IsEmptyTuple(shape)) {
+ *result = parallel_result.responses(i);
+ VLOG(3) << "Fetching result from device " << i << ": "
+ << ShapeUtil::HumanString(shape);
+ return Status::OK();
+ }
+ }
+ TF_RET_CHECK(parallel_result.responses_size() > 0);
+ *result = parallel_result.responses(0);
+ VLOG(1) << "Defaulting to device 0 result";
+ return Status::OK();
+}
+
tensorflow::Status Service::Execute(const ExecuteRequest* arg,
ExecuteResponse* result) {
VLOG(1) << "running execute request: " << arg->ShortDebugString();
@@ -870,13 +904,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
// If we received multiple device handles, we must partition the module.
if (arg->execution_options().device_handles_size() > 1) {
- ExecuteParallelRequest parallel_arg;
- *parallel_arg.add_requests() = *arg;
- ExecuteParallelResponse parallel_result;
- TF_RETURN_IF_ERROR(ExecuteParallel(&parallel_arg, &parallel_result));
- TF_RET_CHECK(parallel_result.responses_size() > 0);
- *result = parallel_result.responses(0);
- return Status::OK();
+ return ExecuteOneToN(arg, result);
}
TF_ASSIGN_OR_RETURN(
@@ -894,7 +922,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, replicated_arguments.front(),
- arg->execution_options(), *user_computation));
+ arg->execution_options(), user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -935,9 +963,66 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
return tensorflow::Status::OK();
}
-tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* /*arg*/,
- ExecuteResponse* /*result*/) {
- return Unimplemented("execute-graph is not yet implemented");
+StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
+ const HloModuleProto& module_proto,
+ std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+ se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
+ VLOG(1) << Printf(
+ "BuildExecutable on service %p with serialized module proto: %s", this,
+ module_proto.name().c_str());
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(module_proto, *module_config));
+
+ TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+
+ TF_ASSIGN_OR_RETURN(
+ module, backend->compiler()->RunHloPasses(std::move(module), executor,
+ device_allocator));
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+ backend->compiler()->RunBackend(
+ std::move(module), executor, device_allocator));
+
+ return std::move(executable);
+}
+
+tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) {
+ VLOG(1) << "running execute-graph request";
+
+ if (!arg->has_computation()) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ // TODO(b/74197823): Handle partitioning.
+
+ TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
+ SingleComputationDeviceHandle()));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
+ ResolveAndValidateArguments(arg->arguments(), replicas));
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(arg->computation().program_shape(),
+ replicated_arguments.front(),
+ arg->execution_options()));
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ BuildExecutable(arg->computation(), std::move(module_config),
+ execute_backend_.get(),
+ execute_backend_->default_stream_executor(),
+ /*device_allocator=*/nullptr));
+
+ TF_ASSIGN_OR_RETURN(
+ *result->mutable_output(),
+ ExecuteAndRegisterResult(
+ executable.get(), replicated_arguments, execute_backend_.get(),
+ "result of " + arg->computation().name(), result->mutable_profile()));
+
+ VLOG(1) << "successfully completed 'execute-graph' request";
+ return tensorflow::Status::OK();
}
tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
@@ -967,7 +1052,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, replicated_arguments.front(),
- arg->execution_options(), *user_computation));
+ arg->execution_options(), user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -1268,7 +1353,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, {}, execution_options,
- *user_computation));
+ user_computation));
// Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 96352d9096..773f0a642d 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -115,6 +115,8 @@ class Service : public ServiceInterface {
// Executes a computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
@@ -258,7 +260,7 @@ class Service : public ServiceInterface {
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutionOptions& execution_options,
- const UserComputation& user_computation);
+ const UserComputation* user_computation = nullptr);
protected:
friend class LocalExecutable;
@@ -286,7 +288,7 @@ class Service : public ServiceInterface {
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options,
- const UserComputation& user_computation);
+ const UserComputation* user_computation = nullptr);
// Builds an Executable for the given parameters.
//
@@ -299,6 +301,15 @@ class Service : public ServiceInterface {
perftools::gputools::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator = nullptr);
+ // Builds an Executable for the given HLO module proto.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<Executable>> BuildExecutable(
+ const HloModuleProto& module_proto,
+ std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+ perftools::gputools::StreamExecutor* executor,
+ DeviceMemoryAllocator* device_allocator = nullptr);
+
// Same as BuildExecutable() above, but builds a list of Executables for the
// given computations that may interact with each other.
StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
@@ -346,6 +357,12 @@ class Service : public ServiceInterface {
const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
adder);
+ // Executes a single computation which has more than one target device.
+ // The N devices are expected to all return an empty tuple, but one, which
+ // will be the result of this computation.
+ tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg,
+ ExecuteResponse* result);
+
// Convenience function which checks whether the given shape_with_layout
// (presumably passed by the client to set the result layout) is valid for the
// given computation result shape.
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index d3d55634c9..3d3e1d60f2 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that makes the following transformations on while loops:
//
// - A while loop with static trip count of 0 is deleted.
-// - A while loops with static trip count of 1 is replaced by its body (sans
+// - A while loop with static trip count of 1 is replaced by its body (sans
// loop).
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index 063e312df6..8763e588c4 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-// HLO pass that replaces zero sized Hlos with an zero sized constant literal.
+// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
class ZeroSizedHloElimination : public HloPassInterface {
public:
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 7fb7919674..26022278e5 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -190,6 +190,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
@@ -386,6 +387,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1372,6 +1374,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
index 3f6fd7c65d..ec3b46acfe 100644
--- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -28,11 +29,11 @@ namespace {
class AxpySimpleTest : public ClientLibraryTestBase {};
TEST_F(AxpySimpleTest, AxTenValues) {
- ComputationBuilder builder(client_, "ax_10");
+ XlaBuilder builder("ax_10");
auto alpha = builder.ConstantR0<float>(3.1415926535);
auto x = builder.ConstantR1<float>(
{-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto ax = builder.Mul(alpha, x);
+ builder.Mul(alpha, x);
std::vector<float> expected = {
-3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796,
@@ -46,7 +47,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
auto x = builder.ConstantR1<float>({});
auto y = builder.ConstantR1<float>({});
auto ax = builder.Mul(alpha, x);
- auto axpy = builder.Add(ax, y);
+ builder.Add(ax, y);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -60,7 +61,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) {
auto y = builder.ConstantR1<float>(
{5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
auto ax = builder.Mul(alpha, x);
- auto axpy = builder.Add(ax, y);
+ builder.Add(ax, y);
TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape());
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index a677986cd9..d9bd1ce6eb 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -96,6 +96,20 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout) {
+ ExecutionOptions execution_options = execution_options_;
+ if (shape_with_output_layout != nullptr) {
+ *execution_options.mutable_shape_with_output_layout() =
+ *shape_with_output_layout;
+ }
+ return client_->ExecuteAndTransfer(computation, arguments,
+ &execution_options);
+}
+
+template <>
+StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout) {
@@ -104,6 +118,15 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
+template <>
+StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout) {
+ // Build the computation, as a convenience.
+ TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
+}
+
std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
@@ -142,16 +165,18 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
arguments);
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
shape_with_layout));
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
@@ -249,8 +274,28 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return choose(0);
}
+tensorflow::Status
+ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
+ const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
+ tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
+ const std::function<void(const Literal& actual,
+ const string& error_message)>& /*verify_output*/) {
+ return Unimplemented("not yet implemented for XlaComputation");
+}
+
+tensorflow::Status
+ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
+ const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
+ tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
+ const std::function<void(const Literal& actual,
+ const string& error_message)>& /*verify_output*/,
+ const Shape* /*output_with_layout*/) {
+ return Unimplemented("not yet implemented for XlaComputation");
+}
+
+template <typename BuilderT>
tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
@@ -307,8 +352,9 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
return tensorflow::Status::OK();
}
+template <typename BuilderT>
tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
ErrorSpec error, const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
@@ -522,33 +568,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
return array;
}
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(
- int64 parameter_number, const Literal& literal, const string& name,
- ComputationBuilder* builder, ComputationDataHandle* data_handle) {
- return CreateParameterAndTransferLiteral(parameter_number, literal, name,
- nullptr, builder, data_handle);
-}
-
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(
- int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, ComputationBuilder* builder,
- ComputationDataHandle* data_handle) {
- const Literal* param_literal = &literal;
- std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_) {
- converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
- param_literal = converted_literal.get();
- }
- std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal, device_handle)
- .ConsumeValueOrDie();
- *data_handle =
- builder->Parameter(parameter_number, param_literal->shape(), name);
- return data;
-}
-
ComputationDataHandle ClientLibraryTestBase::AddParam(
const Literal& argument, ComputationBuilder* builder) {
ComputationDataHandle data_handle;
@@ -563,4 +582,24 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
+template void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout);
+
+template void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout);
+
+template void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout);
+
+template void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout);
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index ba0319990b..01aa6c756f 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -94,15 +95,22 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ template <typename BuilderT>
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout = nullptr);
+
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout = nullptr);
+ StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr);
+
// Convenience OrDie variants of above methods.
std::unique_ptr<GlobalData> ExecuteOrDie(
ComputationBuilder* builder,
@@ -130,12 +138,12 @@ class ClientLibraryTestBase : public ::testing::Test {
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
- template <typename NativeT>
- void ComputeAndCompareR1(ComputationBuilder* builder,
+ template <typename NativeT, typename BuilderT>
+ void ComputeAndCompareR1(BuilderT* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT>
- void ComputeAndCompareR1(ComputationBuilder* builder,
+ template <typename NativeT, typename BuilderT>
+ void ComputeAndCompareR1(BuilderT* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
@@ -179,22 +187,26 @@ class ClientLibraryTestBase : public ::testing::Test {
// Build and run the computation and compare the result with the given
// literal. shape_with_layout indicates the result layout to request when
// calling Execute.
+ template <typename BuilderT>
void ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout = nullptr);
+ template <typename BuilderT>
void ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
// ComputeAndCompare variant which returns an error status.
+ template <typename BuilderT>
tensorflow::Status ComputeAndCompareLiteralWithStatus(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout = nullptr);
+ template <typename BuilderT>
tensorflow::Status ComputeAndCompareLiteralWithStatus(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
@@ -266,17 +278,19 @@ class ClientLibraryTestBase : public ::testing::Test {
// server, then stores into "data_handle" the global handle for that
// parameter. When the use_bfloat16 flag is set but the literal has F32
// elements, the literal will be converted to BF16 before being transferred.
+ template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
- ComputationBuilder* builder, ComputationDataHandle* data_handle);
+ BuilderT* builder, HandleT* data_handle);
// As above, but the caller can specify the device that the literal is
// transferred to. If device_handle is nullptr, the literal will be
// transferred to the default device.
+ template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, ComputationBuilder* builder,
- ComputationDataHandle* data_handle);
+ const DeviceHandle* device_handle, BuilderT* builder,
+ HandleT* data_handle);
// Creates a parameter instruction and sets the value that will be passed to
// the computation as specified. This function must be used for all parameters
@@ -399,6 +413,18 @@ class ClientLibraryTestBase : public ::testing::Test {
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
+ tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
+ const xla::XlaComputation& computation, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const std::function<void(const Literal& actual,
+ const string& error_message)>& verify_output);
+ tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
+ const xla::XlaComputation& computation, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const std::function<void(const Literal& actual,
+ const string& error_message)>& verify_output,
+ const Shape* output_with_layout = nullptr);
+
// Executes the computation and calculates the expected reference value using
// the HloEvaluator. Returns two literal in the order of (expected, actual).
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
@@ -440,9 +466,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
arguments, error);
}
-template <typename NativeT>
+template <typename NativeT, typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<NativeT>(expected);
@@ -450,9 +476,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
arguments);
}
-template <typename NativeT>
+template <typename NativeT, typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
@@ -628,6 +654,37 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
return result;
}
+template <typename BuilderT, typename HandleT>
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
+ const Literal& literal,
+ const string& name,
+ BuilderT* builder,
+ HandleT* data_handle) {
+ return CreateParameterAndTransferLiteral(parameter_number, literal, name,
+ nullptr, builder, data_handle);
+}
+
+template <typename BuilderT, typename HandleT>
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ const DeviceHandle* device_handle, BuilderT* builder,
+ HandleT* data_handle) {
+ const Literal* param_literal = &literal;
+ std::unique_ptr<Literal> converted_literal;
+ if (use_bfloat16_) {
+ converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ param_literal = converted_literal.get();
+ }
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*param_literal, device_handle)
+ .ConsumeValueOrDie();
+ *data_handle =
+ builder->Parameter(parameter_number, param_literal->shape(), name);
+ return data;
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index f7b04debd4..02272d6017 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -207,9 +208,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
//
// Splits an empty vector into an empty matrix.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR1<float>({});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0},
@@ -221,10 +222,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) {
// Splits a vector into a matrix.
XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal =
Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0},
@@ -241,9 +242,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
//
// Transposes a 2x0 array to a 0x2 array.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -255,10 +256,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) {
// Transposes a 2-dimensional row vector to a column vector.
XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
auto input_literal = Literal::CreateFromArray(*simple);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -272,10 +273,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
// Transposes a 2-dimensional array.
XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = Literal::CreateFromArray(*a4x3);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
@@ -291,11 +292,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
// with an incorrect result rank.
//
-// Transposes a 0x4 array with ComputationBuilder::Trans.
+// Transposes a 0x4 array with XlaBuilder::Transpose.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Transpose(parameter, {1, 0});
@@ -306,10 +307,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) {
// Transposes a 2-dimensional array with ComputationBuilder::Trans.
XLA_TEST_P(ReshapeTest, Transpose4x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = Literal::CreateFromArray(*a4x3);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Transpose(parameter, {1, 0});
@@ -327,9 +328,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
// Reshapes an empty 2-dimensional array with dimensions that are not just a
// rearrangement of the originals (split), but no reordering (no shuffle).
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index dae402204f..dcd235f876 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -13,20 +13,23 @@ load("//tensorflow:tensorflow.bzl", "py_test")
filegroup(
name = "all_files",
srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
+ include = ["**/*"],
+ exclude = ["**/OWNERS"],
),
visibility = ["//tensorflow:__subpackages__"],
)
py_library(
name = "init_py",
- srcs = [
- "__init__.py",
- ],
+ srcs = ["__init__.py"],
srcs_version = "PY2AND3",
+ deps = [
+ "custom_export_strategy",
+ ":custom_loss_head",
+ ":estimator",
+ ":model",
+ ":trainer_hooks",
+ ],
)
py_library(
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index f70b29c43b..8cfe4a727a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -479,6 +479,10 @@ py_test(
size = "small",
srcs = ["prefetching_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "no_oss",
+ ],
deps = [
"//tensorflow/contrib/data/python/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 94f800e8a5..d0131896a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -468,6 +468,31 @@ class BucketBySequenceLength(test.TestCase):
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
self.assertEqual(sorted(boundaries), sorted(lengths_val))
+ def testTupleElements(self):
+
+ def elements_gen():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (x, y)
+
+ def element_length_fn(x, y):
+ del y
+ return array_ops.shape(x)[0]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=elements_gen,
+ output_shapes=(tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([])),
+ output_types=(dtypes.int32, dtypes.int32))
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8]))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ae10d2eb22..36591c055a 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -140,9 +140,9 @@ def bucket_by_sequence_length(element_length_func,
batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
- def element_to_bucket_id(element):
+ def element_to_bucket_id(*args):
"""Return int64 id of the length bucket for this element."""
- seq_length = element_length_func(element)
+ seq_length = element_length_func(*args)
boundaries = list(bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index 507ceb3585..68e0d9cb82 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -16,6 +16,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib import distributions
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-ds = distributions
+tfd = distributions
class DistributionTest(test.TestCase):
def testParamShapesAndFromParams(self):
classes = [
- ds.Normal,
- ds.Bernoulli,
- ds.Beta,
- ds.Chi2,
- ds.Exponential,
- ds.Gamma,
- ds.InverseGamma,
- ds.Laplace,
- ds.StudentT,
- ds.Uniform,
+ tfd.Normal,
+ tfd.Bernoulli,
+ tfd.Beta,
+ tfd.Chi2,
+ tfd.Exponential,
+ tfd.Gamma,
+ tfd.InverseGamma,
+ tfd.Laplace,
+ tfd.StudentT,
+ tfd.Uniform,
]
sample_shapes = [(), (10,), (10, 20, 30)]
@@ -63,15 +65,15 @@ class DistributionTest(test.TestCase):
with self.test_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
- normal = ds.Normal(loc=1., scale=2., validate_args=True)
+ normal = tfd.Normal(loc=1., scale=2., validate_args=True)
self.assertEqual(normal.parameters, normal.copy().parameters)
- wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]],
- validate_args=True)
+ wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]],
+ validate_args=True)
self.assertEqual(wishart.parameters, wishart.copy().parameters)
def testCopyOverride(self):
with self.test_session():
- normal = ds.Normal(loc=1., scale=2., validate_args=True)
+ normal = tfd.Normal(loc=1., scale=2., validate_args=True)
unused_normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
copy_params = normal.copy(validate_args=False).parameters.copy()
@@ -84,19 +86,19 @@ class DistributionTest(test.TestCase):
mu = 1.
sigma = 2.
- normal = ds.Normal(mu, sigma, validate_args=True)
+ normal = tfd.Normal(mu, sigma, validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch()))
- normal = ds.Normal([mu], [sigma], validate_args=True)
+ normal = tfd.Normal([mu], [sigma], validate_args=True)
self.assertTrue(tensor_util.constant_value(normal.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch()))
- mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True)
+ mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch()))
- mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
+ mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True)
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event()))
self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch()))
@@ -126,7 +128,7 @@ class DistributionTest(test.TestCase):
self.assertFalse(is_scalar.eval(feed_dict={x: [1]}))
def _GetFakeDistribution(self):
- class FakeDistribution(ds.Distribution):
+ class FakeDistribution(tfd.Distribution):
"""Fake Distribution for testing _set_sample_static_shape."""
def __init__(self, batch_shape=None, event_shape=None):
@@ -188,6 +190,105 @@ class DistributionTest(test.TestCase):
y = dist._set_sample_static_shape(x, sample_shape)
self.assertTrue(y.get_shape().ndims is None)
+ def testStrWorksCorrectlyScalar(self):
+ normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
+ self.assertEqual(
+ ("tf.distributions.Normal("
+ "\"Normal\", "
+ "batch_shape=(), "
+ "event_shape=(), "
+ "dtype=float16)"), # Got the dtype right.
+ str(normal))
+
+ chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
+ self.assertEqual(
+ ("tf.distributions.Chi2("
+ "\"silly\", " # What a silly name that is!
+ "batch_shape=(2,), "
+ "event_shape=(), "
+ "dtype=float32)"),
+ str(chi2))
+
+ exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
+ self.assertEqual(
+ ("tf.distributions.Exponential(\"Exponential\", "
+ # No batch shape.
+ "event_shape=(), "
+ "dtype=float32)"),
+ str(exp))
+
+ def testStrWorksCorrectlyMultivariate(self):
+ mvn_static = tfd.MultivariateNormalDiag(
+ loc=np.zeros([2, 2]), name="MVN")
+ self.assertEqual(
+ ("tf.distributions.MultivariateNormalDiag("
+ "\"MVN\", "
+ "batch_shape=(2,), "
+ "event_shape=(2,), "
+ "dtype=float64)"),
+ str(mvn_static))
+
+ mvn_dynamic = tfd.MultivariateNormalDiag(
+ loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
+ name="MVN2")
+ self.assertEqual(
+ ("tf.distributions.MultivariateNormalDiag("
+ "\"MVN2\", "
+ "batch_shape=(?,), " # Partially known.
+ "event_shape=(3,), "
+ "dtype=float32)"),
+ str(mvn_dynamic))
+
+ def testReprWorksCorrectlyScalar(self):
+ normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
+ self.assertEqual(
+ ("<tf.distributions.Normal"
+ " 'Normal'"
+ " batch_shape=()"
+ " event_shape=()"
+ " dtype=float16>"), # Got the dtype right.
+ repr(normal))
+
+ chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
+ self.assertEqual(
+ ("<tf.distributions.Chi2"
+ " 'silly'" # What a silly name that is!
+ " batch_shape=(2,)"
+ " event_shape=()"
+ " dtype=float32>"),
+ repr(chi2))
+
+ exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
+ self.assertEqual(
+ ("<tf.distributions.Exponential"
+ " 'Exponential'"
+ " batch_shape=<unknown>"
+ " event_shape=()"
+ " dtype=float32>"),
+ repr(exp))
+
+ def testReprWorksCorrectlyMultivariate(self):
+ mvn_static = tfd.MultivariateNormalDiag(
+ loc=np.zeros([2, 2]), name="MVN")
+ self.assertEqual(
+ ("<tf.distributions.MultivariateNormalDiag"
+ " 'MVN'"
+ " batch_shape=(2,)"
+ " event_shape=(2,)"
+ " dtype=float64>"),
+ repr(mvn_static))
+
+ mvn_dynamic = tfd.MultivariateNormalDiag(
+ loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
+ name="MVN2")
+ self.assertEqual(
+ ("<tf.distributions.MultivariateNormalDiag"
+ " 'MVN2'"
+ " batch_shape=(?,)" # Partially known.
+ " event_shape=(3,)"
+ " dtype=float32>"),
+ repr(mvn_dynamic))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index 2b7e199fad..b80c909023 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -32,6 +32,7 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.examples.tutorials.mnist import input_data
+layers = tf.keras.layers
FLAGS = None
@@ -56,15 +57,15 @@ class Discriminator(tf.keras.Model):
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
- self.conv1 = tf.layers.Conv2D(
+ self.conv1 = layers.Conv2D(
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
- self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
- self.conv2 = tf.layers.Conv2D(
+ self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.conv2 = layers.Conv2D(
128, 5, data_format=data_format, activation=tf.tanh)
- self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format)
- self.flatten = tf.layers.Flatten()
- self.fc1 = tf.layers.Dense(1024, activation=tf.tanh)
- self.fc2 = tf.layers.Dense(1, activation=None)
+ self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
+ self.flatten = layers.Flatten()
+ self.fc1 = layers.Dense(1024, activation=tf.tanh)
+ self.fc2 = layers.Dense(1, activation=None)
def call(self, inputs):
"""Return two logits per image estimating input authenticity.
@@ -112,16 +113,16 @@ class Generator(tf.keras.Model):
else:
assert data_format == 'channels_last'
self._pre_conv_shape = [-1, 6, 6, 128]
- self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh)
+ self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
# In call(), we reshape the output of fc1 to _pre_conv_shape
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
- self.conv1 = tf.layers.Conv2DTranspose(
+ self.conv1 = layers.Conv2DTranspose(
64, 4, strides=2, activation=None, data_format=data_format)
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
- self.conv2 = tf.layers.Conv2DTranspose(
+ self.conv2 = layers.Conv2DTranspose(
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
def call(self, inputs):
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
index 6ab847cb78..4e1380afb2 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
@@ -32,6 +32,8 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
+layers = tf.keras.layers
+
class LinearModel(tf.keras.Model):
"""A TensorFlow linear regression model."""
@@ -39,7 +41,7 @@ class LinearModel(tf.keras.Model):
def __init__(self):
"""Constructs a LinearModel object."""
super(LinearModel, self).__init__()
- self._hidden_layer = tf.layers.Dense(1)
+ self._hidden_layer = layers.Dense(1)
def call(self, xs):
"""Invoke the linear model.
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index 6b59413141..a28bc8a43d 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -28,6 +28,8 @@ import functools
import tensorflow as tf
+layers = tf.keras.layers
+
class _IdentityBlock(tf.keras.Model):
"""_IdentityBlock is the block that has no conv layer at shortcut.
@@ -49,23 +51,23 @@ class _IdentityBlock(tf.keras.Model):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = tf.layers.Conv2D(
+ self.conv2a = layers.Conv2D(
filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
- self.bn2a = tf.layers.BatchNormalization(
+ self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a')
- self.conv2b = tf.layers.Conv2D(
+ self.conv2b = layers.Conv2D(
filters2,
kernel_size,
padding='same',
data_format=data_format,
name=conv_name_base + '2b')
- self.bn2b = tf.layers.BatchNormalization(
+ self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b')
- self.conv2c = tf.layers.Conv2D(
+ self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
- self.bn2c = tf.layers.BatchNormalization(
+ self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c')
def call(self, input_tensor, training=False):
@@ -113,34 +115,34 @@ class _ConvBlock(tf.keras.Model):
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
- self.conv2a = tf.layers.Conv2D(
+ self.conv2a = layers.Conv2D(
filters1, (1, 1),
strides=strides,
name=conv_name_base + '2a',
data_format=data_format)
- self.bn2a = tf.layers.BatchNormalization(
+ self.bn2a = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2a')
- self.conv2b = tf.layers.Conv2D(
+ self.conv2b = layers.Conv2D(
filters2,
kernel_size,
padding='same',
name=conv_name_base + '2b',
data_format=data_format)
- self.bn2b = tf.layers.BatchNormalization(
+ self.bn2b = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2b')
- self.conv2c = tf.layers.Conv2D(
+ self.conv2c = layers.Conv2D(
filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
- self.bn2c = tf.layers.BatchNormalization(
+ self.bn2c = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '2c')
- self.conv_shortcut = tf.layers.Conv2D(
+ self.conv_shortcut = layers.Conv2D(
filters3, (1, 1),
strides=strides,
name=conv_name_base + '1',
data_format=data_format)
- self.bn_shortcut = tf.layers.BatchNormalization(
+ self.bn_shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1')
def call(self, input_tensor, training=False):
@@ -219,15 +221,15 @@ class ResNet50(tf.keras.Model):
return _IdentityBlock(
3, filters, stage=stage, block=block, data_format=data_format)
- self.conv1 = tf.layers.Conv2D(
+ self.conv1 = layers.Conv2D(
64, (7, 7),
strides=(2, 2),
data_format=data_format,
padding='same',
name='conv1')
bn_axis = 1 if data_format == 'channels_first' else 3
- self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
- self.max_pool = tf.layers.MaxPooling2D(
+ self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
+ self.max_pool = layers.MaxPooling2D(
(3, 3), strides=(2, 2), data_format=data_format)
self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
@@ -250,11 +252,12 @@ class ResNet50(tf.keras.Model):
self.l5b = id_block([512, 512, 2048], stage=5, block='b')
self.l5c = id_block([512, 512, 2048], stage=5, block='c')
- self.avg_pool = tf.layers.AveragePooling2D(
+ self.avg_pool = layers.AveragePooling2D(
(7, 7), strides=(7, 7), data_format=data_format)
if self.include_top:
- self.fc1000 = tf.layers.Dense(classes, name='fc1000')
+ self.flatten = layers.Flatten()
+ self.fc1000 = layers.Dense(classes, name='fc1000')
else:
reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
reduction_indices = tf.constant(reduction_indices)
@@ -298,7 +301,7 @@ class ResNet50(tf.keras.Model):
x = self.avg_pool(x)
if self.include_top:
- return self.fc1000(tf.layers.flatten(x))
+ return self.fc1000(self.flatten(x))
elif self.global_pooling:
return self.global_pooling(x)
else:
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
index 88fffc962f..492adbe1d8 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
@@ -73,6 +73,8 @@ try:
except ImportError:
HAS_MATPLOTLIB = False
+layers = tf.keras.layers
+
def parse(line):
"""Parse a line from the colors dataset."""
@@ -152,7 +154,7 @@ class RNNColorbot(tf.keras.Model):
self.cells = self._add_cells(
[tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes])
- self.relu = tf.layers.Dense(
+ self.relu = layers.Dense(
label_dimension, activation=tf.nn.relu, name="relu")
def call(self, inputs, training=False):
@@ -204,7 +206,7 @@ class RNNColorbot(tf.keras.Model):
def _add_cells(self, cells):
# "Magic" required for keras.Model classes to track all the variables in
- # a list of tf.layers.Layer objects.
+ # a list of layers.Layer objects.
# TODO(ashankar): Figure out API so user code doesn't have to do this.
for i, c in enumerate(cells):
setattr(self, "cell-%d" % i, c)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index 69cd16d12c..a90048d813 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -38,6 +38,8 @@ import tensorflow as tf
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.eager.python import tfe
+layers = tf.keras.layers
+
class RNN(tf.keras.Model):
"""A static RNN.
@@ -74,14 +76,14 @@ class RNN(tf.keras.Model):
def _add_cells(self, cells):
# "Magic" required for keras.Model classes to track all the variables in
- # a list of tf.layers.Layer objects.
+ # a list of Layer objects.
# TODO(ashankar): Figure out API so user code doesn't have to do this.
for i, c in enumerate(cells):
setattr(self, "cell-%d" % i, c)
return cells
-class Embedding(tf.layers.Layer):
+class Embedding(layers.Layer):
"""An Embedding layer."""
def __init__(self, vocab_size, embedding_dim, **kwargs):
@@ -132,7 +134,7 @@ class PTBModel(tf.keras.Model):
else:
self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio)
- self.linear = tf.layers.Dense(
+ self.linear = layers.Dense(
vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))
self._output_shape = [-1, embedding_dim]
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 3f9a7818a5..6673653418 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -33,6 +33,7 @@ import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.spinn import data
from third_party.examples.eager.spinn import spinn
+from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
@@ -172,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
right_in.append(tf.random_normal((1, size * 2)))
tracking.append(tf.random_normal((1, tracker_size * 2)))
- out = reducer(left_in, right_in, tracking=tracking)
+ out = reducer(left_in, right_in=right_in, tracking=tracking)
self.assertEqual(batch_size, len(out))
self.assertEqual(tf.float32, out[0].dtype)
self.assertEqual((1, size * 2), out[0].shape)
@@ -226,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
for _ in range(2):
- out1, out2 = tracker(bufs, stacks)
+ out1, out2 = tracker(bufs, stacks=stacks)
self.assertIsNone(out2)
self.assertEqual(batch_size, len(out1))
self.assertEqual(tf.float32, out1[0].dtype)
@@ -259,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual(tf.int64, transitions.dtype)
self.assertEqual((num_transitions, 1), transitions.shape)
- out = s(buffers, transitions, training=True)
+ out = s(buffers, transitions=transitions, training=True)
self.assertEqual(tf.float32, out.dtype)
self.assertEqual((1, embedding_dims), out.shape)
@@ -285,12 +286,15 @@ class SpinnTest(test_util.TensorFlowTestCase):
vocab_size)
# Invoke model under non-training mode.
- logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = model(
+ prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=False)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# Invoke model under training model.
- logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
+ logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=True)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
@@ -420,8 +424,14 @@ class SpinnTest(test_util.TensorFlowTestCase):
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
- ckpt_variable_names = [
- item[0] for item in checkpoint_utils.list_variables(config.logdir)]
+ object_graph_string = checkpoint_utils.load_variable(
+ config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
+ object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
+ object_graph.ParseFromString(object_graph_string)
+ ckpt_variable_names = set()
+ for node in object_graph.nodes:
+ for attribute in node.attributes:
+ ckpt_variable_names.add(attribute.full_name)
self.assertIn("global_step", ckpt_variable_names)
for v in trainer.variables:
variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 24374266dc..c846343d6d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -358,7 +358,7 @@ cuda_py_test(
size = "medium",
srcs = ["python/estimator/replicate_model_fn_test.py"],
additional_deps = [
- "//third_party/py/absl/testing:parameterized",
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:export_export",
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 42e1b7b68c..74da2cbb3f 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -304,7 +304,7 @@ def multi_label_head(n_classes,
weight_column=None,
thresholds=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
"""Creates a `_Head` for multi-label classification.
@@ -355,7 +355,8 @@ def multi_label_head(n_classes,
string type and have any value in `label_vocabulary`. Also there will be
errors if vocabulary is not provided and labels are string.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
+ weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -404,7 +405,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
weight_column=None,
thresholds=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
self._n_classes = n_classes
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 776f0ee341..8837dfdc6c 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -272,9 +272,9 @@ class MultiLabelHead(test.TestCase):
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
- # loss = labels * -log(sigmoid(logits)) +
- # (1 - labels) * -log(1 - sigmoid(logits))
- expected_training_loss = np.sum(
+ # loss = (labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))) / 2
+ expected_training_loss = 0.5 * np.sum(
_sigmoid_cross_entropy(labels=labels, logits=logits))
actual_training_loss = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
@@ -298,7 +298,7 @@ class MultiLabelHead(test.TestCase):
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
- expected_training_loss = np.sum(
+ expected_training_loss = 0.5 * np.sum(
np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32))
actual_training_loss = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
@@ -361,7 +361,7 @@ class MultiLabelHead(test.TestCase):
labels=labels_input)[0]
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
- self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+ self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval())
def test_eval_create_loss_loss_fn_wrong_shape(self):
"""Tests custom loss_fn that returns Tensor of unexpected shape."""
@@ -438,12 +438,13 @@ class MultiLabelHead(test.TestCase):
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -468,14 +469,13 @@ class MultiLabelHead(test.TestCase):
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -533,14 +533,13 @@ class MultiLabelHead(test.TestCase):
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -562,15 +561,14 @@ class MultiLabelHead(test.TestCase):
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -603,8 +601,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, weighted sum over examples.
- expected_loss = 25.
+ # Average over classes, weighted sum over examples, divide by batch_size.
+ # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2
+ expected_loss = 12.5
spec = head.create_estimator_spec(
features={
@@ -617,8 +616,8 @@ class MultiLabelHead(test.TestCase):
keys = metric_keys.MetricKeys
expected_metrics = {
- # Average loss over weighted examples.
- keys.LOSS_MEAN: expected_loss / 3,
+ # Average loss over weighted examples (denominator is sum(weights)).
+ keys.LOSS_MEAN: expected_loss * (2. / 3.),
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.2000,
@@ -663,7 +662,7 @@ class MultiLabelHead(test.TestCase):
# (1 - labels) * (logits > 0) * logits
expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]
expected_weights = [[1.], [2.]]
- expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.
+ expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2.
training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
features={
'x': np.array(((42,),), dtype=np.int32),
@@ -809,11 +808,8 @@ class MultiLabelHead(test.TestCase):
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- _assert_simple_summaries(self, {
- metric_keys.MetricKeys.LOSS: expected_loss,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
- }, summary_str, tol)
+ _assert_simple_summaries(
+ self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol)
def test_train(self):
head = head_lib.multi_label_head(n_classes=2)
@@ -823,8 +819,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -840,8 +837,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -858,8 +856,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -871,8 +870,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
class _Optimizer(object):
@@ -952,8 +952,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, weighted sum over examples.
- expected_loss = 25.
+ # Average over classes, weighted sum over examples, divide by batch_size.
+ # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2
+ expected_loss = 12.5
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -987,11 +988,8 @@ class MultiLabelHead(test.TestCase):
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- _assert_simple_summaries(self, {
- metric_keys.MetricKeys.LOSS: expected_loss,
- # Average loss over weighted examples.
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
- }, summary_str, tol)
+ _assert_simple_summaries(
+ self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol)
def test_multi_dim_weighted_train_create_loss(self):
"""Logits and labels of shape [2, 2, 3], weights [2, 2]."""
@@ -1008,8 +1006,8 @@ class MultiLabelHead(test.TestCase):
expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]]
# weights are reshaped to [2, 2, 1] to match logits.
expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_training_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_training_loss = 9.9167
training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
features={'weights': weights},
mode=model_fn.ModeKeys.TRAIN,
@@ -1035,8 +1033,8 @@ class MultiLabelHead(test.TestCase):
weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
# loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
# = [[20/3, 10/3], [4, 8]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_loss = 9.9167
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -1124,11 +1122,11 @@ class MultiLabelHead(test.TestCase):
weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
# loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
# = [[20/3, 10/3], [4, 8]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_loss = 9.9167
keys = metric_keys.MetricKeys
expected_metrics = {
- keys.LOSS_MEAN: expected_loss / np.sum(weights),
+ keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)),
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.4977,
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 43cc157a1f..74d3d6d728 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -299,10 +299,11 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
# head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
- # Average over classes, weighted sum over batch and heads.
- expected_loss_head1 = 17.5
- expected_loss_head2 = 30.0
+ # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15
+ expected_loss_head1 = 8.75
+ expected_loss_head2 = 15.
expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
spec = multi_head.create_estimator_spec(
@@ -316,8 +317,8 @@ class MultiHeadTest(test.TestCase):
keys.LOSS + '/head1': expected_loss_head1,
keys.LOSS + '/head2': expected_loss_head2,
# Average loss over examples.
- keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
- keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
+ keys.LOSS_MEAN + '/head1': expected_loss_head1,
+ keys.LOSS_MEAN + '/head2': expected_loss_head2,
# auc and auc_pr cannot be reliably calculated for only 4-6 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC + '/head1': 0.1667,
@@ -363,8 +364,8 @@ class MultiHeadTest(test.TestCase):
tol = 1e-3
with self.test_session():
# Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
- # (averaged over classes, sum-reduced over examples).
- self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol)
+ # (averaged over classes, averaged over examples).
+ self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol)
def test_train_create_loss_two_heads_with_weights(self):
# Use different example weighting for each head weighting.
@@ -399,18 +400,18 @@ class MultiHeadTest(test.TestCase):
with self.test_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
- # training_loss = 1 * 10 + 2 * 7.5 = 25
+ # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
# head-weighted unreduced_loss = 1 * [10, 7.5]
self.assertAllClose(
[[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
# loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
# = [20, 10]
- # training_loss = 2 * 20 + 3 * 10 = 70
+ # training_loss = (2 * 20 + 3 * 10) / 2 = 35
# head-weighted unreduced_loss = 2 * [20, 10]
self.assertAllClose(
[[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
- # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
- self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
+ # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5
+ self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol)
# head-weighted example weights
self.assertAllClose(
[[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
@@ -447,18 +448,18 @@ class MultiHeadTest(test.TestCase):
with self.test_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
- # training_loss = 1 * 10 + 2 * 7.5 = 25
+ # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
# head-weighted unreduced_loss = 1 * [10, 7.5]
self.assertAllClose(
[[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
# loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
# = [20, 10]
- # training_loss = 2 * 20 + 3 * 10 = 70
+ # training_loss = (2 * 20 + 3 * 10) / 2 = 35
# head-weighted unreduced_loss = 2 * [20, 10]
self.assertAllClose(
[[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
- # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
- self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
+ # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5
+ self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol)
# head-weighted example weights
self.assertAllClose(
[[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
@@ -511,8 +512,8 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -546,8 +547,6 @@ class MultiHeadTest(test.TestCase):
_assert_simple_summaries(self, {
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
}, summary_str, tol)
def test_train_one_head_with_optimizer(self):
@@ -560,8 +559,8 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
class _Optimizer(object):
@@ -607,10 +606,12 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
# head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
+ # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15
# Average over classes, weighted sum over batch and heads.
- expected_loss_head1 = 17.5
- expected_loss_head2 = 30.0
+ expected_loss_head1 = 8.75
+ expected_loss_head2 = 15.0
expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
@@ -646,9 +647,6 @@ class MultiHeadTest(test.TestCase):
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,
metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
- metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
}, summary_str, tol)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index 30c5404e03..f22dbcf215 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.contrib.kfac.python.ops import estimator
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -40,30 +39,6 @@ from tensorflow.python.training import training_util
_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
-class DeviceContextGeneratorTest(test.TestCase):
-
- def testNoDevice(self):
- device_context_generator = estimator._DeviceContextGenerator(None)
- with ops.device("/device:CPU:0"): # This is what will be used
- with device_context_generator(): # Does nothing
- a = constant_op.constant([2.0], name="a")
- self.assertEqual("/device:CPU:0", a.op.device)
-
- def testTwoDevices(self):
- device_context_generator = estimator._DeviceContextGenerator(
- ["/device:GPU:0", "/device:GPU:1"])
- with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes
- with device_context_generator():
- a = constant_op.constant([2.0], name="a")
- with device_context_generator():
- b = constant_op.constant([2.0], name="b")
- with device_context_generator():
- c = constant_op.constant([2.0], name="c")
- self.assertEqual("/device:GPU:0", a.op.device)
- self.assertEqual("/device:GPU:1", b.op.device)
- self.assertEqual("/device:GPU:0", c.op.device)
-
-
class EstimatorTest(test.TestCase):
def setUp(self):
@@ -90,68 +65,98 @@ class EstimatorTest(test.TestCase):
def testEstimatorInitManualRegistration(self):
with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection)
+ estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights, self.bias],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
est.make_ops_and_vars()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
est.make_ops_and_vars()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection)
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
est.make_ops_and_vars()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="not_a_real_mode")
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="not_a_real_mode")
est.make_ops_and_vars()
def testGradientsModeBuild(self):
with self._graph.as_default():
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="gradients")
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="gradients")
est.make_ops_and_vars()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="empirical")
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="empirical")
est.make_ops_and_vars()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="curvature_prop")
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="curvature_prop")
est.make_ops_and_vars()
def testExactModeBuild(self):
with self._graph.as_default():
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
- self.layer_collection,
- estimation_mode="exact")
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="exact")
est.make_ops_and_vars()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
- fisher_estimator = estimator.FisherEstimator(
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
@@ -159,8 +164,8 @@ class EstimatorTest(test.TestCase):
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, cov_update_op_thunks,
- _, _) = fisher_estimator.create_ops_and_vars_thunks()
+ (cov_variable_thunks, cov_update_op_thunks, _,
+ _) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
cov_matrices = [
@@ -198,10 +203,43 @@ class EstimatorTest(test.TestCase):
sess.run(cov_update_op)
sess.run(increment_global_step)
+ def test_round_robin_placement(self):
+ """Check if the ops and variables are placed on devices correctly."""
+ with self._graph.as_default():
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0,
+ cov_devices=["/cpu:{}".format(i) for i in range(2)],
+ inv_devices=["/cpu:{}".format(i) for i in range(2)])
+
+ # Construct an op that executes one covariance update per step.
+ (cov_update_ops, _, inv_update_ops, _, _,
+ _) = fisher_estimator.make_ops_and_vars(scope="test")
+ self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
+ self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
+ ]
+ self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
+ self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
+ # Inverse matrices need to be explicitly placed.
+ self.assertEqual(inv_matrices[0].device, "")
+ self.assertEqual(inv_matrices[1].device, "")
+
def test_inv_update_thunks(self):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
- fisher_estimator = estimator.FisherEstimator(
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index c26230c2a8..d721ad08af 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -171,6 +171,7 @@ py_library(
name = "fisher_estimator",
srcs = [
"estimator.py",
+ "placement.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -180,6 +181,7 @@ py_library(
"//tensorflow/python:gradients",
"//tensorflow/python:util",
"//third_party/py/numpy",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index 64755be65c..ced1110676 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import contextlib
-import itertools
-
+import abc
import numpy as np
+import six
+from tensorflow.contrib.kfac.python.ops import placement
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import control_flow_ops
@@ -31,63 +31,46 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
-class _DeviceContextGenerator(object):
- """Class for generating device contexts in a round-robin fashion."""
-
- def __init__(self, devices):
- """Creates a _DeviceContextGenerator object.
-
- Example usage:
+# The linter is confused.
+# pylint: disable=abstract-class-instantiated
+def make_fisher_estimator(placement_strategy=None, **kwargs):
+ """Creates Fisher estimator instances based on the placement strategy.
- ```python
- dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1'])
- with dcg():
- # All operations in this context will be placed on GPU 0
- ...
- with dcg():
- # All operations in this context will be placed on GPU 1
- ...
- ```
-
- Args:
- devices: An iterable of device strings (or None). Successive calls to
- __call__ will give contexts which place devices on these devices in
- a round-robin fashion.
- """
- self._cycle = None if devices is None else itertools.cycle(devices)
+ For example if the `placement_strategy` is 'round_robin' then
+ `FisherEstimatorRoundRobin` instance is returned.
- @contextlib.contextmanager
- def __call__(self):
- """Returns a context manager specifying the default device."""
- if self._cycle is None:
- yield
- else:
- with tf_ops.device(next(self._cycle)):
- yield
+ Args:
+ placement_strategy: `string`, Strategy to be used for placing covariance
+ variables, covariance ops and inverse ops. Check
+ `placement.FisherEstimatorRoundRobin` for a concrete example.
+ **kwargs: Arguments to be passed into `FisherEstimator` class initializer.
+ Returns:
+ An instance of class which inherits from `FisherEstimator` and the mixin
+ which implements specific placement strategy. See,
+ `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
+ `RoundRobinPlacementMixin`.
-def _make_thunk_on_device(func, device):
- def thunk():
- with tf_ops.device(device):
- return func()
- return thunk
+ Raises:
+ ValueError: If the `placement_strategy` is not equal to 'round_robin'.
+ """
+ if placement_strategy in [None, "round_robin"]:
+ return FisherEstimatorRoundRobin(**kwargs)
+ else:
+ raise ValueError("Unimplemented vars and ops placement strategy : %s",
+ placement_strategy)
+# pylint: enable=abstract-class-instantiated
+@six.add_metaclass(abc.ABCMeta)
class FisherEstimator(object):
"""Fisher estimator class supporting various approximations of the Fisher.
- Attributes:
- cov_update_thunks: list of no-arg functions. Executing a function adds
- covariance update ops for a single FisherFactor to the graph.
- cov_update_ops: List of Ops. Running an op updates covariance matrices for a
- single FisherFactor.
- cov_update_op: Op. Running updates covariance matrices for all
- FisherFactors.
- inv_update_thunks: list of no-arg functions. Executing a function adds
- inverse update ops for a single FisherFactor to the graph.
- inv_update_ops: List of Ops. Running an op updates inverse matrices for a
- single FisherFactor.
- inv_update_op: Op. Running updates inverse matrices for all FisherFactors.
+ This is an abstract base class which does not implement a strategy for
+ placing covariance variables, covariance update ops and inverse update ops.
+ The placement strategies are implemented in `placement.py`. See
+ `FisherEstimatorRoundRobin` for example of a concrete subclass with
+ a round-robin placement strategy.
"""
def __init__(self,
@@ -184,6 +167,77 @@ class FisherEstimator(object):
def name(self):
return self._name
+ @abc.abstractmethod
+ def make_ops_and_vars(self, scope=None):
+ """Make ops and vars with a specific placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. For example in case of
+ round robin placement a new device is chosen for each factor by cycling
+ through list of devices in the cov_devices argument. If cov_devices is None
+ then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all ops will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ inv_update_op: inv_update_ops grouped into a single op.
+ cov_update_thunks: Thunks that make the ops in cov_update_ops.
+ inv_update_thunks: Thunks that make the ops in inv_update_ops.
+ """
+ pass
+
+ @abc.abstractmethod
+ def make_vars_and_create_op_thunks(self, scope=None):
+ """Make vars and create op thunks with a specific placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all thunks will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ pass
+
def _apply_transformation(self, vecs_and_vars, transform):
"""Applies an block-wise transformation to the corresponding vectors.
@@ -286,158 +340,6 @@ class FisherEstimator(object):
self._instantiate_factors()
self._register_matrix_functions()
- def make_ops_and_vars(self, scope=None):
- """Make ops and vars with no specific device placement.
-
- See make_ops_and_vars_round_robin for further details.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all ops will execute, inside of a variable scope of the given
- name. (Default: None)
- Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- inv_update_op: inv_update_ops grouped into a single op.
- cov_update_thunks: Thunks that make the ops in cov_update_ops.
- inv_update_thunks: Thunks that make the ops in inv_update_ops.
- """
- return self.make_ops_and_vars_round_robin(scope=scope)
-
- # TODO(b/70674513): Factor device placement outside of this class.
- def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None,
- inv_devices=None):
- """Make ops and vars with a round-robin device placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all ops will execute, inside of a variable scope of the given
- name. (Default: None)
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
-
- Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- inv_update_op: inv_update_ops grouped into a single op.
- cov_update_thunks: Thunks that make the ops in cov_update_ops.
- inv_update_thunks: Thunks that make the ops in inv_update_ops.
- """
- (cov_update_thunks,
- inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin(
- scope=scope,
- cov_devices=cov_devices,
- inv_devices=inv_devices)
- cov_update_ops = [thunk() for thunk in cov_update_thunks]
- inv_update_ops = [thunk() for thunk in inv_update_thunks]
-
- scope = self.name if scope is None else scope
- with variable_scope.variable_scope(scope):
- cov_update_op = control_flow_ops.group(cov_update_ops,
- name="cov_update_op")
- inv_update_op = control_flow_ops.group(inv_update_ops,
- name="inv_update_op")
-
- return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
- cov_update_thunks, inv_update_thunks)
-
- def make_vars_and_create_op_thunks_round_robin(self,
- scope=None,
- cov_devices=None,
- inv_devices=None):
- """Make vars and create op thunks w/ a round-robin device placement strat.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
-
- (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
- inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
-
- if cov_devices:
- cov_update_thunks = []
- for cov_variable_thunk, cov_update_thunk, device in zip(
- cov_variable_thunks_raw, cov_update_thunks_raw,
- itertools.cycle(cov_devices)):
- with tf_ops.device(device):
- cov_variable_thunk()
- cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
- device))
- else:
- for cov_variable_thunk in cov_variable_thunks_raw:
- cov_variable_thunk()
- cov_update_thunks = cov_update_thunks_raw
-
- for inv_variable_thunk in inv_variable_thunks_raw:
- inv_variable_thunk()
-
- if inv_devices:
- inv_update_thunks = []
- for inv_update_thunk, device in zip(inv_update_thunks_raw,
- itertools.cycle(inv_devices)):
- inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
- device))
- else:
- inv_update_thunks = inv_update_thunks_raw
-
- return cov_update_thunks, inv_update_thunks
-
def create_ops_and_vars_thunks(self, scope=None):
"""Create thunks that make the ops and vars on demand.
@@ -582,3 +484,9 @@ class FisherEstimator(object):
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
return zip(*grads_all)
+
+
+class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
+ FisherEstimator):
+ """Fisher estimator which provides round robin device placement strategy."""
+ pass
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 083da768ec..843aeef7d8 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import warnings
-
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
@@ -53,8 +52,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
estimation_mode="gradients",
colocate_gradients_with_ops=True,
batch_size=None,
- cov_devices=None,
- inv_devices=None):
+ placement_strategy=None,
+ **kwargs):
"""Initializes the KFAC optimizer with the given settings.
Args:
@@ -96,14 +95,11 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
(Default: True)
batch_size: The size of the mini-batch. Only needed when momentum_type
== 'qmodel' or when automatic adjustment is used. (Default: None)
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified. Only used
- with (soon-to-be-depcrecated "convenience" properties).
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified. Only used
- with (soon-to-be-depcrecated "convenience" properties).
+ placement_strategy: string, Device placement strategy used when creating
+ covariance variables, covariance ops, and inverse ops.
+ (Default: `None`)
+ **kwargs: Arguments to be passesd to specific placement
+ strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
Raises:
ValueError: If the momentum type is unsupported.
@@ -123,8 +119,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._layers = layer_collection
self._estimation_mode = estimation_mode
self._colocate_gradients_with_ops = colocate_gradients_with_ops
- self._cov_devices = cov_devices
- self._inv_devices = inv_devices
# The below paramaters are required only if damping needs to be adapated.
# These parameters can be set by calling
@@ -164,16 +158,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._momentum_type = momentum_type
self._norm_constraint = norm_constraint
self._batch_size = batch_size
+ self._placement_strategy = placement_strategy
with variable_scope.variable_scope(name):
- self._fisher_est = est.FisherEstimator(
- self._variables,
- self._cov_ema_decay,
- self.damping,
- self._layers,
+ self._fisher_est = est.make_fisher_estimator(
+ placement_strategy=placement_strategy,
+ variables=self._variables,
+ cov_ema_decay=self._cov_ema_decay,
+ damping=self.damping,
+ layer_collection=self._layers,
exps=(-1,),
estimation_mode=self._estimation_mode,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops,
+ **kwargs)
super(KfacOptimizer, self).__init__(learning_rate, name=name)
@@ -237,6 +234,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
"damping", initializer=self._damping_constant, trainable=False)
@property
+ def variables(self):
+ return self._variables
+
+ @property
+ def damping(self):
+ if self._damping:
+ return self._damping
+ else:
+ return self._damping_constant
+
+ @property
+ def damping_adaptation_interval(self):
+ return self._damping_adaptation_interval
+
+ @property
def cov_update_thunks(self):
self._maybe_make_and_save_everything()
return self._cov_update_thunks
@@ -266,37 +278,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._maybe_make_and_save_everything()
return self._inv_update_op
- @property
- def variables(self):
- return self._variables
-
- @property
- def damping(self):
- if self._damping:
- return self._damping
- else:
- return self._damping_constant
-
- @property
- def damping_adaptation_interval(self):
- return self._damping_adaptation_interval
-
def _maybe_make_and_save_everything(self):
if not self._fisher_est.made_vars():
warnings.warn("These convenience properties will be depcrecated soon. "
"Please use explicit op/thunk creation methods instead "
- "(e.g. make_ops_and_vars_round_robin, etc).",
+ "(e.g. make_ops_and_vars, etc).",
DeprecationWarning)
(self._cov_update_ops, self._cov_update_op, self._inv_update_ops,
self._inv_update_op, self._cov_update_thunks,
- self._inv_update_thunks) = self.make_ops_and_vars_round_robin(
- cov_devices=self._cov_devices,
- inv_devices=self._inv_devices)
+ self._inv_update_thunks) = self.make_ops_and_vars()
def make_ops_and_vars(self):
- """Make ops and vars with no specific device placement.
+ """Make ops and vars with device placement `self._placement_strategy`.
- See make_ops_and_vars_round_robin for details.
+ See `FisherEstimator.make_ops_and_vars` for details.
Returns:
cov_update_ops: List of ops that compute the cov updates. Corresponds
@@ -307,88 +302,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
cov_update_op: cov_update_ops grouped into a single op.
inv_update_op: inv_update_ops grouped into a single op.
"""
- with variable_scope.variable_scope(self.get_name()):
- return self._fisher_est.make_ops_and_vars()
-
- def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None):
- """Make ops and vars with a round-robin device placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
+ return self._fisher_est.make_ops_and_vars(scope=self.get_name())
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
+ def make_vars_and_create_op_thunks(self):
+ """Make vars and create op thunks.
Returns:
- cov_update_ops: List of ops that compute the cov updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_ops: List of ops that compute the inv updates. Corresponds
- one-to-one with the list of factors given by the "factors" property.
- cov_update_op: cov_update_ops grouped into a single op.
- inv_update_op: inv_update_ops grouped into a single op.
- cov_update_thunks: Thunks that make the ops in cov_update_ops.
- inv_update_thunks: Thunks that make the ops in inv_update_ops.
- """
- with variable_scope.variable_scope(self.get_name()):
- return self._fisher_est.make_ops_and_vars_round_robin(
- cov_devices=cov_devices, inv_devices=inv_devices)
-
- def make_vars_and_create_op_thunks_round_robin(self,
- cov_devices=None,
- inv_devices=None):
- """Make vars and create op thunks w/ a round-robin device placement strat.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.make_vars_and_create_op_thunks_round_robin(
- scope=scope, cov_devices=cov_devices, inv_devices=inv_devices)
+ return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
- def ops_and_vars_thunks(self):
+ def create_ops_and_vars_thunks(self):
"""Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks,
@@ -413,7 +341,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
inv_update_thunks: A list of thunks that make the inv update ops.
"""
scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.ops_and_vars_thunks(scope=scope)
+ return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
def minimize(self, *args, **kwargs):
# Should this variable scope encompass everything below? Or will the super-
@@ -462,7 +390,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
An `Operation` that applies the specified gradients.
"""
self._maybe_make_and_save_everything()
-
# In Python 3, grads_and_vars can be a zip() object which can only be
# iterated over once. By converting it to a list, we ensure that it can be
# iterated over more than once.
@@ -618,7 +545,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
# compute the matrix-vector products with the transposed Fisher factor
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
-
batch_size = math_ops.cast(
self._batch_size, dtype=fft_precon_grads[0].dtype)
@@ -802,7 +728,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
# Go through variable and update its associated part of the velocity vector.
return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
- # TODO(b/73448937): Move all update damping code to a separate class/function.
def _update_damping(self, prev_batch, global_step):
"""Adapts damping parameter. Check KFAC (Section 6.5) for the details.
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
new file mode 100644
index 0000000000..bf12dbaa9a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/placement.py
@@ -0,0 +1,167 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements placement strategies for cov and inv ops, cov variables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope
+
+
+def _make_thunk_on_device(func, device):
+ def thunk():
+ with tf_ops.device(device):
+ return func()
+ return thunk
+
+
+class RoundRobinPlacementMixin(object):
+ """Implements round robin placement strategy for ops and variables."""
+
+ def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs):
+ """Initializes the RoundRobinPlacementMixin class.
+
+ Args:
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ *args:
+ **kwargs:
+
+ """
+ super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs)
+ self._cov_devices = cov_devices
+ self._inv_devices = inv_devices
+
+ def make_ops_and_vars(self, scope=None):
+ """Make ops and vars with a round-robin device placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the
+ `self._cov_devices` attribute. If `self._cov_devices` is `None` then no
+ explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the `self._inv_devices` attribute.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all ops will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_ops: List of ops that compute the cov updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ cov_update_op: cov_update_ops grouped into a single op.
+ inv_update_ops: List of ops that compute the inv updates. Corresponds
+ one-to-one with the list of factors given by the "factors" property.
+ inv_update_op: inv_update_ops grouped into a single op.
+ cov_update_thunks: Thunks that make the ops in cov_update_ops.
+ inv_update_thunks: Thunks that make the ops in inv_update_ops.
+ """
+ (cov_update_thunks,
+ inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope)
+ cov_update_ops = [thunk() for thunk in cov_update_thunks]
+ inv_update_ops = [thunk() for thunk in inv_update_thunks]
+
+ scope = self.name if scope is None else scope
+ with variable_scope.variable_scope(scope):
+ cov_update_op = control_flow_ops.group(cov_update_ops,
+ name="cov_update_op")
+ inv_update_op = control_flow_ops.group(inv_update_ops,
+ name="inv_update_op")
+
+ return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op,
+ cov_update_thunks, inv_update_thunks)
+
+ def make_vars_and_create_op_thunks(self, scope=None):
+ """Make vars and create op thunks w/ a round-robin device placement strat.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the
+ `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
+ explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the `self._inv_devices` attribute.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all thunks will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
+ (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
+ inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
+
+ if self._cov_devices:
+ cov_update_thunks = []
+ for cov_variable_thunk, cov_update_thunk, device in zip(
+ cov_variable_thunks_raw, cov_update_thunks_raw,
+ itertools.cycle(self._cov_devices)):
+ with tf_ops.device(device):
+ cov_variable_thunk()
+ cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
+ device))
+ else:
+ for cov_variable_thunk in cov_variable_thunks_raw:
+ cov_variable_thunk()
+ cov_update_thunks = cov_update_thunks_raw
+
+ for inv_variable_thunk in inv_variable_thunks_raw:
+ inv_variable_thunk()
+
+ if self._inv_devices:
+ inv_update_thunks = []
+ for inv_update_thunk, device in zip(inv_update_thunks_raw,
+ itertools.cycle(self._inv_devices)):
+ inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
+ device))
+ else:
+ inv_update_thunks = inv_update_thunks_raw
+
+ return cov_update_thunks, inv_update_thunks
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
new file mode 100644
index 0000000000..d0102883e6
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -0,0 +1,197 @@
+/*Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import android.graphics.Bitmap;
+import android.os.SystemClock;
+import android.util.Log;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+
+/**
+ * Class that benchmarks image classifier models.
+ *
+ * <p>===================== General workflow =======================
+ *
+ * <pre>{@code
+ * benchmarker = new OvicBenchmarker();
+ * benchmarker.getReadyToTest(labelInputStream, model);
+ * while (!benchmarker.shouldStop()) {
+ * Bitmap bitmap = ...
+ * benchmarker.doTestIteration(bitmap);
+ * }
+ * }</pre>
+ */
+public class OvicBenchmarker {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicBenchmarker";
+
+ /** Evaluation transformation parameters. */
+ private static final float CENTRAL_FRACTION = 0.875f;
+
+ /** Dimensions of inputs. */
+ private static final int DIM_BATCH_SIZE = 1;
+ private static final int DIM_PIXEL_SIZE = 3;
+ private int imgHeight = 224;
+ private int imgWidth = 224;
+
+ /* Preallocated buffers for storing image data in. */
+ private int[] intValues = null;
+
+ /** A ByteBuffer to hold image data, to be feed into classifier as inputs. */
+ private ByteBuffer imgData = null;
+
+ private OvicClassifier classifier;
+
+ /** Total runtime in ms. */
+ private double totalRuntime = 0.0;
+ /** Total allowed runtime in ms. */
+ private double wallTime = 20000 * 30.0;
+
+ private Boolean benchmarkStarted = null;
+
+ /**
+ * Initializes an {@link OvicBenchmarker}
+ *
+ * @param wallTime: a double number specifying the total amount of time to benchmark.
+ */
+ public OvicBenchmarker(double wallTime) {
+ benchmarkStarted = false;
+ totalRuntime = 0.0;
+ this.wallTime = wallTime;
+ }
+
+ /** Check whether the benchmarker should stop. */
+ public Boolean shouldStop() {
+ if (totalRuntime >= wallTime) {
+ Log.e(
+ TAG,
+ "Total runtime "
+ + Double.toString(totalRuntime)
+ + " exceeded walltime "
+ + Double.toString(wallTime));
+ return true;
+ }
+ return false;
+ }
+
+ /** Check whether the benchmarker is ready to start classifying images. */
+ public Boolean readyToTest() {
+ return (classifier != null);
+ }
+
+ /**
+ * Getting the benchmarker ready for classifying images.
+ *
+ * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
+ * read from.
+ * @param model: a {@link MappedByteBuffer} model to benchmark.
+ */
+ public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
+ try {
+ Log.i(TAG, "Creating classifier.");
+ classifier = new OvicClassifier(labelInputStream, model);
+ int [] inputDims = classifier.getInputDims();
+ imgHeight = inputDims[1];
+ imgWidth = inputDims[2];
+ // Only accept QUANTIZED_UINT8 input.
+ imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
+ imgData.order(ByteOrder.nativeOrder());
+ intValues = new int[imgHeight * imgWidth];
+ } catch (Exception e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker.");
+ }
+ }
+
+ /** Return how many classes are predicted per image. */
+ public int getNumPredictions() {
+ return classifier.getNumPredictions();
+ }
+
+ /**
+ * Perform test on a single bitmap image.
+ *
+ * @param bitmap: a {@link Bitmap} image to classify.
+ */
+ public OvicSingleImageResult doTestIteration(Bitmap bitmap)
+ throws IOException, InterruptedException {
+ if (shouldStop() || !readyToTest()) {
+ return null;
+ }
+ OvicSingleImageResult iterResult = null;
+ try {
+ Log.i(TAG, "Converting bitmap.");
+ convertBitmapToInput(bitmap);
+ Log.i(TAG, "Classifying image.");
+ iterResult = classifier.classifyByteBuffer(imgData);
+ } catch (RuntimeException e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to classify image.");
+ }
+ if (iterResult == null || iterResult.latency == null) {
+ throw new RuntimeException("Classification result or timing is invalid.");
+ }
+ Log.d(TAG, "Native inference latency: " + iterResult.latency);
+ Log.i(TAG, iterResult.toString());
+
+ if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
+ benchmarkStarted = true;
+ } else {
+ totalRuntime += (double) iterResult.latency;
+ }
+ return iterResult;
+ }
+
+ /**
+ * Writes Image data into a {@link ByteBuffer}.
+ *
+ * @param bitmap: a {@link Bitmap} source image.
+ */
+ private void convertBitmapToInput(Bitmap bitmap) throws RuntimeException {
+ if (imgData == null) {
+ throw new RuntimeException("Benchmarker is not yet ready to test.");
+ }
+ imgData.rewind();
+ // Perform transformations corresponding to evaluation mode.
+ float width = (float) bitmap.getWidth();
+ float height = (float) bitmap.getHeight();
+ int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2);
+ int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2);
+ int newWidth = Math.round(width - stWidth * 2);
+ int newHeight = Math.round(height - stHeight * 2);
+ bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight);
+ bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+
+ // Convert the image to ByteBuffer.
+ int pixel = 0;
+ long startTime = SystemClock.uptimeMillis();
+
+ for (int i = 0; i < imgHeight; ++i) {
+ for (int j = 0; j < imgWidth; ++j) {
+ final int val = intValues[pixel++];
+ imgData.put((byte) ((val >> 16) & 0xFF));
+ imgData.put((byte) ((val >> 8) & 0xFF));
+ imgData.put((byte) (val & 0xFF));
+ }
+ }
+ long endTime = SystemClock.uptimeMillis();
+ Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
new file mode 100644
index 0000000000..b2dfd8f2e7
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -0,0 +1,209 @@
+/*Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.TestHelper;
+
+/** Benchmark ImageNet Classifier with Tensorflow Lite. */
+public class OvicClassifier {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicClassifier";
+
+ /** Number of results to show (i.e. the "K" in top-K predictions). */
+ private static final int RESULTS_TO_SHOW = 5;
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private Interpreter tflite;
+
+ /** Labels corresponding to the output of the vision model. */
+ private List<String> labelList;
+
+ /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
+ private byte[][] inferenceOutputArray = null;
+ /** An array to hold final prediction probabilities. */
+ private float[][] labelProbArray = null;
+
+ /** Input resultion. */
+ private int[] inputDims = null;
+ /** Whether the model runs as float or quantized. */
+ private Boolean outputIsFloat = null;
+
+ private PriorityQueue<Map.Entry<Integer, Float>> sortedLabels =
+ new PriorityQueue<>(
+ RESULTS_TO_SHOW,
+ new Comparator<Map.Entry<Integer, Float>>() {
+ @Override
+ public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
+ return (o1.getValue()).compareTo(o2.getValue());
+ }
+ });
+
+ /** Initializes an {@code OvicClassifier}. */
+ OvicClassifier(InputStream labelInputStream, MappedByteBuffer model)
+ throws IOException, RuntimeException {
+ if (model == null) {
+ throw new RuntimeException("Input model is empty.");
+ }
+ labelList = loadLabelList(labelInputStream);
+ // OVIC uses one thread for CPU inference.
+ tflite = new Interpreter(model, 1);
+ inputDims = TestHelper.getInputDims(tflite, 0);
+ if (inputDims.length != 4) {
+ throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
+ }
+ if (inputDims[0] != 1) {
+ throw new RuntimeException("The model must have a batch size of 1, got "
+ + inputDims[0] + " instead.");
+ }
+ if (inputDims[3] != 3) {
+ throw new RuntimeException("The model must have three color channels, got "
+ + inputDims[3] + " instead.");
+ }
+ int minSide = Math.min(inputDims[1], inputDims[2]);
+ int maxSide = Math.max(inputDims[1], inputDims[2]);
+ if (minSide <= 0 || maxSide > 1000) {
+ throw new RuntimeException("The model's resolution must be between (0, 1000].");
+ }
+ String outputDataType = TestHelper.getOutputDataType(tflite, 0);
+ if (outputDataType.equals("float")) {
+ outputIsFloat = true;
+ } else if (outputDataType.equals("byte")) {
+ outputIsFloat = false;
+ } else {
+ throw new RuntimeException("Cannot process output type: " + outputDataType);
+ }
+ inferenceOutputArray = new byte[1][labelList.size()];
+ labelProbArray = new float[1][labelList.size()];
+ }
+
+ /** Classifies a {@link ByteBuffer} image. */
+ // @throws RuntimeException if model is uninitialized.
+ OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) throws RuntimeException {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
+ }
+ if (outputIsFloat == null) {
+ throw new RuntimeException(TAG + ": Classifier output type has not been resolved.");
+ }
+ if (outputIsFloat) {
+ tflite.run(imgData, labelProbArray);
+ } else {
+ tflite.run(imgData, inferenceOutputArray);
+ /** Convert results to float */
+ for (int i = 0; i < inferenceOutputArray[0].length; i++) {
+ labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f;
+ }
+ }
+ OvicSingleImageResult iterResult = computeTopKLabels();
+ iterResult.latency = getLastNativeInferenceLatencyMilliseconds();
+ return iterResult;
+ }
+
+ /** Return the probability array of all classes. */
+ public float[][] getlabelProbArray() {
+ return labelProbArray;
+ }
+
+ /** Return the number of top labels predicted by the classifier. */
+ public int getNumPredictions() {
+ return RESULTS_TO_SHOW;
+ }
+
+ /** Return the four dimensions of the input image. */
+ public int[] getInputDims() {
+ return inputDims;
+ }
+
+ /*
+ * Get native inference latency of last image classification run.
+ * @throws RuntimeException if model is uninitialized.
+ */
+ public Long getLastNativeInferenceLatencyMilliseconds() {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
+ }
+ Long latency = tflite.getLastNativeInferenceDurationNanoseconds();
+ return (latency == null) ? null : (Long) (latency / 1000000);
+ }
+
+ /** Closes tflite to release resources. */
+ public void close() {
+ tflite.close();
+ tflite = null;
+ }
+
+ /** Reads label list from Assets. */
+ private static List<String> loadLabelList(InputStream labelInputStream) throws IOException {
+ List<String> labelList = new ArrayList<String>();
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(labelInputStream, StandardCharsets.UTF_8))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ labelList.add(line);
+ }
+ }
+ return labelList;
+ }
+
+ /** Computes top-K labels. */
+ private OvicSingleImageResult computeTopKLabels() {
+ if (labelList == null) {
+ throw new RuntimeException("Label file has not been loaded.");
+ }
+ for (int i = 0; i < labelList.size(); ++i) {
+ sortedLabels.add(new AbstractMap.SimpleEntry<>(i, labelProbArray[0][i]));
+ if (sortedLabels.size() > RESULTS_TO_SHOW) {
+ sortedLabels.poll();
+ }
+ }
+ OvicSingleImageResult singleImageResult = new OvicSingleImageResult();
+ if (sortedLabels.size() != RESULTS_TO_SHOW) {
+ throw new RuntimeException(
+ "Number of returned labels does not match requirement: "
+ + sortedLabels.size()
+ + " returned, but "
+ + RESULTS_TO_SHOW
+ + " required.");
+ }
+ for (int i = 0; i < RESULTS_TO_SHOW; ++i) {
+ Map.Entry<Integer, Float> label = sortedLabels.poll();
+ // ImageNet model prediction indices are 0-based.
+ singleImageResult.topKIndices.add(label.getKey());
+ singleImageResult.topKClasses.add(labelList.get(label.getKey()));
+ singleImageResult.topKProbs.add(label.getValue());
+ }
+ // Labels with lowest probability are returned first, hence need to reverse them.
+ Collections.reverse(singleImageResult.topKIndices);
+ Collections.reverse(singleImageResult.topKClasses);
+ Collections.reverse(singleImageResult.topKProbs);
+ return singleImageResult;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java
new file mode 100644
index 0000000000..4af9a65c2f
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java
@@ -0,0 +1,54 @@
+/*Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.util.ArrayList;
+
+/** Result class for inference run on a single image. */
+public class OvicSingleImageResult {
+
+ /** Top K classes and probabilities. */
+ public ArrayList<String> topKClasses;
+ public ArrayList<Float> topKProbs;
+ public ArrayList<Integer> topKIndices;
+
+ /** Latency (ms). */
+ public Long latency;
+
+ OvicSingleImageResult() {
+ topKClasses = new ArrayList<>();
+ topKProbs = new ArrayList<>();
+ topKIndices = new ArrayList<>();
+ latency = -1L;
+ }
+
+ @Override
+ public String toString() {
+ String textToShow = latency + "ms";
+ for (int k = 0; k < topKProbs.size(); ++k) {
+ textToShow +=
+ "\nPrediction ["
+ + k
+ + "] = Class "
+ + Integer.toString(topKIndices.get(k))
+ + " ("
+ + topKClasses.get(k)
+ + ") : "
+ + Float.toString(topKProbs.get(k));
+ }
+ return textToShow;
+ }
+
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
new file mode 100644
index 0000000000..4fd23a99d2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -0,0 +1,176 @@
+/*Copyright 2018 Google LLC
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Paths;
+import javax.imageio.ImageIO;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.ovic.OvicClassifier}. */
+@RunWith(JUnit4.class)
+public final class OvicClassifierTest {
+
+ private OvicClassifier classifier;
+ private InputStream labelsInputStream = null;
+ private MappedByteBuffer quantizedModel = null;
+ private MappedByteBuffer floatModel = null;
+ private MappedByteBuffer lowResModel = null;
+ private ByteBuffer testImage = null;
+ private ByteBuffer lowResTestImage = null;
+ private OvicSingleImageResult testResult = null;
+ private static final String LABELS_PATH = "testdata/labels.txt";
+ private static final String QUANTIZED_MODEL_PATH = "testdata/quantized_model.lite";
+ private static final String LOW_RES_MODEL_PATH = "testdata/low_res_model.lite";
+ private static final String FLOAT_MODEL_PATH = "testdata/float_model.lite";
+ private static final String TEST_IMAGE_PATH = "testdata/test_image_224.jpg";
+ private static final String TEST_LOW_RES_IMAGE_PATH = "testdata/test_image_128.jpg";
+ private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform"
+
+ @Before
+ public void setUp() {
+ try {
+ File labelsfile = new File(getTestDir(LABELS_PATH));
+ labelsInputStream = new FileInputStream(labelsfile);
+ quantizedModel = loadModelFile(getTestDir(QUANTIZED_MODEL_PATH));
+ floatModel = loadModelFile(getTestDir(FLOAT_MODEL_PATH));
+ lowResModel = loadModelFile(getTestDir(LOW_RES_MODEL_PATH));
+ File imageFile = new File(getTestDir(TEST_IMAGE_PATH));
+ BufferedImage img = ImageIO.read(imageFile);
+ testImage = toByteBuffer(img);
+ // Low res image and models.
+ imageFile = new File(getTestDir(TEST_LOW_RES_IMAGE_PATH));
+ img = ImageIO.read(imageFile);
+ lowResTestImage = toByteBuffer(img);
+ } catch (IOException e) {
+ System.out.print(e.getMessage());
+ }
+ System.out.println("Successful setup");
+ }
+
+ private static String getTestDir(String testfile) throws IOException {
+ return Paths.get("third_party/tensorflow/contrib/lite/java/ovic/src/", testfile).toString();
+ }
+
+ @Test
+ public void ovicClassifier_quantizedModelCreateSuccess() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, quantizedModel);
+ assertThat(classifier != null).isTrue();
+ }
+
+ @Test
+ public void ovicClassifier_floatModelCreateSuccess() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, floatModel);
+ assertThat(classifier != null).isTrue();
+ }
+
+ @Test
+ public void ovicClassifier_quantizedModelClassifySuccess() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, quantizedModel);
+ testResult = classifier.classifyByteBuffer(testImage);
+ assertCorrectTopK(testResult);
+ }
+
+ @Test
+ public void ovicClassifier_floatModelClassifySuccess() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, floatModel);
+ testResult = classifier.classifyByteBuffer(testImage);
+ assertCorrectTopK(testResult);
+ }
+
+ @Test
+ public void ovicClassifier_lowResModelClassifySuccess() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, lowResModel);
+ testResult = classifier.classifyByteBuffer(lowResTestImage);
+ assertCorrectTopK(testResult);
+ }
+
+ @Test
+ public void ovicClassifier_latencyNotNull() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, floatModel);
+ testResult = classifier.classifyByteBuffer(testImage);
+ assertThat(testResult.latency != null).isTrue();
+ }
+
+ @Test
+ public void ovicClassifier_mismatchedInputResolutionFails() throws Exception {
+ classifier = new OvicClassifier(labelsInputStream, lowResModel);
+ int[] inputDims = classifier.getInputDims();
+ assertThat((inputDims[1] == 128) && (inputDims[2] == 128)).isTrue();
+ try {
+ testResult = classifier.classifyByteBuffer(testImage);
+ fail();
+ } catch (RuntimeException e) {
+ assertThat(e)
+ .hasMessageThat()
+ .contains(
+ "Failed to get input dimensions. 0-th input should have 49152 bytes, "
+ + "but found 150528 bytes.");
+ }
+ }
+
+ private static ByteBuffer toByteBuffer(BufferedImage image) {
+ ByteBuffer imgData = ByteBuffer.allocateDirect(
+ image.getHeight() * image.getWidth() * 3);
+ imgData.order(ByteOrder.nativeOrder());
+ for (int y = 0; y < image.getHeight(); y++) {
+ for (int x = 0; x < image.getWidth(); x++) {
+ int val = image.getRGB(x, y);
+ imgData.put((byte) ((val >> 16) & 0xFF));
+ imgData.put((byte) ((val >> 8) & 0xFF));
+ imgData.put((byte) (val & 0xFF));
+ }
+ }
+ return imgData;
+ }
+
+ private static void assertCorrectTopK(OvicSingleImageResult testResult) {
+ assertThat(testResult.topKClasses.size() > 0).isTrue();
+ Boolean topKAccurate = false;
+ // Assert that the correct class is in the top K.
+ for (int i = 0; i < testResult.topKIndices.size(); i++) {
+ if (testResult.topKIndices.get(i) == TEST_IMAGE_GROUNDTRUTH) {
+ topKAccurate = true;
+ break;
+ }
+ }
+ System.out.println(testResult.toString());
+ System.out.flush();
+ assertThat(topKAccurate).isTrue();
+ }
+
+ private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException {
+ File modelfile = new File(modelFilePath);
+ FileInputStream inputStream = new FileInputStream(modelfile);
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = 0L;
+ long declaredLength = fileChannel.size();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt
new file mode 100644
index 0000000000..fe811239d8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt
@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index deab5a91d2..68bce19aa3 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -754,7 +754,7 @@ def make_mean_tests(zip_path):
[-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
],
"const_axis": [True, False],
- "keep_dims": [True, False],
+ "keepdims": [True, False],
}, {
"input_dtype": [tf.float32, tf.int32, tf.int64],
"input_shape": [[1, 224, 224, 3]],
@@ -765,7 +765,7 @@ def make_mean_tests(zip_path):
[2, 2, 3], [-3, -3, -4], [-3, 2, 1]
],
"const_axis": [True, False],
- "keep_dims": [True, False],
+ "keepdims": [True, False],
}]
def build_graph(parameters):
@@ -788,7 +788,7 @@ def make_mean_tests(zip_path):
input_tensors = [input_tensor, axis]
out = tf.reduce_mean(
- input_tensor, axis=axis, keep_dims=parameters["keep_dims"])
+ input_tensor, axis=axis, keepdims=parameters["keepdims"])
return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 486ff1edcd..102740ee47 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -124,6 +124,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -168,6 +169,41 @@ cc_library(
)
cc_library(
+ name = "toco_saved_model",
+ srcs = [
+ "toco_saved_model.cc",
+ ],
+ hdrs = [
+ "toco_saved_model.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_flags_proto_cc",
+ ":types_proto_cc",
+ "//tensorflow/cc/tools:freeze_saved_model",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "toco_saved_model_test",
+ srcs = ["toco_saved_model_test.cc"],
+ deps = [
+ ":model_cmdline_flags",
+ ":toco_cmdline_flags",
+ ":toco_saved_model",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
name = "graph_transformations",
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
@@ -363,6 +399,7 @@ tf_cc_binary(
":toco_cmdline_flags",
":toco_flags_proto_cc",
":toco_port",
+ ":toco_saved_model",
":toco_tooling",
":types_proto_cc",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 59a6115920..7b71792ff7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -190,6 +190,7 @@ struct ParsedModelFlags {
Arg<string> output_array;
Arg<string> output_arrays;
Arg<string> input_shapes;
+ Arg<int> batch_size = Arg<int>(1);
Arg<float> mean_value = Arg<float>(0.f);
Arg<string> mean_values;
Arg<float> std_value = Arg<float>(1.f);
@@ -215,9 +216,11 @@ struct ParsedModelFlags {
// you want). See toco_cmdline_flags.cc for details.
struct ParsedTocoFlags {
Arg<string> input_file;
+ Arg<string> savedmodel_directory;
Arg<string> output_file;
- Arg<string> input_format;
- Arg<string> output_format;
+ Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF");
+ Arg<string> output_format = Arg<string>("TFLITE");
+ Arg<string> savedmodel_tagset;
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index a7a50e6fc9..b844e0b948 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1541,7 +1541,9 @@ void ConvertMeanOperator(const NodeDef& node,
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
+ if (HasAttr(node, "keepdims")) {
+ op->keep_dims = GetBoolAttr(node, "keepdims");
+ } else if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
}
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 4e2dec15a5..4264f21c76 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -72,6 +72,12 @@ bool ParseModelFlagsFromCommandLineFlags(
"Shapes corresponding to --input_arrays, colon-separated. For "
"many models each shape takes the form batch size, input array "
"height, input array width, input array depth."),
+ Flag("batch_size", parsed_flags.batch_size.bind(),
+ parsed_flags.batch_size.default_value(),
+ "Batch size for the model. Replaces the first dimension of an "
+ "input size array if undefined. Use only with SavedModels when "
+ "--input_shapes flag is not specified. Always use --input_shapes "
+ "flag with frozen graphs."),
Flag("input_data_type", parsed_flags.input_data_type.bind(),
parsed_flags.input_data_type.default_value(),
"Deprecated: use --input_data_types instead. Input array type, if "
diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc
index f01ec0ec61..8041aa9e7f 100644
--- a/tensorflow/contrib/lite/toco/toco.cc
+++ b/tensorflow/contrib/lite/toco/toco.cc
@@ -23,40 +23,70 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
#include "tensorflow/contrib/lite/toco/toco_tooling.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
-#ifndef CHECK_OK
-#define CHECK_OK(val) CHECK_EQ((val).ok(), true)
-#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true)
-#endif
-
namespace toco {
namespace {
-#define QCHECK_REQUIRE_TOCO_FLAG(arg) \
- QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg;
-
-void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- const TocoFlags& toco_flags) {
- port::CheckInitGoogleIsDone("InitGoogle is not done yet");
-
- QCHECK_REQUIRE_TOCO_FLAG(input_file)
- QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(),
- port::file::Defaults()))
- << "Specified input_file does not exist: "
- << parsed_toco_flags.input_file.value();
- QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(),
- port::file::Defaults()))
+// Checks the permissions of the output file to ensure it is writeable.
+void CheckOutputFilePermissions(const Arg<string>& output_file) {
+ QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
+ QCHECK(port::file::Writable(output_file.value()).ok())
+ << "Specified output_file is not writable: " << output_file.value()
+ << ".\n";
+}
+
+// Checks the permissions of the frozen model file.
+void CheckFrozenModelPermissions(const Arg<string>& input_file) {
+ QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
+ QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
+ << "Specified input_file does not exist: " << input_file.value() << ".\n";
+ QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
<< "Specified input_file exists, but is not readable: "
- << parsed_toco_flags.input_file.value();
+ << input_file.value() << ".\n";
+}
- QCHECK_REQUIRE_TOCO_FLAG(output_file);
- QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value()))
- << "parsed_toco_flags.input_file.value() output_file is not writable: "
- << parsed_toco_flags.output_file.value();
+// Checks the permissions of the SavedModel directory.
+void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) {
+ QCHECK(savedmodel_directory.specified())
+ << "Missing required flag --savedmodel_directory.\n";
+ QCHECK(
+ port::file::Exists(savedmodel_directory.value(), port::file::Defaults())
+ .ok())
+ << "Specified savedmodel_directory does not exist: "
+ << savedmodel_directory.value() << ".\n";
+}
+
+// Reads the contents of the GraphDef from either the frozen graph file or the
+// SavedModel directory. If it reads the SavedModel directory, it updates the
+// ModelFlags and TocoFlags accordingly.
+void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents) {
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
+
+ bool has_input_file = parsed_toco_flags.input_file.specified();
+ bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified();
+
+ // Ensure either input_file or savedmodel_directory flag has been set.
+ QCHECK_NE(has_input_file, has_savedmodel_dir)
+ << "Specify either input_file or savedmodel_directory flag.\n";
+
+ // Checks the input file permissions and reads the contents.
+ if (has_input_file) {
+ CheckFrozenModelPermissions(parsed_toco_flags.input_file);
+ CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ graph_def_contents, port::file::Defaults())
+ .ok());
+ } else {
+ CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory);
+ GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags,
+ model_flags, graph_def_contents);
+ }
}
void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
@@ -67,21 +97,20 @@ void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
TocoFlags toco_flags;
ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
- CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags);
+ string graph_def_contents;
+ ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
+ &model_flags, &graph_def_contents);
+ CheckOutputFilePermissions(parsed_toco_flags.output_file);
- string input_file_contents;
- CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(),
- &input_file_contents,
- port::file::Defaults()));
std::unique_ptr<Model> model =
- Import(toco_flags, model_flags, input_file_contents);
+ Import(toco_flags, model_flags, graph_def_contents);
Transform(toco_flags, model.get());
string output_file_contents;
Export(toco_flags, *model, toco_flags.allow_custom_ops(),
&output_file_contents);
- CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(),
- output_file_contents,
- port::file::Defaults()));
+ CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(),
+ output_file_contents, port::file::Defaults())
+ .ok());
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 0f67c2de72..cc7803dd86 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/core/platform/logging.h"
@@ -38,6 +39,9 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Input file (model of any supported format). For Protobuf "
"formats, both text and binary are supported regardless of file "
"extension."),
+ Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
+ parsed_flags.savedmodel_directory.default_value(),
+ "Full path to the directory containing the SavedModel."),
Flag("output_file", parsed_flags.output_file.bind(),
parsed_flags.output_file.default_value(),
"Output file. "
@@ -49,6 +53,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.output_format.default_value(),
"Output file format. "
"One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
+ Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
+ parsed_flags.savedmodel_tagset.default_value(),
+ "Comma-separated set of tags identifying the MetaGraphDef within "
+ "the SavedModel to analyze. All tags in the tag set must be "
+ "specified."),
Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
parsed_flags.default_ranges_min.default_value(),
"If defined, will be used as the default value for the min bound "
@@ -128,47 +137,72 @@ bool ParseTocoFlagsFromCommandLineFlags(
}
}
+namespace {
+
+// Defines the requirements for a given flag. kUseDefault means the default
+// should be used in cases where the value isn't specified by the user.
+enum class FlagRequirement {
+ kNone,
+ kMustBeSpecified,
+ kMustNotBeSpecified,
+ kUseDefault,
+};
+
+// Enforces the FlagRequirements are met for a given flag.
+template <typename T>
+void EnforceFlagRequirement(const T& flag, const string& flag_name,
+ FlagRequirement requirement) {
+ if (requirement == FlagRequirement::kMustBeSpecified) {
+ QCHECK(flag.specified()) << "Missing required flag " << flag_name;
+ }
+ if (requirement == FlagRequirement::kMustNotBeSpecified) {
+ QCHECK(!flag.specified())
+ << "Given other flags, this flag should not have been specified: "
+ << flag_name;
+ }
+}
+
+// Gets the value from the flag if specified. Returns default if the
+// FlagRequirement is kUseDefault.
+template <typename T>
+absl::optional<T> GetFlagValue(const Arg<T>& flag,
+ FlagRequirement requirement) {
+ if (flag.specified()) return flag.value();
+ if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
+ return absl::optional<T>();
+}
+
+} // namespace
+
void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
TocoFlags* toco_flags) {
namespace port = toco::port;
port::CheckInitGoogleIsDone("InitGoogle is not done yet");
- enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
-
-#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \
- do { \
- if (requirement == FlagRequirement::kMustBeSpecified) { \
- QCHECK(parsed_toco_flags.name.specified()) \
- << "Missing required flag: " << #name; \
- } \
- if (requirement == FlagRequirement::kMustNotBeSpecified) { \
- QCHECK(!parsed_toco_flags.name.specified()) \
- << "Given other flags, this flag should not have been specified: " \
- << #name; \
- } \
- } while (false)
-#define READ_TOCO_FLAG(name, requirement) \
- ENFORCE_FLAG_REQUIREMENT(name, requirement); \
- do { \
- if (parsed_toco_flags.name.specified()) { \
- toco_flags->set_##name(parsed_toco_flags.name.value()); \
- } \
+#define READ_TOCO_FLAG(name, requirement) \
+ do { \
+ EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
+ auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
+ if (flag_value.has_value()) { \
+ toco_flags->set_##name(flag_value.value()); \
+ } \
} while (false)
-#define PARSE_TOCO_FLAG(Type, name, requirement) \
- ENFORCE_FLAG_REQUIREMENT(name, requirement); \
- do { \
- if (parsed_toco_flags.name.specified()) { \
- Type x; \
- QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
- << "Unrecognized " << #Type << " value " \
- << parsed_toco_flags.name.value(); \
- toco_flags->set_##name(x); \
- } \
+#define PARSE_TOCO_FLAG(Type, name, requirement) \
+ do { \
+ EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \
+ auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
+ if (flag_value.has_value()) { \
+ Type x; \
+ QCHECK(Type##_Parse(flag_value.value(), &x)) \
+ << "Unrecognized " << #Type << " value " \
+ << parsed_toco_flags.name.value(); \
+ toco_flags->set_##name(x); \
+ } \
} while (false)
- PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
- PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
+ PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
+ PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc
new file mode 100644
index 0000000000..91a742b9e0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc
@@ -0,0 +1,186 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+namespace {
+
+// Loads a SavedModel from the directory specified in parsed_toco_flags.
+// Returns a SavedModelBundle with the requested MetaGraphDef.
+const tensorflow::SavedModelBundle* LoadSavedModel(
+ const ParsedTocoFlags& parsed_toco_flags) {
+ const string model_path = parsed_toco_flags.savedmodel_directory.value();
+ QCHECK(tensorflow::MaybeSavedModelDirectory(model_path))
+ << "Model is not saved in the supported SavedModel format.\n";
+
+ // Gets the tags identifying the MetaGraphDef from the command line arguments.
+ QCHECK(parsed_toco_flags.savedmodel_tagset.specified())
+ << "Missing required flag --savedmodel_tagset.\n";
+ const string tags_str = parsed_toco_flags.savedmodel_tagset.value();
+ auto tags = absl::StrSplit(tags_str, ',');
+
+ // Loads MetaGraphDef.
+ auto* bundle = new tensorflow::SavedModelBundle;
+ TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
+ tensorflow::RunOptions(), model_path,
+ tags, bundle))
+ << "Failed to load exported model from " << model_path
+ << ". Ensure the model contains the required tags '" << tags_str
+ << "'.\n";
+ return bundle;
+}
+
+// Returns the array name without the postfix.
+//
+// e.g. reduces "input:0" to "input".
+string GetArrayName(const string& name) {
+ const std::vector<string>& names = absl::StrSplit(name, ':');
+ return names[0];
+}
+
+// Returns the list of array names without the postfix sorted alphabetically.
+std::set<string> GetSortedNames(const std::unordered_set<string>& names) {
+ std::vector<string> final_names;
+ final_names.reserve(names.size());
+ for (const auto& name : names) {
+ final_names.push_back(GetArrayName(name));
+ }
+ return std::set<string>(final_names.begin(), final_names.end());
+}
+
+// Gets the final shape after replacing the first dimension with batch size, if
+// it is undefined (containing the value -1). Returns whether the shape is
+// valid.
+bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape,
+ int batch_size,
+ tensorflow::TensorShapeProto* final_shape) {
+ for (int idx = 0; idx < shape.dim().size(); ++idx) {
+ int64 final_dim = shape.dim()[idx].size();
+ if (final_dim == -1) {
+ if (idx > 0) return false;
+ final_dim = batch_size;
+ }
+ final_shape->add_dim()->set_size(final_dim);
+ }
+ return true;
+}
+
+// Updates the input arrays in ModelFlags to contain the shape of the array.
+void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size,
+ ModelFlags* model_flags) {
+ // Build map of input array names to input arrays.
+ std::unordered_map<string, InputArray*> input_data_map;
+ for (auto& input : *model_flags->mutable_input_arrays()) {
+ input_data_map[input.name()] = &input;
+ }
+
+ // Adds shapes to the input arrays if the shape is valid.
+ for (const tensorflow::NodeDef& node_def : graph_def.node()) {
+ if (input_data_map.find(node_def.name()) != input_data_map.end()) {
+ const auto shape_it = node_def.attr().find("shape");
+ if (shape_it != node_def.attr().end()) {
+ tensorflow::TensorShapeProto final_shape;
+ bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(),
+ batch_size, &final_shape);
+
+ if (is_valid) {
+ auto* shape = input_data_map.at(node_def.name())->mutable_shape();
+ QCHECK_EQ(shape->dims_size(), 0)
+ << "The shape for the input '" << node_def.name()
+ << "' was previously defined. For clarity please define inputs "
+ << "via --input_arrays and input_shapes flags.\n";
+ for (const auto& dim : final_shape.dim()) {
+ shape->add_dims(dim.size());
+ }
+ }
+ }
+ }
+ }
+
+ // Checks all input arrays have a shape.
+ for (auto const& input : model_flags->input_arrays()) {
+ QCHECK(input.shape().dims_size() > 0)
+ << "A valid input shape was not found for input '" << input.name()
+ << "'. Please define via --input_arrays and --input_shapes flags.\n";
+ }
+}
+
+} // namespace
+
+void ParseMetaData(const tensorflow::GraphDef& graph_def,
+ const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags) {
+ if (!parsed_model_flags.input_arrays.specified()) {
+ const std::set<string> sorted_inputs = GetSortedNames(inputs);
+ for (const auto& input_name : sorted_inputs) {
+ model_flags->add_input_arrays()->set_name(input_name);
+ }
+ }
+
+ if (!parsed_model_flags.output_arrays.specified()) {
+ const std::set<string> sorted_outputs = GetSortedNames(outputs);
+ for (const auto& output_name : sorted_outputs) {
+ model_flags->add_output_arrays(GetArrayName(output_name));
+ }
+ }
+
+ if (!parsed_model_flags.input_shapes.specified()) {
+ int batch_size = parsed_model_flags.batch_size.value();
+ ProcessInputShapes(graph_def, batch_size, model_flags);
+ }
+
+ if (!parsed_toco_flags.inference_type.specified()) {
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
+}
+
+// TODO(nupurgarg): Add top level tests.
+void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents) {
+ // Loads the MetaGraphDef within a SavedModelBundle.
+ auto bundle = LoadSavedModel(parsed_toco_flags);
+
+ // Converts the MetaGraphDef to frozen GraphDef.
+ tensorflow::GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ // Reads the frozen GraphDef into a string.
+ QCHECK(frozen_graph_def.SerializeToString(graph_def_contents))
+ << "Unable to generate serialized GraphDef.\n";
+
+ // Process inputs and outputs and metadata within GraphDef.
+ const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def();
+ ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags,
+ parsed_model_flags, toco_flags, model_flags);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h
new file mode 100644
index 0000000000..7a0fabd82d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include "tensorflow/contrib/lite/toco/args.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/types.pb.h"
+
+namespace toco {
+
+// Parses metadata into `toco_flags` and `model_flags`.
+//
+// Stores `inputs` as input_arrays and `outputs` as output_arrays in
+// `model_flags`. Infers input_shapes from the GraphDef and stores it in
+// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT
+// and stores it in `toco_flags`.
+void ParseMetaData(const tensorflow::GraphDef& graph_def,
+ const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags);
+
+// Generates a frozen graph from the SavedModel in the directory specified in
+// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses
+// metadata relating to the GraphDef into `toco_flags` and `model_flags`.
+void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents);
+
+} // namespace toco
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
new file mode 100644
index 0000000000..5e122afe65
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
@@ -0,0 +1,274 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+namespace {
+
+using tensorflow::ops::Add;
+using tensorflow::ops::Const;
+using tensorflow::ops::FakeQuantWithMinMaxArgs;
+using tensorflow::ops::Placeholder;
+
+class TocoSavedModelTest : public ::testing::Test {
+ protected:
+ // Calls functions to process cmdline arguments and calls ParseMetaData.
+ // ParseMetaData parses input_arrays, output_arrays, and gets metadata from
+ // SavedModel it is not defined in the cmdline arguments.
+ void ProcessGraphDefMetadata(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs,
+ const tensorflow::GraphDef& graph_def) {
+ ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_);
+ ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_);
+ ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_,
+ parsed_model_flags_, &toco_flags_, &model_flags_);
+ }
+
+ // Gets the GraphDef from the SavedModelBundle and processes metadata.
+ void ProcessSavedModelMetadata(const std::unordered_set<string>& inputs,
+ const std::unordered_set<string>& outputs) {
+ const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def();
+ ProcessGraphDefMetadata(inputs, outputs, graph_def);
+ }
+
+ // Returns a GraphDef representing a simple float model with a single input.
+ tensorflow::GraphDef GetFloatGraphDef(const std::vector<int64>& shape) {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output input =
+ Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::PartialTensorShape(shape)));
+ tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
+ tensorflow::Output add = Add(scope.WithOpName("add"), input, zero);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Returns a GraphDef representing a simple float model with two inputs.
+ tensorflow::GraphDef GetComplexFloatGraphDef() {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output inputA =
+ Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output inputB =
+ Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Returns a GraphDef representing a simple quantized model.
+ tensorflow::GraphDef GetQuantizedGraphDef() {
+ tensorflow::GraphDef graph_def;
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ tensorflow::Output input =
+ Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
+ Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
+ tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
+ tensorflow::Output fake_quant =
+ FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero);
+ tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant);
+
+ TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
+ return graph_def;
+ }
+
+ // Gets the values in the input_arrays flag.
+ std::vector<string> GetInputArrays() {
+ std::vector<string> actual;
+ for (const auto& input : model_flags_.input_arrays()) {
+ actual.push_back(input.name());
+ }
+ return actual;
+ }
+
+ // Gets the values in the output_arrays flag.
+ std::vector<string> GetOutputArrays() {
+ std::vector<string> actual(model_flags_.output_arrays().begin(),
+ model_flags_.output_arrays().end());
+ return actual;
+ }
+
+ // Gets the shape of the given input array.
+ string GetInputShape(const string& input_array) {
+ for (const auto& input : model_flags_.input_arrays()) {
+ if (input.name() == input_array) {
+ std::vector<string> dims;
+ for (int idx = 0; idx < input.shape().dims_size(); ++idx) {
+ dims.push_back(std::to_string(input.shape().dims(idx)));
+ }
+ return absl::StrJoin(dims, ",");
+ }
+ }
+ return "";
+ }
+
+ tensorflow::SavedModelBundle bundle_;
+ ParsedTocoFlags parsed_toco_flags_;
+ ParsedModelFlags parsed_model_flags_;
+ TocoFlags toco_flags_;
+ ModelFlags model_flags_;
+};
+
+// Tests if input_arrays, output_arrays, inference_type, and output_arrays are
+// added to ModelFlags if they are not specified in cmdline arguments.
+// Tests if the default batch size replaces a -1 in the first dimension.
+TEST_F(TocoSavedModelTest, NoCmdLine) {
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the order of input_arrays and output_arrays is deterministic when
+// they are taken from the SavedModel.
+TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) {
+ tensorflow::GraphDef graph_def = GetComplexFloatGraphDef();
+
+ // Note: The model does not have two outputs. However, the function does not
+ // need an accurate output_array list. This is only meant to test order.
+ ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add", "invalid"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1");
+ EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if input_shapes is inferred when input_arrays is passed in via cmdline
+// arguments.
+TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) {
+ parsed_model_flags_.input_arrays.bind()("input");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "2,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Ensures a failure occurs when input_shapes is defined without input_arrays.
+TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) {
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
+
+ EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
+ "failed: input_shapes.size\\(\\) == "
+ "model_flags->input_arrays_size\\(\\)");
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes are used when
+// specified with an empty GraphDef.
+TEST_F(TocoSavedModelTest, InputArraysCmdLine) {
+ parsed_model_flags_.input_arrays.bind()("inputA,inputB");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+
+ ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"output0", "output1"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(GetInputShape("inputB"), "9,12");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes are used when
+// specified even if values exist within the GraphDef.
+TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) {
+ parsed_model_flags_.input_arrays.bind()("inputA");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1");
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if the cmdline values of input_arrays, input_shapes, inference_type,
+// and output_arrays are used when specified with an empty GraphDef.
+TEST_F(TocoSavedModelTest, AllParamsCmdLine) {
+ parsed_model_flags_.input_arrays.bind()("inputA,inputB");
+ parsed_model_flags_.output_arrays.bind()("outputA,outputB");
+ parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
+ parsed_toco_flags_.inference_type.bind()("FLOAT");
+
+ ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"outputA", "outputB"}));
+ EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
+ EXPECT_EQ(GetInputShape("inputB"), "9,12");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Tests if a quantized graph gives the correct values assuming type is passed
+// in via command line.
+TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) {
+ parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8");
+ tensorflow::GraphDef graph_def = GetQuantizedGraphDef();
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8);
+}
+
+// Tests if the provided batch size replaces a -1 in the first dimension of
+// input shape.
+TEST_F(TocoSavedModelTest, MissingShapeParameterValid) {
+ parsed_model_flags_.batch_size.bind()(3);
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
+
+ ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
+ EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
+ EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
+ EXPECT_EQ(GetInputShape("input"), "3,3,3,1");
+ EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
+}
+
+// Ensures a failure occurs if there is a -1 in a dimension aside from the first
+// position of input shape.
+TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) {
+ parsed_model_flags_.batch_size.bind()(3);
+ tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1});
+
+ EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
+ "A valid input shape was not found for input 'input'.");
+}
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py
index e3040f09e4..f498b814bf 100644
--- a/tensorflow/contrib/py2tf/converters/call_trees.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees.py
@@ -118,6 +118,12 @@ class CallTreeTransformer(transformer.Base):
def _should_compile(self, node, fqn):
"""Determines whether an entity should be compiled in the context."""
+ # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
+ module_name = fqn[0]
+ for mod in self.uncompiled_modules:
+ if module_name.startswith(mod[0] + '.'):
+ return False
+
for i in range(1, len(fqn)):
if fqn[:i] in self.uncompiled_modules:
return False
diff --git a/tensorflow/contrib/py2tf/converters/for_loops.py b/tensorflow/contrib/py2tf/converters/for_loops.py
index 4297c1cf2a..8d28b149a8 100644
--- a/tensorflow/contrib/py2tf/converters/for_loops.py
+++ b/tensorflow/contrib/py2tf/converters/for_loops.py
@@ -38,19 +38,19 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
self.generic_visit(node)
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
i_var = self.context.namer.new_symbol('i', body_scope.referenced)
- n_var = self.context.namer.new_symbol('n', body_scope.referenced)
- iterated_var = self.context.namer.new_symbol('iterated',
- body_scope.referenced)
+ smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter',
+ body_scope.referenced)
+ cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
if anno.hasanno(node, 'extra_cond'):
template = """
i = 0
- iterated = loop_iter
- n = len(iterated)
- while i < n and extra_cond:
- target = iterated[i]
+ smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
+ cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
+ while cont and extra_cond:
body
i += 1
+ cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
"""
return templates.replace(
template,
@@ -58,18 +58,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
target=node.target,
body=node.body,
i=i_var,
- n=n_var,
- iterated=iterated_var,
+ smart_loop_iter=smart_loop_iter_var,
+ cont=cont_var,
extra_cond=anno.getanno(node, 'extra_cond'))
else:
template = """
i = 0
- iterated = loop_iter
- n = len(iterated)
- while i < n:
- target = iterated[i]
+ smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
+ cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
+ while cont:
body
i += 1
+ cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
"""
repl = templates.replace(
template,
@@ -77,8 +77,8 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
target=node.target,
body=node.body,
i=i_var,
- n=n_var,
- iterated=iterated_var)
+ smart_loop_iter=smart_loop_iter_var,
+ cont=cont_var)
return repl
def visit_Continue(self, node):
diff --git a/tensorflow/contrib/py2tf/converters/lists.py b/tensorflow/contrib/py2tf/converters/lists.py
index 12ebd00062..3e62037a50 100644
--- a/tensorflow/contrib/py2tf/converters/lists.py
+++ b/tensorflow/contrib/py2tf/converters/lists.py
@@ -67,6 +67,9 @@ class ListTransformer(transformer.Base):
node = self.generic_visit(node)
if isinstance(node.value, gast.Call):
call_node = node.value
+
+ if not anno.hasanno(call_node.func, anno.Basic.QN):
+ return node
qn = anno.getanno(call_node.func, anno.Basic.QN)
if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
index 5556a58c02..a969adbeca 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
@@ -168,6 +168,15 @@ class TypeInfoResolver(transformer.Base):
anno.getanno(definition, 'element_type'))
return node
+ def _process_tuple_assignment(self, source, t):
+ for i, e in enumerate(t.elts):
+ if isinstance(e, gast.Tuple):
+ self._process_tuple_assignment(source, e)
+ else:
+ self.scope.setval(
+ anno.getanno(e, anno.Basic.QN),
+ gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+
def _process_variable_assignment(self, source, targets):
if isinstance(source, gast.Call):
func = source.func
@@ -183,10 +192,9 @@ class TypeInfoResolver(transformer.Base):
for t in targets:
if isinstance(t, gast.Tuple):
- for i, e in enumerate(t.elts):
- self.scope.setval(
- anno.getanno(e, anno.Basic.QN),
- gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+ # need to recurse on the case of assigning nested tuples,
+ # ex. a, (b, c) = f()
+ self._process_tuple_assignment(source, t)
elif isinstance(t, (gast.Name, gast.Attribute)):
self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
else:
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
index 0d9d5a85f0..8a8956197d 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
@@ -196,6 +196,23 @@ class TypeInfoResolverTest(test.TestCase):
f_ref = node.body[0].body[1].value
self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ def test_nested_assignment(self):
+
+ def test_fn(foo):
+ a, (b, c) = foo
+ return a, b, c
+
+ node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)})
+ lhs = node.body[0].body[1].value.elts
+ a = lhs[0]
+ b = lhs[1]
+ c = lhs[2]
+ # TODO(mdan): change these once we have the live values propagating
+ # correctly
+ self.assertFalse(anno.hasanno(a, 'live_val'))
+ self.assertFalse(anno.hasanno(b, 'live_val'))
+ self.assertFalse(anno.hasanno(c, 'live_val'))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD
index d029289f5a..b53fbb5c18 100644
--- a/tensorflow/contrib/py2tf/utils/BUILD
+++ b/tensorflow/contrib/py2tf/utils/BUILD
@@ -35,6 +35,7 @@ py_library(
deps = [
"//tensorflow/python:list_ops",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py
index d9d8e34689..4e6003c852 100644
--- a/tensorflow/contrib/py2tf/utils/__init__.py
+++ b/tensorflow/contrib/py2tf/utils/__init__.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_dataset
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_for_cond
from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
diff --git a/tensorflow/contrib/py2tf/utils/builtins.py b/tensorflow/contrib/py2tf/utils/builtins.py
index 3cb62b55d4..251b4ed8ee 100644
--- a/tensorflow/contrib/py2tf/utils/builtins.py
+++ b/tensorflow/contrib/py2tf/utils/builtins.py
@@ -22,8 +22,10 @@ import six
from tensorflow.contrib.py2tf.utils import py_func
from tensorflow.contrib.py2tf.utils import type_check
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
@@ -54,7 +56,6 @@ def dynamic_len(list_or_tensor):
raise ValueError(
'len requires non-zero rank for tensor "%s"' % list_or_tensor)
return array_ops.shape(list_or_tensor)[0]
-
return len(list_or_tensor)
@@ -97,3 +98,69 @@ def dynamic_print(*values):
if all(map(is_tf_print_compatible, values)):
return logging_ops.Print(1, values)
return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
+
+
+def dynamic_dataset(iterated):
+ """Implementartion of smart tf.data.Dataset epoch wrapping.
+
+ The function checks if the input is a tf.data.Dataset and if so then wraps it
+ so that for each element it returns it also returns the current epoch the
+ dataset iteration is in, for two epochs. If the input is not a
+ tf.data.Dataset then it just returns the input.
+
+ Args:
+ iterated: The iterable or tf.data.Dataset that is being iterated over.
+ Returns:
+ Either just the untouched input, or in the case of input being a
+ tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each
+ element it returns it also returns the current epoch the dataset iteration
+ is in.
+ """
+ if not isinstance(iterated, dataset_ops.Dataset):
+ return iterated
+
+ def epoch_dataset_number_helper(i):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(i).repeat(), iterated))
+
+ epoch_numbers = dataset_ops.Dataset.range(2)
+ return epoch_numbers.flat_map(epoch_dataset_number_helper)
+
+
+def dynamic_for_cond(iteration, iterated):
+ """Implementartion of smart while-loop condition using dynamic dispatch.
+
+ The function checks if it is iterating over a tf.data.Dataset or not, and in
+ the case it is not then it simply returns if we are still in range of the
+ iterated and the next element. If it is iterating over a dataset then it only
+ iterates for a single epoch.
+
+ Args:
+ iteration: The current iteration of the loop.
+ iterated: The iterable or tf.data.Dataset that is being iterated over.
+ Returns:
+ A tuple of a bool that indicates whether the loop should continue, and the
+ next element in iterated.
+ """
+ # TODO(znado): Clean up.
+ # TODO(znado): This won't work for unpacked iterates. Fix.
+ if isinstance(iterated, dataset_ops.Dataset):
+ curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next()
+ return math_ops.less(curr_epoch, 1), next_elem
+ elif tensor_util.is_tensor(iterated):
+ if iterated.shape.ndims > 1:
+ elem_shape = array_ops.shape(iterated)[1:]
+ else:
+ elem_shape = ()
+ if iterated.shape.ndims == 0 or iterated.shape[0] == 0:
+ return False, array_ops.zeros(elem_shape, iterated.dtype)
+ return control_flow_ops.cond(
+ math_ops.less(iteration, dynamic_len(iterated)),
+ lambda: (True, iterated[iteration]),
+ lambda: (False, array_ops.zeros(elem_shape, iterated.dtype)))
+ elif hasattr(iterated, '__len__'):
+ if iteration < len(iterated):
+ return True, iterated[iteration]
+ return False, None
+ else:
+ raise NotImplementedError('Python iterators not yet supported.')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index aaa6f3c2c1..152f8c8c69 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -931,8 +931,7 @@ class _InputPipeline(object):
# In the model-parallel case, both the host-side and device-side
# computations must agree on the core on which infeed takes place. We
# choose to perform infeed on logical core 0 of each replica.
- with ops.device(tpu.core(0)):
- values = self._infeed_queue.generate_dequeue_op()
+ values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
# The unflatten process uses the structure information recorded above.
return self._inputs_structure_recorder.unflatten_features_and_labels(
values)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index 42ac6eb680..604e6600c8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -23,6 +23,7 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_sharding
from tensorflow.python.framework import dtypes
@@ -368,13 +369,20 @@ class InfeedQueue(object):
policy.freeze()
self._validate()
- def generate_dequeue_op(self):
+ def generate_dequeue_op(self, tpu_device=0):
"""Generates the device-side Op to dequeue a tuple from the queue.
Implicitly freezes the queue configuration if it is not already
frozen, which will raise errors if the shapes and types have not
been fully specified.
+ Args:
+ tpu_device: The TPU device ordinal where the infeed instruction should be
+ placed. If None, no explicit placement will be performed, and it is up
+ to the user to call this API from within a proper TPU device scope.
+ The XLA code will fail if the TPU dequeue instruction is not bound to
+ any device.
+
Returns:
A list of Outputs corresponding to a shard of infeed dequeued
into XLA, suitable for use within a replicated block.
@@ -392,8 +400,13 @@ class InfeedQueue(object):
policy.get_sharded_shape(shape)
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
- return tpu_ops.infeed_dequeue_tuple(
- dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
+ if tpu_device is not None:
+ with ops.device(tpu.core(tpu_device)):
+ return tpu_ops.infeed_dequeue_tuple(
+ dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
+ else:
+ return tpu_ops.infeed_dequeue_tuple(
+ dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
def _generate_enqueue_op(self,
inputs,
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 16397622ed..96eff86d8d 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -38,40 +38,60 @@ class HParamsTest(test.TestCase):
self.assertFalse('bar' in hparams)
def testSomeValues(self):
- hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6')
- self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values())
- expected_str = '[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\')]'
+ hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
+ self.assertDictEqual(
+ {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'},
+ hparams.values())
+ expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), '
+ '(\'d\', \'/a/b=c/d\')]')
self.assertEqual(expected_str, str(hparams.__str__()))
self.assertEqual(expected_str, str(hparams))
self.assertEqual(1, hparams.aaa)
self.assertEqual(2.0, hparams.b)
self.assertEqual('relu6', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('aaa=12')
self.assertDictEqual({
'aaa': 12,
'b': 2.0,
- 'c_c': 'relu6'
+ 'c_c': 'relu6',
+ 'd': '/a/b=c/d'
}, hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(2.0, hparams.b)
self.assertEqual('relu6', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=relu4, b=-2.0e10')
self.assertDictEqual({
'aaa': 12,
'b': -2.0e10,
- 'c_c': 'relu4'
+ 'c_c': 'relu4',
+ 'd': '/a/b=c/d'
}, hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(-2.0e10, hparams.b)
self.assertEqual('relu4', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=,b=0,')
- self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': ''}, hparams.values())
+ self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'},
+ hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(0.0, hparams.b)
self.assertEqual('', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=2.3",b=+2,')
self.assertEqual(2.0, hparams.b)
self.assertEqual('2.3"', hparams.c_c)
+ hparams.parse('d=/a/b/c/d,aaa=11,')
+ self.assertEqual(11, hparams.aaa)
+ self.assertEqual(2.0, hparams.b)
+ self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a/b/c/d', hparams.d)
+ hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
+ self.assertEqual(10, hparams.aaa)
+ self.assertEqual(1.5, hparams.b)
+ self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a=b/c/d', hparams.d)
with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
hparams.parse('x=123')
with self.assertRaisesRegexp(ValueError, 'Could not parse'):
@@ -84,17 +104,19 @@ class HParamsTest(test.TestCase):
hparams.parse('b=relu')
with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
hparams.parse('aaa=[123]')
- self.assertEqual(12, hparams.aaa)
- self.assertEqual(2.0, hparams.b)
+ self.assertEqual(10, hparams.aaa)
+ self.assertEqual(1.5, hparams.b)
self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a=b/c/d', hparams.d)
# Exports to proto.
hparam_def = hparams.to_proto()
# Imports from proto.
hparams2 = hparam.HParams(hparam_def=hparam_def)
# Verifies that all hparams are restored.
- self.assertEqual(12, hparams2.aaa)
- self.assertEqual(2.0, hparams2.b)
+ self.assertEqual(10, hparams2.aaa)
+ self.assertEqual(1.5, hparams2.b)
self.assertEqual('2.3"', hparams2.c_c)
+ self.assertEqual('/a=b/c/d', hparams2.d)
def testSetFromMap(self):
hparams = hparam.HParams(a=1, b=2.0, c='tanh')
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 2885a9f823..1d11410332 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -354,6 +354,7 @@ cc_library(
"platform/mutex.h",
"platform/net.h",
"platform/notification.h",
+ "platform/null_file_system.h",
"platform/prefetch.h",
"platform/profile_utils/clock_cycle_profiler.h",
"platform/profile_utils/cpu_utils.h",
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
index 9e0de08267..4eb6eb4e4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
@@ -34,7 +34,7 @@ This operation computes
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644
index 0000000000..47148f7b03
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterDiv"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Divides sparse updates into the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] /= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] /= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644
index 0000000000..71f06d9a43
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMax"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Reduces sparse updates into the variable referenced by `resource` using the `max` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644
index 0000000000..08e40ee2a8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMin"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Reduces sparse updates into the variable referenced by `resource` using the `min` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644
index 0000000000..5c63549d81
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterMul"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Multiplies sparse updates into the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] *= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] *= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644
index 0000000000..e71e60cbee
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "ResourceScatterSub"
+ in_arg {
+ name: "resource"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to add to `ref`.
+END
+ }
+ summary: "Subtracts sparse updates from the variable referenced by `resource`."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
index 4b5201f025..9da9d09ea6 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
index 771cf0b591..8e99718c7e 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions divide.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt
new file mode 100644
index 0000000000..7b52dad4a1
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt
@@ -0,0 +1,60 @@
+op {
+ graph_op_name: "ScatterMax"
+ in_arg {
+ name: "ref"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+ }
+ out_arg {
+ name: "output_ref"
+ description: <<END
+= Same as `ref`. Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+ }
+ summary: "Reduces sparse updates into a variable reference using the `max` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt
new file mode 100644
index 0000000000..721ac0ff35
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt
@@ -0,0 +1,60 @@
+op {
+ graph_op_name: "ScatterMin"
+ in_arg {
+ name: "ref"
+ description: <<END
+Should be from a `Variable` node.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+ }
+ out_arg {
+ name: "output_ref"
+ description: <<END
+= Same as `ref`. Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+ }
+ summary: "Reduces sparse updates into a variable reference using the `min` operation."
+ description: <<END
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
index a51f571b00..b9e293ba9e 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their contributions multiply.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
index c0d3a4a133..d12b3e68c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
Duplicate entries are handled correctly: if multiple `indices` reference
the same location, their (negated) contributions add.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterSub.png" alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
index c44dbbd233..4804908afc 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
@@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are
duplicate entries in `indices`, the order at which the updates happen
for each value is undefined.
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644
index 0000000000..56b5a46d10
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterDiv"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644
index 0000000000..8119bcc6c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMax"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644
index 0000000000..d874aef3fe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMin"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644
index 0000000000..365a37fa0d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterMul"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644
index 0000000000..72dc5bf889
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc
index 6ac9319ad1..16b61315f2 100644
--- a/tensorflow/core/common_runtime/constant_folding_test.cc
+++ b/tensorflow/core/common_runtime/constant_folding_test.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index de10b10b7e..a619cac9a4 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -55,6 +55,49 @@ tf_cuda_library(
)
tf_cuda_library(
+ name = "tensor_handle",
+ srcs = [
+ "tensor_handle.cc",
+ ],
+ hdrs = [
+ "tensor_handle.h",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":eager_executor",
+ ":kernel_and_device",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+)
+
+tf_cuda_library(
+ name = "copy_to_device_node",
+ hdrs = [
+ "copy_to_device_node.h",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":eager_executor",
+ ":tensor_handle",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+)
+
+tf_cuda_library(
name = "kernel_and_device",
srcs = [
"kernel_and_device.cc",
diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
new file mode 100644
index 0000000000..8a887540b0
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class CopyToDeviceNode : public EagerNode {
+ public:
+ CopyToDeviceNode(TensorHandle* src, Device* dstd, EagerContext* ctx)
+ : EagerNode(ctx->NextId()),
+ src_(src),
+ dstd_(dstd),
+ ctx_(ctx),
+ dst_(new TensorHandle(id, src_->dtype, ctx)) {
+ src_->Ref();
+ dst_->Ref();
+ }
+
+ ~CopyToDeviceNode() override {
+ src_->Unref();
+ dst_->Unref();
+ }
+
+ Status Run() override {
+ TensorHandle* temp = nullptr;
+ TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &temp));
+ const Tensor* tensor = nullptr;
+ Device* device = nullptr;
+ Device* op_device = nullptr;
+ Status status = temp->TensorAndDevice(&tensor, &device, &op_device);
+ // `temp` is a ready handle. So the following call should return OK.
+ TF_DCHECK_OK(status) << status.error_message();
+ DCHECK(tensor);
+ dst_->SetTensorAndDevice(*tensor, device, op_device);
+ temp->Unref();
+ return Status::OK();
+ }
+
+ TensorHandle* dst() { return dst_; }
+
+ private:
+ TensorHandle* src_;
+ Device* dstd_;
+ EagerContext* ctx_;
+ TensorHandle* dst_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
new file mode 100644
index 0000000000..328cd5dd5c
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -0,0 +1,178 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+bool TensorHandle::IsReady() {
+ if (node_id == 0) return true;
+ mutex_lock l(ctx_mutex_);
+ return ctx_ == nullptr;
+}
+
+Status TensorHandle::WaitReady() {
+ if (node_id == 0) return Status::OK();
+ EagerExecutor* executor = nullptr;
+ {
+ mutex_lock l(ctx_mutex_);
+ if (ctx_ == nullptr) return Status::OK();
+ executor = ctx_->Executor();
+ }
+ return executor->WaitFor(node_id);
+}
+
+Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *t = &tensor_;
+ return Status::OK();
+}
+
+Status TensorHandle::Device(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = device_;
+ return Status::OK();
+}
+
+Status TensorHandle::OpDevice(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = op_device_;
+ return Status::OK();
+}
+
+Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
+ tensorflow::Device** device,
+ tensorflow::Device** op_device) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *tensor = &tensor_;
+ *device = device_;
+ *op_device = op_device_;
+ return Status::OK();
+}
+
+void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device) {
+ mutex_lock l(ctx_mutex_);
+ DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called "
+ << "on non-ready handles.";
+ ctx_ = nullptr;
+ tensor_ = tensor;
+ device_ = device;
+ op_device_ = op_device;
+}
+
+Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
+ TensorHandle** output) {
+ const tensorflow::Tensor* src = nullptr;
+ tensorflow::Device* srcd = nullptr;
+ // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
+ // nullptr.
+ tensorflow::Device* src_opd = nullptr;
+ TF_RETURN_IF_ERROR(TensorAndDevice(&src, &srcd, &src_opd));
+ if (srcd == nullptr) srcd = ctx->HostCPU();
+ bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
+ const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
+ const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
+ // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
+ // has device type XLA_CPU, and the other CPU.
+ const bool both_on_cpu = src_cpu && dst_cpu;
+ if (is_same_device || both_on_cpu) {
+ dstd = dst_cpu ? nullptr : dstd;
+ *output = new tensorflow::TensorHandle(*src, dstd, dstd);
+ return tensorflow::Status::OK();
+ }
+ if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
+ !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
+ return tensorflow::errors::InvalidArgument(
+ "Can't copy Tensor with type ",
+ tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
+ ".");
+ }
+ tensorflow::AllocatorAttributes attr;
+ if (src->dtype() == tensorflow::DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
+ if (src->shape().num_elements() == 0) {
+ dstd = dst_cpu ? nullptr : dstd;
+ *output = new tensorflow::TensorHandle(dst, dstd, dstd);
+ return tensorflow::Status::OK();
+ }
+ tensorflow::DeviceContext* src_device_context = nullptr;
+ if (!src_cpu) {
+ src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
+ }
+ tensorflow::DeviceContext* dst_device_context = nullptr;
+ if (!dst_cpu) {
+ dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
+ }
+ // TODO(ashankar): The Sync() call below may be more aggressive than
+ // necessary. It is based on knowledge of implementation details - that
+ // GPU devices are implemented using 3 streams - one for host->device copies,
+ // one for device->host copies and one for sending operations to the GPU.
+ // With that setup, Sync()ing across all 3 streams should be sufficient
+ // but more than necessary (since it waits for operations that might have
+ // nothing to do with this tensor to complete).
+ TF_RETURN_IF_ERROR(srcd->Sync());
+ tensorflow::Notification n;
+ tensorflow::Status status;
+ tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
+ srcd, dstd, tensorflow::AllocatorAttributes(),
+ tensorflow::AllocatorAttributes(), src, &dst,
+ [&status, &n](const tensorflow::Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ if (status.ok()) {
+ dstd = dst_cpu ? nullptr : dstd;
+ *output = new tensorflow::TensorHandle(dst, dstd, dstd);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
new file mode 100644
index 0000000000..eb69a13c06
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+// Associates a Tensor and a Device, used in the eager runtime. Internal version
+// executor_of the TFE_TensorHandle struct and the python EagerTensor class
+// (unrelated to python TensorHandle).
+class TensorHandle : public core::RefCounted {
+ public:
+ TensorHandle(const Tensor& t, Device* d, Device* op_device)
+ : dtype(t.dtype()),
+ node_id(0),
+ tensor_(t),
+ device_(d),
+ op_device_(op_device),
+ ctx_(nullptr) {}
+
+ TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
+ : dtype(dtype),
+ node_id(node_id),
+ tensor_(dtype),
+ device_(nullptr),
+ op_device_(nullptr),
+ ctx_(ctx) {
+ DCHECK_GT(node_id, 0);
+ }
+
+ ~TensorHandle() override {}
+
+ Status Tensor(const tensorflow::Tensor** t);
+
+ Status Device(tensorflow::Device** d);
+
+ Status OpDevice(tensorflow::Device** d);
+
+ Status TensorAndDevice(const tensorflow::Tensor** tensor,
+ tensorflow::Device** device,
+ tensorflow::Device** op_device);
+
+ // Note that this can be called at most once, and only on non-ready handles,
+ // and makes them ready.
+ void SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device);
+
+ Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
+ TensorHandle** output);
+
+ // dtype for the handle. It must be the same as t.dtype() once the handle is
+ // ready.
+ const DataType dtype;
+
+ private:
+ // If the contents of the Tensor pointed to by this handle is yet to be
+ // computed by a EagerNode, this function will block till that compuatation is
+ // done and the handle is "ready".
+ Status WaitReady();
+
+ bool IsReady();
+
+ // Id for the EagerNode that will compute the value pointed to by this handle.
+ // If the value is 0, the handle is already ready, but not vice-versa.
+ const uint64 node_id;
+
+ tensorflow::Tensor tensor_;
+
+ // TODO(ashankar): device_ == nullptr iff local CPU
+ // This was expedient, but perhaps worth revisiting ('device_' should always
+ // be a valid pointer?)
+ // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
+ // provided with the appropriate TFE_Context.
+ //
+ // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
+ // TFE_TensorHandle does not outlive the TFE_Context from which it came?
+ tensorflow::Device* device_;
+
+ // Device in which the op producing this tensor was executed. Equals to
+ // device_ for constant tensors.
+ tensorflow::Device* op_device_;
+
+ mutex ctx_mutex_;
+
+ // `ctx` is only guaranteed to be set if the handle is not "ready". This is
+ // typically true when the handle was produced during async execution.
+ // `ctx` object is not owned and should outlive this handle.
+ EagerContext* ctx_ GUARDED_BY(ctx_mutex_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 52b9077d8c..8473b228d3 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -185,6 +185,7 @@ class DeviceBase {
virtual Allocator* GetScopedAllocator(AllocatorAttributes attr,
int64 step_id) {
LOG(FATAL) << "Device does not implement GetScopedAllocator()";
+ return nullptr;
}
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index bd9608b369..601984fcfd 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
@@ -286,6 +287,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
@@ -422,7 +424,7 @@ cc_library(
]),
)
-tf_cuda_cc_test(
+tf_cuda_only_cc_test(
name = "memory_optimizer_test",
srcs = ["memory_optimizer_test.cc"],
tags = ["no_cuda_on_cpu_tap"], # Do not re-enable again without actually testing.
@@ -498,6 +500,7 @@ cc_library(
":constant_folding",
":custom_graph_optimizer",
":custom_graph_optimizer_registry",
+ ":debug_stripper",
":dependency_optimizer",
":function_optimizer",
":graph_optimizer",
@@ -616,3 +619,34 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "debug_stripper",
+ srcs = ["debug_stripper.cc"],
+ hdrs = [
+ "debug_stripper.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:graph_optimizer",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "debug_stripper_test",
+ size = "small",
+ srcs = ["debug_stripper_test.cc"],
+ deps = [
+ ":debug_stripper",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 3876486d80..792f675043 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -157,6 +158,8 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
ArithmeticOptimizer optimizer;
GraphDef output;
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// Run the optimizer twice to make sure the rewrite is idempotent.
@@ -172,6 +175,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
EXPECT_EQ(2, new_div.input_size());
EXPECT_EQ("c1", new_div.input(0));
EXPECT_EQ("c1", new_div.input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 914a9257ee..6340565bcd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1922,6 +1922,8 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
"concat5", "concat6", "concat7", "concat8", "concat9"};
+ auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+ EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -1971,9 +1973,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
}
}
- auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
auto tensors = EvaluateNodes(output, {"concat0"});
- EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
new file mode 100644
index 0000000000..461f1aa2fb
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ // TODO(haoliang): Let's remove assertions here.
+ *output = item.graph;
+ return Status::OK();
+}
+
+void DebugStripper::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // Takes no feedback.
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.h b/tensorflow/core/grappler/optimizers/debug_stripper.h
new file mode 100644
index 0000000000..1fe25aa1c3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// DebugStripper strips off debug-related nodes (e.g.
+// Assert, CheckNumerics, Print) from the graph.
+class DebugStripper : public GraphOptimizer {
+ public:
+ DebugStripper() {}
+ ~DebugStripper() override {}
+
+ string name() const override { return "debug_stripper"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
new file mode 100644
index 0000000000..d2cabc0798
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class DebugStripperTest : public GrapplerTest {};
+
+// TODO(haoliang): Add tests for different removal operations.
+TEST_F(DebugStripperTest, OutputEqualToInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto c = ops::Const(s.WithOpName("c"), 0, {});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 6eb2bbc547..47ec16226b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
@@ -84,6 +85,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
graph_optimizer.reset(
new DependencyOptimizer(cfg_.dependency_optimization()));
}
+ if (optimizer == "debug_stripper") {
+ graph_optimizer.reset(new DebugStripper());
+ }
return graph_optimizer;
}
@@ -134,10 +138,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new AutoParallel(cfg_.auto_parallel().num_replicas())));
}
+ if (cfg_.debug_stripper() == RewriterConfig::ON) {
+ optimizers.push_back(
+ std::unique_ptr<GraphOptimizer>(new DebugStripper()));
+ }
} else {
const std::set<string> available_optimizers = {
- "pruning", "function", "constfold", "layout", "memory",
- "autoparallel", "arithmetic", "loop", "dependency"};
+ "pruning", "function", "constfold", "layout",
+ "memory", "autoparallel", "arithmetic", "loop",
+ "dependency", "debug_stripper"};
std::vector<string> custom_optimizer_names;
for (const auto& optimizer_name : cfg_.optimizers()) {
if (available_optimizers.find(optimizer_name) !=
@@ -238,6 +247,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.dependency_optimization() != RewriterConfig::OFF ||
cfg.auto_parallel().enable() ||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
+ cfg.debug_stripper() == RewriterConfig::ON ||
!cfg.optimizers().empty();
}
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 1c15ea65b8..ee126f4955 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -36,6 +36,7 @@ GrapplerTest::GrapplerTest() {
cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc
index b3814331ee..b2dc16d5d7 100644
--- a/tensorflow/core/kernels/immutable_constant_op_test.cc
+++ b/tensorflow/core/kernels/immutable_constant_op_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index aecad0185f..e134e476f6 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
- int64 num_updates = updates.NumElements();
- OP_REQUIRES(c, num_updates % N == 0,
- errors::InvalidArgument(
- "shape of indices (", indices.shape().DebugString(),
- ") is not compatible with the shape of updates (",
- updates.shape().DebugString(), ")"));
- auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i),
- " = ", indices_flat(bad_i), " is not in [0, ",
- params->dim_size(0), ")"));
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ } else {
+ int64 num_updates = updates.NumElements();
+ OP_REQUIRES(c, num_updates % N == 0,
+ errors::InvalidArgument(
+ "shape of indices (", indices.shape().DebugString(),
+ ") is not compatible with the shape of updates (",
+ updates.shape().DebugString(), ")"));
+ auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ }
}
}
};
@@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-// TODO(apassos) add the other types here.
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
+ scatter_op::UpdateOp::SUB); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
+ scatter_op::UpdateOp::MUL); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
+ scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
+ scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
+ scatter_op::UpdateOp::MAX);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
#endif // GOOGLE_CUDA
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
#undef REGISTER_SCATTER_KERNEL
#undef REGISTER_SCATTER_KERNEL_INDEX
diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc
index 7eba82899f..cf5408123f 100644
--- a/tensorflow/core/kernels/scatter_functor.cc
+++ b/tensorflow/core/kernels/scatter_functor.cc
@@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Forward declarations of the functor specializations for GPU.
-#define DECLARE_GPU_SPECS_OP(T, Index, op) \
- template <> \
- Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
- OpKernelContext* c, const GPUDevice& d, \
- typename TTypes<T>::Matrix params, \
- typename TTypes<T>::ConstMatrix updates, \
- typename TTypes<Index>::ConstFlat indices); \
- extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
+#define DECLARE_GPU_SPECS_OP(T, Index, op) \
+ template <> \
+ Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ typename TTypes<T>::ConstMatrix updates, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterFunctor<GPUDevice, T, Index, op>; \
+ template <> \
+ Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ const typename TTypes<T>::ConstScalar update, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index 079f15e101..52666645bf 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <type_traits>
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/types.h"
@@ -33,7 +35,7 @@ typedef Eigen::SyclDevice SYCLDevice;
namespace scatter_op {
-enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
+enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
namespace internal {
@@ -45,6 +47,10 @@ struct Assign<scatter_op::UpdateOp::ASSIGN> {
static void Run(Params p, Update u) {
p = u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p.setConstant(u);
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::ADD> {
@@ -52,6 +58,10 @@ struct Assign<scatter_op::UpdateOp::ADD> {
static void Run(Params p, Update u) {
p += u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p + u;
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::SUB> {
@@ -59,6 +69,10 @@ struct Assign<scatter_op::UpdateOp::SUB> {
static void Run(Params p, Update u) {
p -= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p + static_cast<Update>(-u);
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::MUL> {
@@ -66,6 +80,10 @@ struct Assign<scatter_op::UpdateOp::MUL> {
static void Run(Params p, Update u) {
p *= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p * u;
+ }
};
template <>
struct Assign<scatter_op::UpdateOp::DIV> {
@@ -73,6 +91,34 @@ struct Assign<scatter_op::UpdateOp::DIV> {
static void Run(Params p, Update u) {
p /= u;
}
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p / u;
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MIN> {
+ // This method requires that Params and Update are tensor types.
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p = p.cwiseMin(u);
+ }
+ // Same thing, but for Update being a scalar type.
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p.cwiseMin(u);
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MAX> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p = p.cwiseMax(u);
+ }
+ template <typename Params, typename Update>
+ static void RunScalar(Params p, Update u) {
+ p = p.cwiseMax(u);
+ }
};
#ifdef TENSORFLOW_USE_SYCL
@@ -117,6 +163,22 @@ struct AssignSYCL<scatter_op::UpdateOp::DIV> {
p.device(d) = p / u;
}
};
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MIN> {
+ template <typename Device, typename Params, typename Update>
+ static void Run(Device d, Params p, Update u) {
+ p.device(d) = p.cwiseMin(u);
+ }
+};
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MAX> {
+ template <typename Device, typename Params, typename Update>
+ static void Run(Device d, Params p, Update u) {
+ p.device(d) = p.cwiseMax(u);
+ }
+};
#endif // TENSORFLOW_USE_SYCL
} // namespace internal
@@ -241,6 +303,112 @@ struct ScatterFunctorSYCL {
};
#endif // TENSORFLOW_USE_SYCL
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices);
+};
+
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::Assign<op>::RunScalar(
+ params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
+ Index operator()(OpKernelContext* c, const SYCLDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::AssignSYCL<op>::RunScalar(
+ d, params.template chip<0>(index), update);
+ }
+ return -1;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename T, typename Index>
+struct ScatterScalarFunctorBase<CPUDevice, T, Index,
+ scatter_op::UpdateOp::ASSIGN> {
+ Index operator()(OpKernelContext* c, const CPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
+ params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<CPUDevice, T, Index, op>
+ : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorSYCL {
+ Index operator()(OpKernelContext* c, const SYCLDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::Flat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ for (Index i = 0; i < N; i++) {
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ scatter_op::internal::AssignSYCL<op>::Run(
+ d, params.template chip<0>(index), update());
+ }
+ return -1;
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
index 52972997cc..59911bf0d2 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
@@ -23,15 +23,18 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
- template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+ template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index be18658543..70809e4dcf 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -29,12 +29,53 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+namespace scatter_op_gpu {
+
+template <typename T, scatter_op::UpdateOp op>
+struct ScatterOpKernelBody;
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> {
+ __device__ void operator()(T* dest, T src) const { *dest = src; }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); }
+};
+
template <typename T, typename Index, scatter_op::UpdateOp op>
__global__ void ScatterOpCustomKernel(T* params, const T* updates,
const Index* indices,
Index first_dim_size, Index updates_size,
Index indices_size) {
Index update_block = updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
CUDA_1D_KERNEL_LOOP(i, updates_size) {
int indices_i = i / update_block;
int updates_i = i;
@@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
continue;
}
int params_i = param_first_index * update_block + (i % update_block);
- switch (op) {
- case scatter_op::UpdateOp::ASSIGN: {
- params[params_i] = ldg(updates + updates_i);
- break;
- }
- case scatter_op::UpdateOp::ADD: {
- CudaAtomicAdd(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::SUB: {
- CudaAtomicSub(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::MUL: {
- CudaAtomicMul(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::DIV: {
- CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
- break;
- }
+ body(&params[params_i], ldg(updates + updates_i));
+ }
+}
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
+ const Index* indices,
+ Index first_dim_size,
+ Index indices_size,
+ Index synthesized_updates_size) {
+ Index update_block = synthesized_updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
+ CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) {
+ int indices_i = i / update_block;
+ int param_first_index = indices[indices_i];
+ const T update_val = *update;
+ if (!(param_first_index >= 0 && param_first_index < first_dim_size)) {
+ // Ignore indices that are out of range.
+ continue;
}
+ int params_i = param_first_index * update_block + (i % update_block);
+ body(&params[params_i], update_val);
}
}
+} // namespace scatter_op_gpu
+
namespace functor {
// Specialization for a GPU device.
template <typename T, typename Index, scatter_op::UpdateOp op>
@@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
const Index indices_size = indices.size();
const Index updates_size = updates.size();
CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
- ScatterOpCustomKernel<T, Index, op>
+ scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
params.data(), updates.data(), indices.data(), first_dim_size,
updates_size, indices_size);
@@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
}
};
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
+ Index operator()(OpKernelContext* c, const GPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // TODO(b/31801742): Implement indices range check. The hardest part is
+ // with returning a value after the range check, as we do not want to do
+ // device to host memcpy during a stream.
+ const Index first_dim_size = params.dimension(0);
+ const Index indices_size = indices.size();
+ const Index synthesized_updates_size = indices_size * params.dimension(1);
+ CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d);
+ scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ params.data(), update.data(), indices.data(), first_dim_size,
+ indices_size, synthesized_updates_size);
+ return -1;
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 282165349f..0fbde764d5 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice;
// Check whether updates.shape = indices.shape + params.shape[1:]
static bool ValidShapes(const Tensor& params, const Tensor& updates,
const Tensor& indices) {
+ if (updates.dims() == 0) return true;
if (updates.dims() != indices.dims() + params.dims() - 1) return false;
for (int d = 0; d < indices.dims(); d++) {
if (updates.dim_size(d) != indices.dim_size(d)) {
@@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
params.shape().DebugString()));
OP_REQUIRES(
c, ValidShapes(params, updates, indices),
- errors::InvalidArgument(
- "Must have updates.shape = indices.shape + params.shape[1:], got ",
- "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
- indices.shape().DebugString(), ", params.shape ",
- params.shape().DebugString()));
+ errors::InvalidArgument("Must have updates.shape = indices.shape + "
+ "params.shape[1:] or updates.shape = [], got ",
+ "updates.shape ", updates.shape().DebugString(),
+ ", indices.shape ", indices.shape().DebugString(),
+ ", params.shape ", params.shape().DebugString()));
}
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
@@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape()) ||
+ IsLegacyScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
auto indices_flat = indices_host.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctorSYCL<T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
+
#define REGISTER_SCATTER_UPDATE(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
// Registers GPU kernels.
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, SYCL);
+
+#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
-#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_ARITHMETIC_SYCL
+#undef REGISTER_SCATTER_MINMAX_SYCL
#undef REGISTER_SCATTER_UPDATE_SYCL
#endif // TENSORFLOW_USE_SYCL
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
-#undef REGISTER_SCATTER_ARITHEMTIC_GPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC_GPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
+#undef REGISTER_SCATTER_MINMAX_GPU
#undef REGISTER_SCATTER_UPDATE
#undef REGISTER_SCATTER_UPDATE_CPU
#undef REGISTER_SCATTER_UPDATE_GPU
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
index 0b43704846..0df329310f 100644
--- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
@@ -24,15 +24,18 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Instantiates functor specializations for GPU.
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
- template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+ template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index 0b8645a2ae..5b3537b94c 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -185,7 +185,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -202,7 +202,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -219,7 +219,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
.contains("Must have updates.shape = indices.shape + "
- "params.shape[1:], got "))
+ "params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -300,6 +300,20 @@ static void BM_ScatterDivInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
}
+static void BM_ScatterMinInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin");
+}
+static void BM_ScatterMinInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin");
+}
+
+static void BM_ScatterMaxInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax");
+}
+static void BM_ScatterMaxInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax");
+}
+
BENCHMARK(BM_ScatterUpdateInt32)
->Arg(1)
->Arg(10)
@@ -332,5 +346,11 @@ BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index b41826d6eb..05d6e02281 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -43706,6 +43706,210 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterDiv"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMax"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMin"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMul"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -43743,6 +43947,57 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterSub"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterUpdate"
input_arg {
name: "resource"
@@ -48902,6 +49157,110 @@ op {
}
}
op {
+ name: "ScatterMax"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
+ name: "ScatterMin"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ScatterMul"
input_arg {
name: "ref"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index af2c563489..274a7fbf75 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -21659,6 +21659,210 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterDiv"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMax"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMin"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ResourceScatterMul"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -21696,6 +21900,57 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterSub"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterUpdate"
input_arg {
name: "resource"
@@ -23435,6 +23690,110 @@ op {
}
}
op {
+ name: "ScatterMax"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
+ name: "ScatterMin"
+ input_arg {
+ name: "ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_ref"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ScatterMul"
input_arg {
name: "ref"
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 0d8cf78cc2..3d0a6c2157 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather")
return Status::OK();
});
+namespace {
+
+Status ResourceScatterUpdateShape(InferenceContext* c) {
+ ShapeAndType handle_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
+ ShapeHandle var_shape = handle_shape_and_type.shape;
+ ShapeHandle indices_shape = c->input(1);
+
+ ShapeHandle unused_updates_shape;
+ ShapeHandle concat;
+ ShapeHandle var_subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
+ TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
+ TF_RETURN_IF_ERROR(
+ InferenceContext::Rank(c->input(2)) == 0
+ ? Status::OK()
+ : c->Merge(c->input(2), concat, &unused_updates_shape));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("ResourceScatterAdd")
.Input("resource: resource")
.Input("indices: Tindices")
.Input("updates: dtype")
.Attr("dtype: numbertype")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
+ .SetShapeFn(ResourceScatterUpdateShape);
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+REGISTER_OP("ResourceScatterSub")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMul")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterDiv")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMin")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMax")
+ .Input("resource: resource")
+ .Input("indices: Tindices")
+ .Input("updates: dtype")
+ .Attr("dtype: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("ResourceScatterUpdate")
.Input("resource: resource")
@@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate")
.Input("updates: dtype")
.Attr("dtype: type")
.Attr("Tindices: {int32, int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeAndType handle_shape_and_type;
- TF_RETURN_IF_ERROR(
- ValidateVariableResourceHandle(c, &handle_shape_and_type));
- ShapeHandle var_shape = handle_shape_and_type.shape;
- ShapeHandle indices_shape = c->input(1);
-
- ShapeHandle unused_updates_shape;
- ShapeHandle concat;
- ShapeHandle var_subshape;
- TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
- TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
- return Status::OK();
- });
+ .SetShapeFn(ResourceScatterUpdateShape);
REGISTER_OP("MutexV2")
.Attr("container: string = ''")
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 7a524b60c0..664f52452e 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -122,7 +122,10 @@ Status ScatterUpdateShape(InferenceContext* c) {
ShapeHandle var_subshape;
TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
- TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
+ TF_RETURN_IF_ERROR(
+ InferenceContext::Rank(c->input(2)) == 0
+ ? Status::OK()
+ : c->Merge(c->input(2), concat, &unused_updates_shape));
c->set_output(0, var_shape);
return Status::OK();
@@ -180,6 +183,26 @@ REGISTER_OP("ScatterDiv")
.Attr("use_locking: bool = false")
.SetShapeFn(ScatterUpdateShape);
+REGISTER_OP("ScatterMin")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape);
+
+REGISTER_OP("ScatterMax")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape);
+
REGISTER_OP("ScatterNdUpdate")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 47ddf0ccb9..9a6ff48069 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 03c0c5ab51..8f99766e15 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -305,74 +305,6 @@ class ReadOnlyMemoryRegion {
virtual uint64 length() = 0;
};
-// START_SKIP_DOXYGEN
-
-#ifndef SWIG
-// Degenerate file system that provides no implementations.
-class NullFileSystem : public FileSystem {
- public:
- NullFileSystem() {}
-
- ~NullFileSystem() override = default;
-
- Status NewRandomAccessFile(
- const string& fname, std::unique_ptr<RandomAccessFile>* result) override {
- return errors::Unimplemented("NewRandomAccessFile unimplemented");
- }
-
- Status NewWritableFile(const string& fname,
- std::unique_ptr<WritableFile>* result) override {
- return errors::Unimplemented("NewWritableFile unimplemented");
- }
-
- Status NewAppendableFile(const string& fname,
- std::unique_ptr<WritableFile>* result) override {
- return errors::Unimplemented("NewAppendableFile unimplemented");
- }
-
- Status NewReadOnlyMemoryRegionFromFile(
- const string& fname,
- std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
- return errors::Unimplemented(
- "NewReadOnlyMemoryRegionFromFile unimplemented");
- }
-
- Status FileExists(const string& fname) override {
- return errors::Unimplemented("FileExists unimplemented");
- }
-
- Status GetChildren(const string& dir, std::vector<string>* result) override {
- return errors::Unimplemented("GetChildren unimplemented");
- }
-
- Status DeleteFile(const string& fname) override {
- return errors::Unimplemented("DeleteFile unimplemented");
- }
-
- Status CreateDir(const string& dirname) override {
- return errors::Unimplemented("CreateDir unimplemented");
- }
-
- Status DeleteDir(const string& dirname) override {
- return errors::Unimplemented("DeleteDir unimplemented");
- }
-
- Status GetFileSize(const string& fname, uint64* file_size) override {
- return errors::Unimplemented("GetFileSize unimplemented");
- }
-
- Status RenameFile(const string& src, const string& target) override {
- return errors::Unimplemented("RenameFile unimplemented");
- }
-
- Status Stat(const string& fname, FileStatistics* stat) override {
- return errors::Unimplemented("Stat unimplemented");
- }
-};
-#endif
-
-// END_SKIP_DOXYGEN
-
/// \brief A registry for file system implementations.
///
/// Filenames are specified as an URI, which is of the form
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index abe88ab6c7..e07aad55cb 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/core/platform/null_file_system.h b/tensorflow/core/platform/null_file_system.h
new file mode 100644
index 0000000000..008e6d54d0
--- /dev/null
+++ b/tensorflow/core/platform/null_file_system.h
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_
+#define TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_system.h"
+
+namespace tensorflow {
+
+// START_SKIP_DOXYGEN
+
+#ifndef SWIG
+// Degenerate file system that provides no implementations.
+class NullFileSystem : public FileSystem {
+ public:
+ NullFileSystem() {}
+
+ ~NullFileSystem() override = default;
+
+ Status NewRandomAccessFile(
+ const string& fname, std::unique_ptr<RandomAccessFile>* result) override {
+ return errors::Unimplemented("NewRandomAccessFile unimplemented");
+ }
+
+ Status NewWritableFile(const string& fname,
+ std::unique_ptr<WritableFile>* result) override {
+ return errors::Unimplemented("NewWritableFile unimplemented");
+ }
+
+ Status NewAppendableFile(const string& fname,
+ std::unique_ptr<WritableFile>* result) override {
+ return errors::Unimplemented("NewAppendableFile unimplemented");
+ }
+
+ Status NewReadOnlyMemoryRegionFromFile(
+ const string& fname,
+ std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
+ return errors::Unimplemented(
+ "NewReadOnlyMemoryRegionFromFile unimplemented");
+ }
+
+ Status FileExists(const string& fname) override {
+ return errors::Unimplemented("FileExists unimplemented");
+ }
+
+ Status GetChildren(const string& dir, std::vector<string>* result) override {
+ return errors::Unimplemented("GetChildren unimplemented");
+ }
+
+ Status DeleteFile(const string& fname) override {
+ return errors::Unimplemented("DeleteFile unimplemented");
+ }
+
+ Status CreateDir(const string& dirname) override {
+ return errors::Unimplemented("CreateDir unimplemented");
+ }
+
+ Status DeleteDir(const string& dirname) override {
+ return errors::Unimplemented("DeleteDir unimplemented");
+ }
+
+ Status GetFileSize(const string& fname, uint64* file_size) override {
+ return errors::Unimplemented("GetFileSize unimplemented");
+ }
+
+ Status RenameFile(const string& src, const string& target) override {
+ return errors::Unimplemented("RenameFile unimplemented");
+ }
+
+ Status Stat(const string& fname, FileStatistics* stat) override {
+ return errors::Unimplemented("Stat unimplemented");
+ }
+};
+#endif
+
+// END_SKIP_DOXYGEN
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index fdf16aa1da..bb772460b0 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -46,6 +46,8 @@ message RewriterConfig {
Toggle loop_optimization = 9;
// Function optimizations (default is ON).
Toggle function_optimization = 10;
+ // Strips debug-related nodes from the graph (off by default).
+ Toggle debug_stripper = 11;
// If true, don't remove unnecessary ops from the graph
bool disable_model_pruning = 2;
diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md
index 0d612ee0c7..ec2d877386 100644
--- a/tensorflow/docs_src/api_guides/python/state_ops.md
+++ b/tensorflow/docs_src/api_guides/python/state_ops.md
@@ -83,6 +83,8 @@ automatically by the optimizers in most cases.
* @{tf.scatter_sub}
* @{tf.scatter_mul}
* @{tf.scatter_div}
+* @{tf.scatter_min}
+* @{tf.scatter_max}
* @{tf.scatter_nd_update}
* @{tf.scatter_nd_add}
* @{tf.scatter_nd_sub}
diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md
index b706d9b204..ebeff8493b 100644
--- a/tensorflow/docs_src/community/index.md
+++ b/tensorflow/docs_src/community/index.md
@@ -13,3 +13,6 @@ This section contains the following documents:
conventions that TensorFlow developers and users should follow.
* @{$community/benchmarks$Benchmarks}, Benchmarks, a guide for defining and
running a TensorFlow benchmark.
+ * @{$security$Using TensorFlow Securely}, which explains TensorFlow's security
+ model, a list of recent security reports, and information on how you can
+ report a security vulnerability to the TensorFlow team.
diff --git a/tensorflow/docs_src/community/leftnav_files b/tensorflow/docs_src/community/leftnav_files
index fab35024ad..af344506c7 100644
--- a/tensorflow/docs_src/community/leftnav_files
+++ b/tensorflow/docs_src/community/leftnav_files
@@ -4,3 +4,4 @@ roadmap.md
documentation.md
style_guide.md
benchmarks.md
+security.md
diff --git a/tensorflow/docs_src/community/security.md b/tensorflow/docs_src/community/security.md
new file mode 100644
index 0000000000..8d13c7a1ea
--- /dev/null
+++ b/tensorflow/docs_src/community/security.md
@@ -0,0 +1,7 @@
+# Using TensorFlow Securely
+
+Before using TensorFlow, please take a look at our security model, list of
+recent security announcements, and ways you can report security issues to the
+TensorFlow team at the
+[https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](Using
+TensorFlow Securely) page on GitHub.
diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md
index b88483be69..f59cebe6c4 100644
--- a/tensorflow/docs_src/get_started/get_started_for_beginners.md
+++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md
@@ -14,6 +14,11 @@ If you are already familiar with basic machine learning concepts
but are new to TensorFlow, read
@{$premade_estimators$Getting Started with TensorFlow: for ML Experts}.
+If you'd like to learn a lot about the basics of Machine Learning,
+consider taking
+[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/).
+
+
## The Iris classification problem
Imagine you are a botanist seeking an automated way to classify each
@@ -86,6 +91,9 @@ a number. Here's the representation scheme:
* 1 represents versicolor
* 2 represents virginica
+For a look at other examples of labels and examples, see the
+[ML Terminology section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/framing/ml-terminology).
+
## Models and training
@@ -371,7 +379,7 @@ There are several categories of neural networks.
We'll be using a [**fully connected neural
network**](https://developers.google.com/machine-learning/glossary/#fully_connected_layer),
which means that the neurons in one layer take inputs from *every* neuron in
-the previous layer. For example, the following figure illustrates a
+the previous layer. For example, the following figure illustrates a
fully connected neural network consisting of three hidden layers:
* The first hidden layer contains four neurons.
@@ -385,6 +393,9 @@ fully connected neural network consisting of three hidden layers:
**A neural network with three hidden layers.**
<p>&nbsp;</p>
+For a more detailed introduction to neural networks, see the
+[Introduction to Neural Nets section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/introduction-to-neural-networks/anatomy).
+
To specify a model type, instantiate an
[**Estimator**](https://developers.google.com/machine-learning/glossary/#Estimators)
class. TensorFlow provides two categories of Estimators:
@@ -448,9 +459,9 @@ will become very important.
### Train the model
-Instantiating a `tf.Estimator.DNNClassifier` creates a framework for learning
-the model. Basically, we've wired a network but haven't yet let data flow
-through it. To train the neural network, call the Estimator object's `train`
+Instantiating a `tf.Estimator.DNNClassifier` creates a framework for learning
+the model. Basically, we've wired a network but haven't yet let data flow
+through it. To train the neural network, call the Estimator object's `train`
method. For example:
```python
@@ -559,15 +570,15 @@ of 0.5. The following suggests a more effective model:
<th colspan="1">Label</th>
<th colspan="1">Prediction</th>
</tr>
- <tr> <td>5.9</td> <td>3.0</td> <td>4.3</td> <td>1.5</td> <td>1</td>
+ <tr> <td>5.9</td> <td>3.0</td> <td>4.3</td> <td>1.5</td> <td>1</td>
<td style="background-color:green">1</td></tr>
- <tr> <td>6.9</td> <td>3.1</td> <td>5.4</td> <td>2.1</td> <td>2</td>
+ <tr> <td>6.9</td> <td>3.1</td> <td>5.4</td> <td>2.1</td> <td>2</td>
<td style="background-color:green">2</td></tr>
- <tr> <td>5.1</td> <td>3.3</td> <td>1.7</td> <td>0.5</td> <td>0</td>
+ <tr> <td>5.1</td> <td>3.3</td> <td>1.7</td> <td>0.5</td> <td>0</td>
<td style="background-color:green">0</td></tr>
- <tr> <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td>1</td>
+ <tr> <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td>1</td>
<td style="background-color:red">2</td></tr>
- <tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td>
+ <tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td>
<td style="background-color:green">1</td></tr>
</table>
@@ -631,6 +642,10 @@ Test set accuracy: 0.967
An accuracy of 0.967 implies that our trained model correctly classified 29
out of the 30 Iris species in the test set.
+To get a deeper understanding of different metrics for evaluating
+models, see the
+[Classification section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/classification).
+
### Predicting
@@ -723,7 +738,6 @@ Prediction is "Virginica" (97.9%), expected "Virginica"
## Summary
-<!--TODO(barryr): When MLCC is released, add pointers to relevant sections.-->
This document provides a short introduction to machine learning.
Because `premade_estimators.py` relies on high-level APIs, much of the
diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md
index b7bd1286e3..fb83a770a5 100644
--- a/tensorflow/docs_src/get_started/index.md
+++ b/tensorflow/docs_src/get_started/index.md
@@ -1,5 +1,12 @@
# Getting Started
+If you are new to machine learning, we recommend taking the following online
+course prior to diving into TensorFlow documentation:
+
+ * [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/),
+ which introduces machine learning concepts and encourages experimentation
+ with existing TensorFlow code.
+
TensorFlow is a tool for machine learning. While it contains a wide range of
functionality, TensorFlow is mainly designed for deep neural network models.
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 8e46c9ee20..27b696696d 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -506,11 +506,18 @@ TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}.
-
If the system outputs an error message instead of a greeting, see [Common
installation problems](#common_installation_problems).
+If you are new to machine learning, we recommend the following:
+
+* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
+* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
+
+If you are experienced with machine learning but new to TensorFlow, see
+@{$get_started/premade_estimators$Getting Started with TensorFlow}.
+
+
## Common installation problems
We are relying on Stack Overflow to document TensorFlow installation problems
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index cb7250a16e..7060ef43da 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -400,12 +400,18 @@ writing TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see
-@{$get_started/premade_estimators$Getting Started with TensorFlow}.
-
If the system outputs an error message instead of a greeting, see
[Common installation problems](#common_installation_problems).
+If you are new to machine learning, we recommend the following:
+
+* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
+* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
+
+If you are experienced with machine learning but new to TensorFlow, see
+@{$get_started/premade_estimators$Getting Started with TensorFlow}.
+
+
## Common installation problems
We are relying on Stack Overflow to document TensorFlow installation problems
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index 2413bc9cfb..86add74da1 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -17,7 +17,7 @@ You must choose one of the following types of TensorFlow to install:
NVIDIA® GPU, you must install this version. Note that this version of
TensorFlow is typically much easier to install (typically,
in 5 or 10 minutes), so even if you have an NVIDIA GPU, we recommend
- installing this version first. Prebuilt binaries will use AVX instructions.
+ installing this version first. Prebuilt binaries will use AVX instructions.
* **TensorFlow with GPU support**. TensorFlow programs typically run
significantly faster on a GPU than on a CPU. Therefore, if your
system has a NVIDIA® GPU meeting the prerequisites shown below
@@ -154,13 +154,17 @@ TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}.
-
If the system outputs an error message instead of a greeting, see [Common
installation problems](#common_installation_problems).
-There is also a helpful [script](https://gist.github.com/mrry/ee5dbcfdd045fa48a27d56664411d41c)
-for Windows TensorFlow installation issues.
+If you are new to machine learning, we recommend the following:
+
+* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
+* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
+
+If you are experienced with machine learning but new to TensorFlow, see
+@{$get_started/premade_estimators$Getting Started with TensorFlow}.
+
## Common installation problems
diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md
index e8027fc12b..d5703e0737 100644
--- a/tensorflow/docs_src/programmers_guide/embedding.md
+++ b/tensorflow/docs_src/programmers_guide/embedding.md
@@ -7,6 +7,9 @@ with the TensorBoard Embedding Projector
newcomers to machine learning or TensorFlow, and the Embedding Projector how-to
is for users at all levels.
+An alternative tutorial on these concepts is available in the
+[Embeddings section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/embeddings/video-lecture).
+
[TOC]
An **embedding** is a mapping from discrete objects, such as words, to vectors
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 5ddd32ed48..838f4f2301 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1089,186 +1089,232 @@ func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output
return op.Output(0)
}
-// Returns (x - y)(x - y) element-wise.
+// A placeholder op that passes through `input` when its output is not fed.
//
-// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+// Arguments:
+// input: The default value to produce when `output` is not fed.
+// shape: The (possibly partial) shape of the tensor.
+//
+// Returns A placeholder tensor that defaults to `input` if it is not fed.
+func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"shape": shape}
opspec := tf.OpSpec{
- Type: "SquaredDifference",
+ Type: "PlaceholderWithDefault",
Input: []tf.Input{
- x, y,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Forwards the input to the output.
+// A placeholder op for a value that will be fed into the computation.
//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
+// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2.
+//
+// N.B. This operation will fail with an error if it is executed. It is
+// intended as a way to represent a value that will always be fed, and to
+// provide attrs that enable the fed value to be checked at runtime.
//
// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
+// dtype: The type of elements in the tensor.
+// shape: The shape of the tensor. The shape can be any partially-specified
+// shape. To be unconstrained, pass in a shape with unknown rank.
//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+// Returns A placeholder tensor that must be replaced using the feed mechanism.
+func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"dtype": dtype, "shape": shape}
opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
+ Type: "PlaceholderV2",
+
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// QuantizedMulAttr is an optional argument to QuantizedMul.
-type QuantizedMulAttr func(optionalAttr)
+// PlaceholderAttr is an optional argument to Placeholder.
+type PlaceholderAttr func(optionalAttr)
-// QuantizedMulToutput sets the optional Toutput attribute to value.
-// If not specified, defaults to DT_QINT32
-func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr {
+// PlaceholderShape sets the optional shape attribute to value.
+//
+// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
+// shape is unconstrained.
+// If not specified, defaults to <unknown_rank:true >
+func PlaceholderShape(value tf.Shape) PlaceholderAttr {
return func(m optionalAttr) {
- m["Toutput"] = value
+ m["shape"] = value
}
}
-// Returns x * y element-wise, working on quantized buffers.
-//
-// Arguments:
-//
+// A placeholder op for a value that will be fed into the computation.
//
-// min_x: The float value that the lowest quantized `x` value represents.
-// max_x: The float value that the highest quantized `x` value represents.
-// min_y: The float value that the lowest quantized `y` value represents.
-// max_y: The float value that the highest quantized `y` value represents.
+// N.B. This operation will fail with an error if it is executed. It is
+// intended as a way to represent a value that will always be fed, and to
+// provide attrs that enable the fed value to be checked at runtime.
//
-// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
+// Arguments:
+// dtype: The type of elements in the tensor.
//
-// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about
-// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) {
+// Returns A placeholder tensor that must be replaced using the feed mechanism.
+func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
+ attrs := map[string]interface{}{"dtype": dtype}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "QuantizedMul",
- Input: []tf.Input{
- x, y, min_x, max_x, min_y, max_y,
- },
+ Type: "Placeholder",
+
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// QuantizedMatMulAttr is an optional argument to QuantizedMatMul.
-type QuantizedMatMulAttr func(optionalAttr)
-
-// QuantizedMatMulToutput sets the optional Toutput attribute to value.
-// If not specified, defaults to DT_QINT32
-func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr {
- return func(m optionalAttr) {
- m["Toutput"] = value
- }
+ return op.Output(0)
}
-// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value.
+// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
//
-// value: If true, `a` is transposed before multiplication.
-// If not specified, defaults to false
-func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr {
- return func(m optionalAttr) {
- m["transpose_a"] = value
- }
-}
-
-// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value.
+// This operation folds the padded areas of `input` by `MirrorPad` according to the
+// `paddings` you specify. `paddings` must be the same as `paddings` argument
+// given to the corresponding `MirrorPad` op.
//
-// value: If true, `b` is transposed before multiplication.
-// If not specified, defaults to false
-func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr {
- return func(m optionalAttr) {
- m["transpose_b"] = value
- }
-}
-
-// QuantizedMatMulTactivation sets the optional Tactivation attribute to value.
+// The folded size of each dimension D of the output is:
//
-// value: The type of output produced by activation function
-// following this operation.
-// If not specified, defaults to DT_QUINT8
-func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr {
- return func(m optionalAttr) {
- m["Tactivation"] = value
+// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)`
+//
+// For example:
+//
+// ```
+// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]].
+// # 'paddings' is [[0, 1]], [0, 1]].
+// # 'mode' is SYMMETRIC.
+// # rank of 't' is 2.
+// pad(t, paddings) ==> [[ 1, 5]
+// [11, 28]]
+// ```
+//
+// Arguments:
+// input: The input tensor to be folded.
+// paddings: A two-column matrix specifying the padding sizes. The number of
+// rows must be the same as the rank of `input`.
+// mode: The mode used in the `MirrorPad` op.
+//
+// Returns The folded tensor.
+func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
+ if scope.Err() != nil {
+ return
}
+ attrs := map[string]interface{}{"mode": mode}
+ opspec := tf.OpSpec{
+ Type: "MirrorPadGrad",
+ Input: []tf.Input{
+ input, paddings,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Perform a quantized matrix multiplication of `a` by the matrix `b`.
+// Pads a tensor with mirrored values.
//
-// The inputs must be two-dimensional matrices and the inner dimension of
-// `a` (after being transposed if `transpose_a` is non-zero) must match the
-// outer dimension of `b` (after being transposed if `transposed_b` is
-// non-zero).
+// This operation pads a `input` with mirrored values according to the `paddings`
+// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
+// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many values to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many values to add after the contents of `input`
+// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
+// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
+// (if false, respectively).
+//
+// The padded size of each dimension D of the output is:
+//
+// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+//
+// For example:
+//
+// ```
+// # 't' is [[1, 2, 3], [4, 5, 6]].
+// # 'paddings' is [[1, 1]], [2, 2]].
+// # 'mode' is SYMMETRIC.
+// # rank of 't' is 2.
+// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
+// [2, 1, 1, 2, 3, 3, 2]
+// [5, 4, 4, 5, 6, 6, 5]
+// [5, 4, 4, 5, 6, 6, 5]]
+// ```
//
// Arguments:
-// a: Must be a two-dimensional tensor.
-// b: Must be a two-dimensional tensor.
-// min_a: The float value that the lowest quantized `a` value represents.
-// max_a: The float value that the highest quantized `a` value represents.
-// min_b: The float value that the lowest quantized `b` value represents.
-// max_b: The float value that the highest quantized `b` value represents.
+// input: The input tensor to be padded.
+// paddings: A two-column matrix specifying the padding sizes. The number of
+// rows must be the same as the rank of `input`.
+// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions
+// do not include the borders, while in symmetric mode the padded regions
+// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings`
+// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and
+// it is `[1, 2, 3, 3, 2]` in symmetric mode.
//
-// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
-func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) {
+// Returns The padded tensor.
+func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
+ attrs := map[string]interface{}{"mode": mode}
opspec := tf.OpSpec{
- Type: "QuantizedMatMul",
+ Type: "MirrorPad",
Input: []tf.Input{
- a, b, min_a, max_a, min_b, max_b,
+ input, paddings,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
-// A placeholder op that passes through `input` when its output is not fed.
+// Pads a tensor.
//
-// Arguments:
-// input: The default value to produce when `output` is not fed.
-// shape: The (possibly partial) shape of the tensor.
+// This operation pads `input` according to the `paddings` and `constant_values`
+// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
+// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many padding values to add before the contents of `input` in that dimension,
+// and `paddings[D, 1]` indicates how many padding values to add after the contents
+// of `input` in that dimension. `constant_values` is a scalar tensor of the same
+// type as `input` that indicates the value to use for padding `input`.
//
-// Returns A placeholder tensor that defaults to `input` if it is not fed.
-func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) {
+// The padded size of each dimension D of the output is:
+//
+// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+//
+// For example:
+//
+// ```
+// # 't' is [[1, 1], [2, 2]]
+// # 'paddings' is [[1, 1], [2, 2]]
+// # 'constant_values' is 0
+// # rank of 't' is 2
+// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
+// [0, 0, 1, 1, 0, 0]
+// [0, 0, 2, 2, 0, 0]
+// [0, 0, 0, 0, 0, 0]]
+// ```
+func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"shape": shape}
opspec := tf.OpSpec{
- Type: "PlaceholderWithDefault",
+ Type: "PadV2",
Input: []tf.Input{
- input,
+ input, paddings, constant_values,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -2063,6 +2109,47 @@ func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true i
return op.Output(0), op.Output(1), op.Output(2)
}
+// Returns (x - y)(x - y) element-wise.
+//
+// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SquaredDifference",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Forwards the input to the output.
+//
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
+//
+// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
+//
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -2391,50 +2478,6 @@ func Sign(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// QuantizedAddAttr is an optional argument to QuantizedAdd.
-type QuantizedAddAttr func(optionalAttr)
-
-// QuantizedAddToutput sets the optional Toutput attribute to value.
-// If not specified, defaults to DT_QINT32
-func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr {
- return func(m optionalAttr) {
- m["Toutput"] = value
- }
-}
-
-// Returns x + y element-wise, working on quantized buffers.
-//
-// Arguments:
-//
-//
-// min_x: The float value that the lowest quantized `x` value represents.
-// max_x: The float value that the highest quantized `x` value represents.
-// min_y: The float value that the lowest quantized `y` value represents.
-// max_y: The float value that the highest quantized `y` value represents.
-//
-// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
-//
-// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about
-// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QuantizedAdd",
- Input: []tf.Input{
- x, y, min_x, max_x, min_y, max_y,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// ArgMinAttr is an optional argument to ArgMin.
type ArgMinAttr func(optionalAttr)
@@ -3741,32 +3784,6 @@ func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
return op.Output(0)
}
-// Given a quantized tensor described by (input, input_min, input_max), outputs a
-//
-// range that covers the actual values present in that tensor. This op is
-// typically used to produce the requested_output_min and requested_output_max for
-// Requantize.
-//
-// Arguments:
-//
-// input_min: The float value that the minimum quantized input value represents.
-// input_max: The float value that the maximum quantized input value represents.
-//
-// Returns The computed min output.the computed max output.
-func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RequantizationRange",
- Input: []tf.Input{
- input, input_min, input_max,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Returns the truth value of (x <= y) element-wise.
//
// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting
@@ -3943,46 +3960,6 @@ func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMul
return op.Output(0)
}
-// Pads a tensor.
-//
-// This operation pads `input` according to the `paddings` and `constant_values`
-// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
-// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-// how many padding values to add before the contents of `input` in that dimension,
-// and `paddings[D, 1]` indicates how many padding values to add after the contents
-// of `input` in that dimension. `constant_values` is a scalar tensor of the same
-// type as `input` that indicates the value to use for padding `input`.
-//
-// The padded size of each dimension D of the output is:
-//
-// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 1], [2, 2]]
-// # 'paddings' is [[1, 1], [2, 2]]
-// # 'constant_values' is 0
-// # rank of 't' is 2
-// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-// [0, 0, 1, 1, 0, 0]
-// [0, 0, 2, 2, 0, 0]
-// [0, 0, 0, 0, 0, 0]]
-// ```
-func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "PadV2",
- Input: []tf.Input{
- input, paddings, constant_values,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns which elements of x are NaN.
//
// @compatibility(numpy)
@@ -4292,52 +4269,6 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output
return op.Output(0)
}
-// MaxPoolAttr is an optional argument to MaxPool.
-type MaxPoolAttr func(optionalAttr)
-
-// MaxPoolDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func MaxPoolDataFormat(value string) MaxPoolAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Performs max pooling on the input.
-//
-// Arguments:
-// input: 4-D input to pool over.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
-//
-// Returns The max pooled output tensor.
-func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MaxPool",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes gradients of the maxpooling function.
//
// Arguments:
@@ -5247,50 +5178,6 @@ func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
-//
-// This operation folds the padded areas of `input` by `MirrorPad` according to the
-// `paddings` you specify. `paddings` must be the same as `paddings` argument
-// given to the corresponding `MirrorPad` op.
-//
-// The folded size of each dimension D of the output is:
-//
-// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]].
-// # 'paddings' is [[0, 1]], [0, 1]].
-// # 'mode' is SYMMETRIC.
-// # rank of 't' is 2.
-// pad(t, paddings) ==> [[ 1, 5]
-// [11, 28]]
-// ```
-//
-// Arguments:
-// input: The input tensor to be folded.
-// paddings: A two-column matrix specifying the padding sizes. The number of
-// rows must be the same as the rank of `input`.
-// mode: The mode used in the `MirrorPad` op.
-//
-// Returns The folded tensor.
-func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"mode": mode}
- opspec := tf.OpSpec{
- Type: "MirrorPadGrad",
- Input: []tf.Input{
- input, paddings,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// BiasAddGradAttr is an optional argument to BiasAddGrad.
type BiasAddGradAttr func(optionalAttr)
@@ -5411,239 +5298,95 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
-// AvgPoolGradAttr is an optional argument to AvgPoolGrad.
-type AvgPoolGradAttr func(optionalAttr)
-
-// AvgPoolGradDataFormat sets the optional data_format attribute to value.
+// Returns the rank of a tensor.
//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func AvgPoolGradDataFormat(value string) AvgPoolGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes gradients of the average pooling function.
+// This operation returns an integer representing the rank of `input`.
//
-// Arguments:
-// orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
-// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
-// the output of `avg_pool`.
-// ksize: The size of the sliding window for each dimension of the input.
-// strides: The stride of the sliding window for each dimension of the input.
-// padding: The type of padding algorithm to use.
+// For example:
//
-// Returns 4-D. Gradients w.r.t. the input of `avg_pool`.
-func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) {
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// # shape of tensor 't' is [2, 2, 3]
+// rank(t) ==> 3
+// ```
+//
+// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
+// of a tensor is the number of indices required to uniquely select each element
+// of the tensor. Rank is also known as "order", "degree", or "ndims."
+func Rank(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "AvgPoolGrad",
+ Type: "Rank",
Input: []tf.Input{
- orig_input_shape, grad,
+ input,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// StageClearAttr is an optional argument to StageClear.
-type StageClearAttr func(optionalAttr)
-
-// StageClearCapacity sets the optional capacity attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func StageClearCapacity(value int64) StageClearAttr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// StageClearMemoryLimit sets the optional memory_limit attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func StageClearMemoryLimit(value int64) StageClearAttr {
- return func(m optionalAttr) {
- m["memory_limit"] = value
- }
-}
-
-// StageClearContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func StageClearContainer(value string) StageClearAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// StageClearSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func StageClearSharedName(value string) StageClearAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Op removes all elements in the underlying container.
-//
-// Returns the created operation.
-func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtypes": dtypes}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StageClear",
-
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
-type ComputeAccidentalHitsAttr func(optionalAttr)
-
-// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Computes the ids of the positions in sampled_candidates that match true_labels.
-//
-// When doing log-odds NCE, the result of this op should be passed through a
-// SparseToDense op, then added to the logits of the sampled candidates. This has
-// the effect of 'removing' the sampled labels that match the true labels by
-// making the classifier sure that they are sampled labels.
+// Transforms a Tensor into a serialized TensorProto proto.
//
// Arguments:
-// true_classes: The true_classes output of UnpackSparseLabels.
-// sampled_candidates: The sampled_candidates output of CandidateSampler.
-// num_true: Number of true labels per context.
+// tensor: A Tensor of type `T`.
//
-// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
-// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
-// is -FLOAT_MAX.
-func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
+// Returns A serialized TensorProto proto of the input tensor.
+func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"num_true": num_true}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "ComputeAccidentalHits",
+ Type: "SerializeTensor",
Input: []tf.Input{
- true_classes, sampled_candidates,
+ tensor,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
-// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
-type TensorArrayGatherV3Attr func(optionalAttr)
+// MatrixSolveAttr is an optional argument to MatrixSolve.
+type MatrixSolveAttr func(optionalAttr)
-// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value.
+// MatrixSolveAdjoint sets the optional adjoint attribute to value.
//
-// value: The expected shape of an element, if known. Used to
-// validate the shapes of TensorArray elements. If this shape is not
-// fully specified, gathering zero-size TensorArrays is an error.
-// If not specified, defaults to <unknown_rank:true >
-func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr {
+// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
+// adjoint.
+// If not specified, defaults to false
+func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
return func(m optionalAttr) {
- m["element_shape"] = value
+ m["adjoint"] = value
}
}
-// Gather specific elements from the TensorArray into output `value`.
+// Solves systems of linear equations.
//
-// All elements selected by `indices` must have the same shape.
+// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
+// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
+// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+// If `adjoint` is `True` then each output matrix satisfies
+// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
//
// Arguments:
-// handle: The handle to a TensorArray.
-// indices: The locations in the TensorArray from which to read tensor elements.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
+// matrix: Shape is `[..., M, M]`.
+// rhs: Shape is `[..., M, K]`.
//
-// Returns All of the elements in the TensorArray, concatenated along a new
-// axis (the new dimension 0).
-func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) {
+// Returns Shape is `[..., M, K]`.
+func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"dtype": dtype}
+ attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "TensorArrayGatherV3",
- Input: []tf.Input{
- handle, indices, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Converts each string in the input Tensor to its hash mod by a number of buckets.
-//
-// The hash function is deterministic on the content of the string within the
-// process and will never change. However, it is not suitable for cryptography.
-// This function may be used when CPU time is scarce and inputs are trusted or
-// unimportant. There is a risk of adversaries constructing inputs that all hash
-// to the same bucket. To prevent this problem, use a strong hash function with
-// `tf.string_to_hash_bucket_strong`.
-//
-// Arguments:
-// input: The strings to assign a hash bucket.
-// num_buckets: The number of buckets.
-//
-// Returns A Tensor of the same shape as the input `string_tensor`.
-func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "StringToHashBucketFast",
+ Type: "MatrixSolve",
Input: []tf.Input{
- input,
+ matrix, rhs,
},
Attrs: attrs,
}
@@ -5651,18 +5394,15 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o
return op.Output(0)
}
-// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
-//
-// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+// Computes acos of x element-wise.
+func Acos(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Maximum",
+ Type: "Acos",
Input: []tf.Input{
- x, y,
+ x,
},
}
op := scope.AddOperation(opspec)
@@ -5707,6 +5447,76 @@ func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output
return op.Output(0)
}
+// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
+type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
+
+// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, height, width, channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, channels, height, width].
+// If not specified, defaults to "NHWC"
+func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value.
+//
+// value: 1-D tensor of length 4. The dilation factor for each dimension of
+// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+// element on that dimension. The dimension order is determined by the value of
+// `data_format`, see above for details. Dilations in the batch and depth
+// dimensions must be 1.
+// If not specified, defaults to <i:1 i:1 i:1 i:1 >
+func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
+ return func(m optionalAttr) {
+ m["dilations"] = value
+ }
+}
+
+// Computes the gradients of depthwise convolution with respect to the filter.
+//
+// Arguments:
+// input: 4-D with shape based on `data_format`. For example, if
+// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
+// in_width, in_channels]` tensor.
+// filter_sizes: An integer vector representing the tensor shape of `filter`,
+// where `filter` is a 4-D
+// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor.
+// out_backprop: 4-D with shape based on `data_format`.
+// For example, if `data_format` is 'NHWC' then
+// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
+// Gradients w.r.t. the output of the convolution.
+// strides: The stride of the sliding window for each dimension of the input
+// of the convolution.
+// padding: The type of padding algorithm to use.
+//
+// Returns 4-D with shape
+// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
+// the `filter` input of the convolution.
+func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DepthwiseConv2dNativeBackpropFilter",
+ Input: []tf.Input{
+ input, filter_sizes, out_backprop,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// LRNGradAttr is an optional argument to LRNGrad.
type LRNGradAttr func(optionalAttr)
@@ -6236,6 +6046,79 @@ func Tan(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl.
+type ResourceSparseApplyFtrlAttr func(optionalAttr)
+
+// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update relevant entries in '*var' according to the Ftrl-proximal scheme.
+//
+// That is for rows we have grad for, we update var, accum and linear as follows:
+// accum_new = accum + grad * grad
+// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
+// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+// accum = accum_new
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// linear: Should be from a Variable().
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var and accum.
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regularization. Must be a scalar.
+// l2: L2 regularization. Must be a scalar.
+// lr_power: Scaling factor. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyFtrl",
+ Input: []tf.Input{
+ var_, accum, linear, grad, indices, lr, l1, l2, lr_power,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Returns which elements of x are Inf.
+//
+// @compatibility(numpy)
+// Equivalent to np.isinf
+// @end_compatibility
+func IsInf(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IsInf",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
//
// N is the size of the segment being reduced.
@@ -6918,6 +6801,170 @@ func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output,
return scope.AddOperation(opspec)
}
+// AvgPoolGradAttr is an optional argument to AvgPoolGrad.
+type AvgPoolGradAttr func(optionalAttr)
+
+// AvgPoolGradDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func AvgPoolGradDataFormat(value string) AvgPoolGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes gradients of the average pooling function.
+//
+// Arguments:
+// orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
+// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
+// the output of `avg_pool`.
+// ksize: The size of the sliding window for each dimension of the input.
+// strides: The stride of the sliding window for each dimension of the input.
+// padding: The type of padding algorithm to use.
+//
+// Returns 4-D. Gradients w.r.t. the input of `avg_pool`.
+func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPoolGrad",
+ Input: []tf.Input{
+ orig_input_shape, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// StageClearAttr is an optional argument to StageClear.
+type StageClearAttr func(optionalAttr)
+
+// StageClearCapacity sets the optional capacity attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func StageClearCapacity(value int64) StageClearAttr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// StageClearMemoryLimit sets the optional memory_limit attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func StageClearMemoryLimit(value int64) StageClearAttr {
+ return func(m optionalAttr) {
+ m["memory_limit"] = value
+ }
+}
+
+// StageClearContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func StageClearContainer(value string) StageClearAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// StageClearSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func StageClearSharedName(value string) StageClearAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Op removes all elements in the underlying container.
+//
+// Returns the created operation.
+func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StageClear",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
+type ComputeAccidentalHitsAttr func(optionalAttr)
+
+// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Computes the ids of the positions in sampled_candidates that match true_labels.
+//
+// When doing log-odds NCE, the result of this op should be passed through a
+// SparseToDense op, then added to the logits of the sampled candidates. This has
+// the effect of 'removing' the sampled labels that match the true labels by
+// making the classifier sure that they are sampled labels.
+//
+// Arguments:
+// true_classes: The true_classes output of UnpackSparseLabels.
+// sampled_candidates: The sampled_candidates output of CandidateSampler.
+// num_true: Number of true labels per context.
+//
+// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
+// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
+// is -FLOAT_MAX.
+func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ComputeAccidentalHits",
+ Input: []tf.Input{
+ true_classes, sampled_candidates,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// CumsumAttr is an optional argument to Cumsum.
type CumsumAttr func(optionalAttr)
@@ -7314,79 +7361,6 @@ func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64,
return op.Output(0)
}
-// Generates values in an interval.
-//
-// A sequence of `num` evenly-spaced values are generated beginning at `start`.
-// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
-// so that the last one is exactly `stop`.
-//
-// For example:
-//
-// ```
-// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
-// ```
-//
-// Arguments:
-// start: First entry in the range.
-// stop: Last entry in the range.
-// num: Number of values to generate.
-//
-// Returns 1-D. The generated values.
-func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LinSpace",
- Input: []tf.Input{
- start, stop, num,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
-type DestroyResourceOpAttr func(optionalAttr)
-
-// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value.
-//
-// value: whether to ignore the error when the resource
-// doesn't exist.
-// If not specified, defaults to true
-func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr {
- return func(m optionalAttr) {
- m["ignore_lookup_error"] = value
- }
-}
-
-// Deletes the resource specified by the handle.
-//
-// All subsequent operations using the resource will result in a NotFound
-// error status.
-//
-// Arguments:
-// resource: handle to the resource to delete.
-//
-// Returns the created operation.
-func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DestroyResourceOp",
- Input: []tf.Input{
- resource,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -7822,6 +7796,79 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Generates values in an interval.
+//
+// A sequence of `num` evenly-spaced values are generated beginning at `start`.
+// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
+// so that the last one is exactly `stop`.
+//
+// For example:
+//
+// ```
+// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
+// ```
+//
+// Arguments:
+// start: First entry in the range.
+// stop: Last entry in the range.
+// num: Number of values to generate.
+//
+// Returns 1-D. The generated values.
+func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LinSpace",
+ Input: []tf.Input{
+ start, stop, num,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
+type DestroyResourceOpAttr func(optionalAttr)
+
+// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value.
+//
+// value: whether to ignore the error when the resource
+// doesn't exist.
+// If not specified, defaults to true
+func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr {
+ return func(m optionalAttr) {
+ m["ignore_lookup_error"] = value
+ }
+}
+
+// Deletes the resource specified by the handle.
+//
+// All subsequent operations using the resource will result in a NotFound
+// error status.
+//
+// Arguments:
+// resource: handle to the resource to delete.
+//
+// Returns the created operation.
+func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DestroyResourceOp",
+ Input: []tf.Input{
+ resource,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// LRNAttr is an optional argument to LRN.
type LRNAttr func(optionalAttr)
@@ -8054,6 +8101,65 @@ func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...Resi
return op.Output(0)
}
+// Pads a tensor with zeros.
+//
+// This operation pads a `input` with zeros according to the `paddings` you
+// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many zeros to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many zeros to add after the contents of `input`
+// in that dimension.
+//
+// The padded size of each dimension D of the output is:
+//
+// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+//
+// For example:
+//
+// ```
+// # 't' is [[1, 1], [2, 2]]
+// # 'paddings' is [[1, 1], [2, 2]]
+// # rank of 't' is 2
+// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
+// [0, 0, 1, 1, 0, 0]
+// [0, 0, 2, 2, 0, 0]
+// [0, 0, 0, 0, 0, 0]]
+// ```
+func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pad",
+ Input: []tf.Input{
+ input, paddings,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Checks whether a resource handle-based variable has been initialized.
+//
+// Arguments:
+// resource: the input resource handle.
+//
+// Returns a scalar boolean which is true if the variable has been
+// initialized.
+func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "VarIsInitializedOp",
+ Input: []tf.Input{
+ resource,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform.
type StatelessRandomUniformAttr func(optionalAttr)
@@ -8098,6 +8204,38 @@ func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optio
return op.Output(0)
}
+// Makes its input available to the next iteration.
+//
+// Arguments:
+// data: The tensor to be made available to the next iteration.
+//
+// Returns The same tensor as `data`.
+func NextIteration(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "NextIteration",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Output a fact about factorials.
+func Fact(scope *Scope) (fact tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Fact",
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// AngleAttr is an optional argument to Angle.
type AngleAttr func(optionalAttr)
@@ -8672,79 +8810,6 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl.
-type ResourceSparseApplyFtrlAttr func(optionalAttr)
-
-// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update relevant entries in '*var' according to the Ftrl-proximal scheme.
-//
-// That is for rows we have grad for, we update var, accum and linear as follows:
-// accum_new = accum + grad * grad
-// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-// accum = accum_new
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// linear: Should be from a Variable().
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var and accum.
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regularization. Must be a scalar.
-// l2: L2 regularization. Must be a scalar.
-// lr_power: Scaling factor. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceSparseApplyFtrl",
- Input: []tf.Input{
- var_, accum, linear, grad, indices, lr, l1, l2, lr_power,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Returns which elements of x are Inf.
-//
-// @compatibility(numpy)
-// Equivalent to np.isinf
-// @end_compatibility
-func IsInf(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IsInf",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
type ResourceSparseApplyRMSPropAttr func(optionalAttr)
@@ -8974,6 +9039,100 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box
return op.Output(0), op.Output(1), op.Output(2)
}
+// Converts each string in the input Tensor to its hash mod by a number of buckets.
+//
+// The hash function is deterministic on the content of the string within the
+// process and will never change. However, it is not suitable for cryptography.
+// This function may be used when CPU time is scarce and inputs are trusted or
+// unimportant. There is a risk of adversaries constructing inputs that all hash
+// to the same bucket. To prevent this problem, use a strong hash function with
+// `tf.string_to_hash_bucket_strong`.
+//
+// Arguments:
+// input: The strings to assign a hash bucket.
+// num_buckets: The number of buckets.
+//
+// Returns A Tensor of the same shape as the input `string_tensor`.
+func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_buckets": num_buckets}
+ opspec := tf.OpSpec{
+ Type: "StringToHashBucketFast",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
+//
+// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Maximum",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
+type TensorArrayGatherV3Attr func(optionalAttr)
+
+// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value.
+//
+// value: The expected shape of an element, if known. Used to
+// validate the shapes of TensorArray elements. If this shape is not
+// fully specified, gathering zero-size TensorArrays is an error.
+// If not specified, defaults to <unknown_rank:true >
+func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr {
+ return func(m optionalAttr) {
+ m["element_shape"] = value
+ }
+}
+
+// Gather specific elements from the TensorArray into output `value`.
+//
+// All elements selected by `indices` must have the same shape.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+// indices: The locations in the TensorArray from which to read tensor elements.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns All of the elements in the TensorArray, concatenated along a new
+// axis (the new dimension 0).
+func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGatherV3",
+ Input: []tf.Input{
+ handle, indices, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns x / y element-wise for integer types.
//
// Truncation designates that negative numbers will round fractional quantities
@@ -9048,6 +9207,30 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and
return tensors
}
+// Creates a dataset that skips `count` elements from the `input_dataset`.
+//
+// Arguments:
+//
+// count: A scalar representing the number of elements from the `input_dataset`
+// that should be skipped. If count is -1, skips everything.
+//
+//
+func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "SkipDataset",
+ Input: []tf.Input{
+ input_dataset, count,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the maximum along segments of a tensor.
//
// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
@@ -9084,30 +9267,6 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.
return op.Output(0)
}
-// Creates a dataset that skips `count` elements from the `input_dataset`.
-//
-// Arguments:
-//
-// count: A scalar representing the number of elements from the `input_dataset`
-// that should be skipped. If count is -1, skips everything.
-//
-//
-func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
- opspec := tf.OpSpec{
- Type: "SkipDataset",
- Input: []tf.Input{
- input_dataset, count,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes hyperbolic tangent of `x` element-wise.
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -9861,6 +10020,79 @@ func FFT(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Transforms a serialized tensorflow.TensorProto proto into a Tensor.
+//
+// Arguments:
+// serialized: A scalar string containing a serialized TensorProto proto.
+// out_type: The type of the serialized tensor. The provided type must match the
+// type of the serialized tensor and no implicit conversion will take place.
+//
+// Returns A Tensor of type `out_type`.
+func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"out_type": out_type}
+ opspec := tf.OpSpec{
+ Type: "ParseTensor",
+ Input: []tf.Input{
+ serialized,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax.
+type MaxPoolWithArgmaxAttr func(optionalAttr)
+
+// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value.
+// If not specified, defaults to DT_INT64
+func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr {
+ return func(m optionalAttr) {
+ m["Targmax"] = value
+ }
+}
+
+// Performs max pooling on the input and outputs both max values and indices.
+//
+// The indices in `argmax` are flattened, so that a maximum value at position
+// `[b, y, x, c]` becomes flattened index
+// `((b * height + y) * width + x) * channels + c`.
+//
+// The indices returned are always in `[0, height) x [0, width)` before flattening,
+// even if padding is involved and the mathematically correct answer is outside
+// (either negative or too large). This is a bug, but fixing it is difficult to do
+// in a safe backwards compatible way, especially due to flattening.
+//
+// Arguments:
+// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output.
+func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPoolWithArgmax",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA.
type ResourceSparseApplyAdagradDAAttr func(optionalAttr)
@@ -11004,73 +11236,6 @@ func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, padd
return op.Output(0)
}
-// Transforms a Tensor into a serialized TensorProto proto.
-//
-// Arguments:
-// tensor: A Tensor of type `T`.
-//
-// Returns A serialized TensorProto proto of the input tensor.
-func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SerializeTensor",
- Input: []tf.Input{
- tensor,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MatrixSolveAttr is an optional argument to MatrixSolve.
-type MatrixSolveAttr func(optionalAttr)
-
-// MatrixSolveAdjoint sets the optional adjoint attribute to value.
-//
-// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
-// adjoint.
-// If not specified, defaults to false
-func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
- return func(m optionalAttr) {
- m["adjoint"] = value
- }
-}
-
-// Solves systems of linear equations.
-//
-// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
-// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
-// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
-// If `adjoint` is `True` then each output matrix satisfies
-// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
-//
-// Arguments:
-// matrix: Shape is `[..., M, M]`.
-// rhs: Shape is `[..., M, K]`.
-//
-// Returns Shape is `[..., M, K]`.
-func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MatrixSolve",
- Input: []tf.Input{
- matrix, rhs,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Inverse 3D fast Fourier transform.
//
// Computes the inverse 3-dimensional discrete Fourier transform over the
@@ -12025,6 +12190,46 @@ func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_va
return op.Output(0)
}
+// Concatenates tensors along one dimension.
+//
+// Arguments:
+// values: List of `N` Tensors to concatenate. Their ranks and types must match,
+// and their sizes must match in all dimensions except `concat_dim`.
+// axis: 0-D. The dimension along which to concatenate. Must be in the
+// range [-rank(values), rank(values)).
+//
+// Returns A `Tensor` with the concatenation of values stacked along the
+// `concat_dim` dimension. This tensor's shape matches that of `values` except
+// in `concat_dim` where it has the sum of the sizes.
+func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ConcatV2",
+ Input: []tf.Input{
+ tf.OutputList(values), axis,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Reads and outputs the entire contents of the input filename.
+func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReadFile",
+ Input: []tf.Input{
+ filename,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// MinAttr is an optional argument to Min.
type MinAttr func(optionalAttr)
@@ -12088,76 +12293,6 @@ func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
return op.Output(0)
}
-// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
-type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
-
-// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, height, width, channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, channels, height, width].
-// If not specified, defaults to "NHWC"
-func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value.
-//
-// value: 1-D tensor of length 4. The dilation factor for each dimension of
-// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
-// element on that dimension. The dimension order is determined by the value of
-// `data_format`, see above for details. Dilations in the batch and depth
-// dimensions must be 1.
-// If not specified, defaults to <i:1 i:1 i:1 i:1 >
-func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
- return func(m optionalAttr) {
- m["dilations"] = value
- }
-}
-
-// Computes the gradients of depthwise convolution with respect to the filter.
-//
-// Arguments:
-// input: 4-D with shape based on `data_format`. For example, if
-// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
-// in_width, in_channels]` tensor.
-// filter_sizes: An integer vector representing the tensor shape of `filter`,
-// where `filter` is a 4-D
-// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor.
-// out_backprop: 4-D with shape based on `data_format`.
-// For example, if `data_format` is 'NHWC' then
-// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
-// Gradients w.r.t. the output of the convolution.
-// strides: The stride of the sliding window for each dimension of the input
-// of the convolution.
-// padding: The type of padding algorithm to use.
-//
-// Returns 4-D with shape
-// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
-// the `filter` input of the convolution.
-func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DepthwiseConv2dNativeBackpropFilter",
- Input: []tf.Input{
- input, filter_sizes, out_backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes sigmoid of `x` element-wise.
//
// Specifically, `y = 1 / (1 + exp(-x))`.
@@ -12888,6 +13023,140 @@ func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...
return op.Output(0)
}
+// DecodeCSVAttr is an optional argument to DecodeCSV.
+type DecodeCSVAttr func(optionalAttr)
+
+// DecodeCSVFieldDelim sets the optional field_delim attribute to value.
+//
+// value: char delimiter to separate fields in a record.
+// If not specified, defaults to ","
+func DecodeCSVFieldDelim(value string) DecodeCSVAttr {
+ return func(m optionalAttr) {
+ m["field_delim"] = value
+ }
+}
+
+// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value.
+//
+// value: If false, treats double quotation marks as regular
+// characters inside of the string fields (ignoring RFC 4180, Section 2,
+// Bullet 5).
+// If not specified, defaults to true
+func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr {
+ return func(m optionalAttr) {
+ m["use_quote_delim"] = value
+ }
+}
+
+// DecodeCSVNaValue sets the optional na_value attribute to value.
+//
+// value: Additional string to recognize as NA/NaN.
+// If not specified, defaults to ""
+func DecodeCSVNaValue(value string) DecodeCSVAttr {
+ return func(m optionalAttr) {
+ m["na_value"] = value
+ }
+}
+
+// Convert CSV records to tensors. Each column maps to one tensor.
+//
+// RFC 4180 format is expected for the CSV records.
+// (https://tools.ietf.org/html/rfc4180)
+// Note that we allow leading and trailing spaces with int or float field.
+//
+// Arguments:
+// records: Each string is a record/row in the csv and all records should have
+// the same format.
+// record_defaults: One tensor per column of the input record, with either a
+// scalar default value for that column or empty if the column is required.
+//
+// Returns Each tensor will have the same shape as records.
+func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeCSV",
+ Input: []tf.Input{
+ records, tf.OutputList(record_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("DecodeCSV", err)
+ return
+ }
+ return output
+}
+
+// MapClearAttr is an optional argument to MapClear.
+type MapClearAttr func(optionalAttr)
+
+// MapClearCapacity sets the optional capacity attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func MapClearCapacity(value int64) MapClearAttr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// MapClearMemoryLimit sets the optional memory_limit attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func MapClearMemoryLimit(value int64) MapClearAttr {
+ return func(m optionalAttr) {
+ m["memory_limit"] = value
+ }
+}
+
+// MapClearContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func MapClearContainer(value string) MapClearAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MapClearSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func MapClearSharedName(value string) MapClearAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Op removes all elements in the underlying container.
+//
+// Returns the created operation.
+func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MapClear",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler.
type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr)
@@ -13007,78 +13276,6 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output
return op.Output(0)
}
-// Deprecated. Use TensorArrayReadV3
-//
-// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3
-func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "TensorArrayReadV2",
- Input: []tf.Input{
- handle, index, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Does nothing. Serves as a control trigger for scheduling.
-//
-// Only useful as a placeholder for control edges.
-//
-// Returns the created operation.
-func ControlTrigger(scope *Scope) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ControlTrigger",
- }
- return scope.AddOperation(opspec)
-}
-
-// Batch normalization.
-//
-// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization()
-//
-// This op is deprecated. Prefer `tf.nn.batch_normalization`.
-//
-// Arguments:
-// t: A 4D input Tensor.
-// m: A 1D mean Tensor with size matching the last dimension of t.
-// This is the first output from tf.nn.moments,
-// or a saved moving average thereof.
-// v: A 1D variance Tensor with size matching the last dimension of t.
-// This is the second output from tf.nn.moments,
-// or a saved moving average thereof.
-// beta: A 1D beta Tensor with size matching the last dimension of t.
-// An offset to be added to the normalized tensor.
-// gamma: A 1D gamma Tensor with size matching the last dimension of t.
-// If "scale_after_normalization" is true, this tensor will be multiplied
-// with the normalized tensor.
-// variance_epsilon: A small float number to avoid dividing by 0.
-// scale_after_normalization: A bool indicating whether the resulted tensor
-// needs to be multiplied with gamma.
-func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization}
- opspec := tf.OpSpec{
- Type: "BatchNormWithGlobalNormalization",
- Input: []tf.Input{
- t, m, v, beta, gamma,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2.
type MutableDenseHashTableV2Attr func(optionalAttr)
@@ -13375,65 +13572,6 @@ func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp
return op.Output(0)
}
-// Checks whether a resource handle-based variable has been initialized.
-//
-// Arguments:
-// resource: the input resource handle.
-//
-// Returns a scalar boolean which is true if the variable has been
-// initialized.
-func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "VarIsInitializedOp",
- Input: []tf.Input{
- resource,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Pads a tensor with zeros.
-//
-// This operation pads a `input` with zeros according to the `paddings` you
-// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
-// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-// how many zeros to add before the contents of `input` in that dimension, and
-// `paddings[D, 1]` indicates how many zeros to add after the contents of `input`
-// in that dimension.
-//
-// The padded size of each dimension D of the output is:
-//
-// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 1], [2, 2]]
-// # 'paddings' is [[1, 1], [2, 2]]
-// # rank of 't' is 2
-// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-// [0, 0, 1, 1, 0, 0]
-// [0, 0, 2, 2, 0, 0]
-// [0, 0, 0, 0, 0, 0]]
-// ```
-func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pad",
- Input: []tf.Input{
- input, paddings,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes gradients for SparseSegmentMean.
//
// Returns tensor "output" with same shape as grad, except for dimension 0 whose
@@ -14876,6 +15014,101 @@ func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) {
return op.Output(0)
}
+// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
+type MatrixSolveLsAttr func(optionalAttr)
+
+// MatrixSolveLsFast sets the optional fast attribute to value.
+// If not specified, defaults to true
+func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
+ return func(m optionalAttr) {
+ m["fast"] = value
+ }
+}
+
+// Solves one or more linear least-squares problems.
+//
+// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
+// type as `matrix` and shape `[..., M, K]`.
+// The output is a tensor shape `[..., N, K]` where each output matrix solves
+// each of the equations
+// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
+// in the least squares sense.
+//
+// We use the following notation for (complex) matrix and right-hand sides
+// in the batch:
+//
+// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
+// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\),
+// `output`=\\(X \in \mathbb{C}^{n \times k}\\),
+// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
+//
+// If `fast` is `True`, then the solution is computed by solving the normal
+// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
+// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
+// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
+// minimum-norm solution to the under-determined linear system, i.e.
+// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
+// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
+// when \\(A\\) is numerically full rank and has a condition number
+// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
+// sufficiently large.
+//
+// If `fast` is `False` an algorithm based on the numerically robust complete
+// orthogonal decomposition is used. This computes the minimum-norm
+// least-squares solution, even when \\(A\\) is rank deficient. This path is
+// typically 6-7 times slower than the fast path. If `fast` is `False` then
+// `l2_regularizer` is ignored.
+//
+// Arguments:
+// matrix: Shape is `[..., M, N]`.
+// rhs: Shape is `[..., M, K]`.
+// l2_regularizer: Scalar tensor.
+//
+// @compatibility(numpy)
+// Equivalent to np.linalg.lstsq
+// @end_compatibility
+//
+// Returns Shape is `[..., N, K]`.
+func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixSolveLs",
+ Input: []tf.Input{
+ matrix, rhs, l2_regularizer,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Elementwise computes the bitwise OR of `x` and `y`.
+//
+// The result will have those bits set, that are set in `x`, `y` or both. The
+// computation is performed on the underlying representations of `x` and `y`.
+func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BitwiseOr",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation.
type SparseToSparseSetOperationAttr func(optionalAttr)
@@ -15174,6 +15407,52 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype
return op.Output(0), op.Output(1), op.Output(2)
}
+// MaxPoolAttr is an optional argument to MaxPool.
+type MaxPoolAttr func(optionalAttr)
+
+// MaxPoolDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func MaxPoolDataFormat(value string) MaxPoolAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Performs max pooling on the input.
+//
+// Arguments:
+// input: 4-D input to pool over.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns The max pooled output tensor.
+func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPool",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Says whether the targets are in the top `K` predictions.
//
// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
@@ -16313,83 +16592,6 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// MfccAttr is an optional argument to Mfcc.
-type MfccAttr func(optionalAttr)
-
-// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
-//
-// value: The highest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 4000
-func MfccUpperFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["upper_frequency_limit"] = value
- }
-}
-
-// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
-//
-// value: The lowest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 20
-func MfccLowerFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["lower_frequency_limit"] = value
- }
-}
-
-// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
-//
-// value: Resolution of the Mel bank used internally.
-// If not specified, defaults to 40
-func MfccFilterbankChannelCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["filterbank_channel_count"] = value
- }
-}
-
-// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
-//
-// value: How many output channels to produce per time slice.
-// If not specified, defaults to 13
-func MfccDctCoefficientCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["dct_coefficient_count"] = value
- }
-}
-
-// Transforms a spectrogram into a form that's useful for speech recognition.
-//
-// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
-// been effective as an input feature for machine learning. They are created by
-// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
-// higher frequencies that are less significant to the human ear. They have a long
-// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
-// is a good resource to learn more.
-//
-// Arguments:
-// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
-// set to true.
-// sample_rate: How many samples per second the source audio used.
-func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Mfcc",
- Input: []tf.Input{
- spectrogram, sample_rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the element-wise sum of a list of tensors.
//
// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
@@ -17022,87 +17224,129 @@ func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_uppe
return op.Output(0)
}
-// Computes acos of x element-wise.
-func Acos(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Acos",
- Input: []tf.Input{
- x,
- },
+// QuantizedMatMulAttr is an optional argument to QuantizedMatMul.
+type QuantizedMatMulAttr func(optionalAttr)
+
+// QuantizedMatMulToutput sets the optional Toutput attribute to value.
+// If not specified, defaults to DT_QINT32
+func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr {
+ return func(m optionalAttr) {
+ m["Toutput"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax.
-type MaxPoolWithArgmaxAttr func(optionalAttr)
+// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value.
+//
+// value: If true, `a` is transposed before multiplication.
+// If not specified, defaults to false
+func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr {
+ return func(m optionalAttr) {
+ m["transpose_a"] = value
+ }
+}
-// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value.
-// If not specified, defaults to DT_INT64
-func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr {
+// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value.
+//
+// value: If true, `b` is transposed before multiplication.
+// If not specified, defaults to false
+func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr {
return func(m optionalAttr) {
- m["Targmax"] = value
+ m["transpose_b"] = value
}
}
-// Performs max pooling on the input and outputs both max values and indices.
+// QuantizedMatMulTactivation sets the optional Tactivation attribute to value.
//
-// The indices in `argmax` are flattened, so that a maximum value at position
-// `[b, y, x, c]` becomes flattened index
-// `((b * height + y) * width + x) * channels + c`.
+// value: The type of output produced by activation function
+// following this operation.
+// If not specified, defaults to DT_QUINT8
+func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr {
+ return func(m optionalAttr) {
+ m["Tactivation"] = value
+ }
+}
+
+// Perform a quantized matrix multiplication of `a` by the matrix `b`.
//
-// The indices returned are always in `[0, height) x [0, width)` before flattening,
-// even if padding is involved and the mathematically correct answer is outside
-// (either negative or too large). This is a bug, but fixing it is difficult to do
-// in a safe backwards compatible way, especially due to flattening.
+// The inputs must be two-dimensional matrices and the inner dimension of
+// `a` (after being transposed if `transpose_a` is non-zero) must match the
+// outer dimension of `b` (after being transposed if `transposed_b` is
+// non-zero).
//
// Arguments:
-// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
+// a: Must be a two-dimensional tensor.
+// b: Must be a two-dimensional tensor.
+// min_a: The float value that the lowest quantized `a` value represents.
+// max_a: The float value that the highest quantized `a` value represents.
+// min_b: The float value that the lowest quantized `b` value represents.
+// max_b: The float value that the highest quantized `b` value represents.
//
-// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output.
-func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) {
+// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
+func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "MaxPoolWithArgmax",
+ Type: "QuantizedMatMul",
Input: []tf.Input{
- input,
+ a, b, min_a, max_a, min_b, max_b,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Transforms a serialized tensorflow.TensorProto proto into a Tensor.
+// Does nothing. Serves as a control trigger for scheduling.
//
-// Arguments:
-// serialized: A scalar string containing a serialized TensorProto proto.
-// out_type: The type of the serialized tensor. The provided type must match the
-// type of the serialized tensor and no implicit conversion will take place.
+// Only useful as a placeholder for control edges.
//
-// Returns A Tensor of type `out_type`.
-func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) {
+// Returns the created operation.
+func ControlTrigger(scope *Scope) (o *tf.Operation) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"out_type": out_type}
opspec := tf.OpSpec{
- Type: "ParseTensor",
+ Type: "ControlTrigger",
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Batch normalization.
+//
+// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization()
+//
+// This op is deprecated. Prefer `tf.nn.batch_normalization`.
+//
+// Arguments:
+// t: A 4D input Tensor.
+// m: A 1D mean Tensor with size matching the last dimension of t.
+// This is the first output from tf.nn.moments,
+// or a saved moving average thereof.
+// v: A 1D variance Tensor with size matching the last dimension of t.
+// This is the second output from tf.nn.moments,
+// or a saved moving average thereof.
+// beta: A 1D beta Tensor with size matching the last dimension of t.
+// An offset to be added to the normalized tensor.
+// gamma: A 1D gamma Tensor with size matching the last dimension of t.
+// If "scale_after_normalization" is true, this tensor will be multiplied
+// with the normalized tensor.
+// variance_epsilon: A small float number to avoid dividing by 0.
+// scale_after_normalization: A bool indicating whether the resulted tensor
+// needs to be multiplied with gamma.
+func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization}
+ opspec := tf.OpSpec{
+ Type: "BatchNormWithGlobalNormalization",
Input: []tf.Input{
- serialized,
+ t, m, v, beta, gamma,
},
Attrs: attrs,
}
@@ -17110,113 +17354,95 @@ func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (outp
return op.Output(0)
}
-// MapClearAttr is an optional argument to MapClear.
-type MapClearAttr func(optionalAttr)
-
-// MapClearCapacity sets the optional capacity attribute to value.
-// If not specified, defaults to 0
+// Deprecated. Use TensorArrayReadV3
//
-// REQUIRES: value >= 0
-func MapClearCapacity(value int64) MapClearAttr {
- return func(m optionalAttr) {
- m["capacity"] = value
+// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3
+func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
}
-}
-
-// MapClearMemoryLimit sets the optional memory_limit attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func MapClearMemoryLimit(value int64) MapClearAttr {
- return func(m optionalAttr) {
- m["memory_limit"] = value
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayReadV2",
+ Input: []tf.Input{
+ handle, index, flow_in,
+ },
+ Attrs: attrs,
}
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// MapClearContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func MapClearContainer(value string) MapClearAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
+// QuantizedMulAttr is an optional argument to QuantizedMul.
+type QuantizedMulAttr func(optionalAttr)
-// MapClearSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func MapClearSharedName(value string) MapClearAttr {
+// QuantizedMulToutput sets the optional Toutput attribute to value.
+// If not specified, defaults to DT_QINT32
+func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr {
return func(m optionalAttr) {
- m["shared_name"] = value
+ m["Toutput"] = value
}
}
-// Op removes all elements in the underlying container.
+// Returns x * y element-wise, working on quantized buffers.
//
-// Returns the created operation.
-func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) {
+// Arguments:
+//
+//
+// min_x: The float value that the lowest quantized `x` value represents.
+// max_x: The float value that the highest quantized `x` value represents.
+// min_y: The float value that the lowest quantized `y` value represents.
+// max_y: The float value that the highest quantized `y` value represents.
+//
+// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
+//
+// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about
+// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"dtypes": dtypes}
+ attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "MapClear",
-
+ Type: "QuantizedMul",
+ Input: []tf.Input{
+ x, y, min_x, max_x, min_y, max_y,
+ },
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// DecodeCSVAttr is an optional argument to DecodeCSV.
-type DecodeCSVAttr func(optionalAttr)
+// QuantizedAddAttr is an optional argument to QuantizedAdd.
+type QuantizedAddAttr func(optionalAttr)
-// DecodeCSVFieldDelim sets the optional field_delim attribute to value.
-//
-// value: char delimiter to separate fields in a record.
-// If not specified, defaults to ","
-func DecodeCSVFieldDelim(value string) DecodeCSVAttr {
+// QuantizedAddToutput sets the optional Toutput attribute to value.
+// If not specified, defaults to DT_QINT32
+func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr {
return func(m optionalAttr) {
- m["field_delim"] = value
+ m["Toutput"] = value
}
}
-// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value.
+// Returns x + y element-wise, working on quantized buffers.
//
-// value: If false, treats double quotation marks as regular
-// characters inside of the string fields (ignoring RFC 4180, Section 2,
-// Bullet 5).
-// If not specified, defaults to true
-func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr {
- return func(m optionalAttr) {
- m["use_quote_delim"] = value
- }
-}
-
-// DecodeCSVNaValue sets the optional na_value attribute to value.
+// Arguments:
//
-// value: Additional string to recognize as NA/NaN.
-// If not specified, defaults to ""
-func DecodeCSVNaValue(value string) DecodeCSVAttr {
- return func(m optionalAttr) {
- m["na_value"] = value
- }
-}
-
-// Convert CSV records to tensors. Each column maps to one tensor.
//
-// RFC 4180 format is expected for the CSV records.
-// (https://tools.ietf.org/html/rfc4180)
-// Note that we allow leading and trailing spaces with int or float field.
+// min_x: The float value that the lowest quantized `x` value represents.
+// max_x: The float value that the highest quantized `x` value represents.
+// min_y: The float value that the lowest quantized `y` value represents.
+// max_y: The float value that the highest quantized `y` value represents.
//
-// Arguments:
-// records: Each string is a record/row in the csv and all records should have
-// the same format.
-// record_defaults: One tensor per column of the input record, with either a
-// scalar default value for that column or empty if the column is required.
+// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents.
//
-// Returns Each tensor will have the same shape as records.
-func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) {
+// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about
+// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) {
if scope.Err() != nil {
return
}
@@ -17225,84 +17451,117 @@ func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, opt
a(attrs)
}
opspec := tf.OpSpec{
- Type: "DecodeCSV",
+ Type: "QuantizedAdd",
Input: []tf.Input{
- records, tf.OutputList(record_defaults),
+ x, y, min_x, max_x, min_y, max_y,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// MfccAttr is an optional argument to Mfcc.
+type MfccAttr func(optionalAttr)
+
+// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
+//
+// value: The highest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 4000
+func MfccUpperFrequencyLimit(value float32) MfccAttr {
+ return func(m optionalAttr) {
+ m["upper_frequency_limit"] = value
}
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("DecodeCSV", err)
- return
+}
+
+// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
+//
+// value: The lowest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 20
+func MfccLowerFrequencyLimit(value float32) MfccAttr {
+ return func(m optionalAttr) {
+ m["lower_frequency_limit"] = value
}
- return output
}
-// Returns the rank of a tensor.
+// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
//
-// This operation returns an integer representing the rank of `input`.
+// value: Resolution of the Mel bank used internally.
+// If not specified, defaults to 40
+func MfccFilterbankChannelCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["filterbank_channel_count"] = value
+ }
+}
+
+// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
//
-// For example:
+// value: How many output channels to produce per time slice.
+// If not specified, defaults to 13
+func MfccDctCoefficientCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["dct_coefficient_count"] = value
+ }
+}
+
+// Transforms a spectrogram into a form that's useful for speech recognition.
//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// # shape of tensor 't' is [2, 2, 3]
-// rank(t) ==> 3
-// ```
+// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+// been effective as an input feature for machine learning. They are created by
+// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+// higher frequencies that are less significant to the human ear. They have a long
+// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+// is a good resource to learn more.
//
-// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
-// of a tensor is the number of indices required to uniquely select each element
-// of the tensor. Rank is also known as "order", "degree", or "ndims."
-func Rank(scope *Scope, input tf.Output) (output tf.Output) {
+// Arguments:
+// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
+// set to true.
+// sample_rate: How many samples per second the source audio used.
+func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "Rank",
+ Type: "Mfcc",
Input: []tf.Input{
- input,
+ spectrogram, sample_rate,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Output a fact about factorials.
-func Fact(scope *Scope) (fact tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Fact",
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Makes its input available to the next iteration.
+// Given a quantized tensor described by (input, input_min, input_max), outputs a
+//
+// range that covers the actual values present in that tensor. This op is
+// typically used to produce the requested_output_min and requested_output_max for
+// Requantize.
//
// Arguments:
-// data: The tensor to be made available to the next iteration.
//
-// Returns The same tensor as `data`.
-func NextIteration(scope *Scope, data tf.Output) (output tf.Output) {
+// input_min: The float value that the minimum quantized input value represents.
+// input_max: The float value that the maximum quantized input value represents.
+//
+// Returns The computed min output.the computed max output.
+func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "NextIteration",
+ Type: "RequantizationRange",
Input: []tf.Input{
- data,
+ input, input_min, input_max,
},
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1)
}
// MapPeekAttr is an optional argument to MapPeek.
@@ -18911,101 +19170,6 @@ func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output t
return op.Output(0)
}
-// Elementwise computes the bitwise OR of `x` and `y`.
-//
-// The result will have those bits set, that are set in `x`, `y` or both. The
-// computation is performed on the underlying representations of `x` and `y`.
-func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BitwiseOr",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MatrixSolveLsAttr is an optional argument to MatrixSolveLs.
-type MatrixSolveLsAttr func(optionalAttr)
-
-// MatrixSolveLsFast sets the optional fast attribute to value.
-// If not specified, defaults to true
-func MatrixSolveLsFast(value bool) MatrixSolveLsAttr {
- return func(m optionalAttr) {
- m["fast"] = value
- }
-}
-
-// Solves one or more linear least-squares problems.
-//
-// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
-// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
-// type as `matrix` and shape `[..., M, K]`.
-// The output is a tensor shape `[..., N, K]` where each output matrix solves
-// each of the equations
-// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
-// in the least squares sense.
-//
-// We use the following notation for (complex) matrix and right-hand sides
-// in the batch:
-//
-// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
-// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\),
-// `output`=\\(X \in \mathbb{C}^{n \times k}\\),
-// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
-//
-// If `fast` is `True`, then the solution is computed by solving the normal
-// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
-// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
-// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
-// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
-// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
-// minimum-norm solution to the under-determined linear system, i.e.
-// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
-// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
-// when \\(A\\) is numerically full rank and has a condition number
-// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
-// sufficiently large.
-//
-// If `fast` is `False` an algorithm based on the numerically robust complete
-// orthogonal decomposition is used. This computes the minimum-norm
-// least-squares solution, even when \\(A\\) is rank deficient. This path is
-// typically 6-7 times slower than the fast path. If `fast` is `False` then
-// `l2_regularizer` is ignored.
-//
-// Arguments:
-// matrix: Shape is `[..., M, N]`.
-// rhs: Shape is `[..., M, K]`.
-// l2_regularizer: Scalar tensor.
-//
-// @compatibility(numpy)
-// Equivalent to np.linalg.lstsq
-// @end_compatibility
-//
-// Returns Shape is `[..., N, K]`.
-func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MatrixSolveLs",
- Input: []tf.Input{
- matrix, rhs, l2_regularizer,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SvdAttr is an optional argument to Svd.
type SvdAttr func(optionalAttr)
@@ -20850,6 +21014,61 @@ func Iterator(scope *Scope, shared_name string, container string, output_types [
return op.Output(0)
}
+// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage.
+type CropAndResizeGradImageAttr func(optionalAttr)
+
+// CropAndResizeGradImageMethod sets the optional method attribute to value.
+//
+// value: A string specifying the interpolation method. Only 'bilinear' is
+// supported for now.
+// If not specified, defaults to "bilinear"
+func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr {
+ return func(m optionalAttr) {
+ m["method"] = value
+ }
+}
+
+// Computes the gradient of the crop_and_resize op wrt the input image tensor.
+//
+// Arguments:
+// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`.
+// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
+// specifies the coordinates of a box in the `box_ind[i]` image and is specified
+// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of
+// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the
+// `[0, 1]` interval of normalized image height is mapped to
+// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in
+// which case the sampled crop is an up-down flipped version of the original
+// image. The width dimension is treated similarly. Normalized coordinates
+// outside the `[0, 1]` range are allowed, in which case we use
+// `extrapolation_value` to extrapolate the input image values.
+// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`.
+// The value of `box_ind[i]` specifies the image that the `i`-th box refers to.
+// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]`
+// containing the original image size. Both `image_height` and `image_width` need
+// to be positive.
+//
+//
+// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
+func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"T": T}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "CropAndResizeGradImage",
+ Input: []tf.Input{
+ grads, boxes, box_ind, image_size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ShuffleDatasetAttr is an optional argument to ShuffleDataset.
type ShuffleDatasetAttr func(optionalAttr)
@@ -21717,47 +21936,6 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out
return op.Output(0)
}
-// PlaceholderAttr is an optional argument to Placeholder.
-type PlaceholderAttr func(optionalAttr)
-
-// PlaceholderShape sets the optional shape attribute to value.
-//
-// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
-// shape is unconstrained.
-// If not specified, defaults to <unknown_rank:true >
-func PlaceholderShape(value tf.Shape) PlaceholderAttr {
- return func(m optionalAttr) {
- m["shape"] = value
- }
-}
-
-// A placeholder op for a value that will be fed into the computation.
-//
-// N.B. This operation will fail with an error if it is executed. It is
-// intended as a way to represent a value that will always be fed, and to
-// provide attrs that enable the fed value to be checked at runtime.
-//
-// Arguments:
-// dtype: The type of elements in the tensor.
-//
-// Returns A placeholder tensor that must be replaced using the feed mechanism.
-func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Placeholder",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that executes a SQL query and emits rows of the result set.
//
// Arguments:
@@ -23339,101 +23517,6 @@ func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) {
return scope.AddOperation(opspec)
}
-// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage.
-type CropAndResizeGradImageAttr func(optionalAttr)
-
-// CropAndResizeGradImageMethod sets the optional method attribute to value.
-//
-// value: A string specifying the interpolation method. Only 'bilinear' is
-// supported for now.
-// If not specified, defaults to "bilinear"
-func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr {
- return func(m optionalAttr) {
- m["method"] = value
- }
-}
-
-// Computes the gradient of the crop_and_resize op wrt the input image tensor.
-//
-// Arguments:
-// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`.
-// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
-// specifies the coordinates of a box in the `box_ind[i]` image and is specified
-// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of
-// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the
-// `[0, 1]` interval of normalized image height is mapped to
-// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in
-// which case the sampled crop is an up-down flipped version of the original
-// image. The width dimension is treated similarly. Normalized coordinates
-// outside the `[0, 1]` range are allowed, in which case we use
-// `extrapolation_value` to extrapolate the input image values.
-// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`.
-// The value of `box_ind[i]` specifies the image that the `i`-th box refers to.
-// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]`
-// containing the original image size. Both `image_height` and `image_width` need
-// to be positive.
-//
-//
-// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
-func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"T": T}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "CropAndResizeGradImage",
- Input: []tf.Input{
- grads, boxes, box_ind, image_size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Reads and outputs the entire contents of the input filename.
-func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReadFile",
- Input: []tf.Input{
- filename,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Concatenates tensors along one dimension.
-//
-// Arguments:
-// values: List of `N` Tensors to concatenate. Their ranks and types must match,
-// and their sizes must match in all dimensions except `concat_dim`.
-// axis: 0-D. The dimension along which to concatenate. Must be in the
-// range [-rank(values), rank(values)).
-//
-// Returns A `Tensor` with the concatenation of values stacked along the
-// `concat_dim` dimension. This tensor's shape matches that of `values` except
-// in `concat_dim` where it has the sum of the sizes.
-func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConcatV2",
- Input: []tf.Input{
- tf.OutputList(values), axis,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Forwards the value of an available tensor from `inputs` to `output`.
//
// `Merge` waits for at least one of the tensors in `inputs` to become available.
@@ -27804,86 +27887,3 @@ func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Outp
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Pads a tensor with mirrored values.
-//
-// This operation pads a `input` with mirrored values according to the `paddings`
-// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
-// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-// how many values to add before the contents of `input` in that dimension, and
-// `paddings[D, 1]` indicates how many values to add after the contents of `input`
-// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
-// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
-// (if false, respectively).
-//
-// The padded size of each dimension D of the output is:
-//
-// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 2, 3], [4, 5, 6]].
-// # 'paddings' is [[1, 1]], [2, 2]].
-// # 'mode' is SYMMETRIC.
-// # rank of 't' is 2.
-// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
-// [2, 1, 1, 2, 3, 3, 2]
-// [5, 4, 4, 5, 6, 6, 5]
-// [5, 4, 4, 5, 6, 6, 5]]
-// ```
-//
-// Arguments:
-// input: The input tensor to be padded.
-// paddings: A two-column matrix specifying the padding sizes. The number of
-// rows must be the same as the rank of `input`.
-// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions
-// do not include the borders, while in symmetric mode the padded regions
-// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings`
-// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and
-// it is `[1, 2, 3, 3, 2]` in symmetric mode.
-//
-// Returns The padded tensor.
-func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"mode": mode}
- opspec := tf.OpSpec{
- Type: "MirrorPad",
- Input: []tf.Input{
- input, paddings,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// A placeholder op for a value that will be fed into the computation.
-//
-// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2.
-//
-// N.B. This operation will fail with an error if it is executed. It is
-// intended as a way to represent a value that will always be fed, and to
-// provide attrs that enable the fed value to be checked at runtime.
-//
-// Arguments:
-// dtype: The type of elements in the tensor.
-// shape: The shape of the tensor. The shape can be any partially-specified
-// shape. To be unconstrained, pass in a shape with unknown rank.
-//
-// Returns A placeholder tensor that must be replaced using the feed mechanism.
-func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype, "shape": shape}
- opspec := tf.OpSpec{
- Type: "PlaceholderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7bd05bb6e0..94a6355e9a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2884,9 +2884,11 @@ py_library(
":client",
":control_flow_ops",
":data_flow_ops",
+ ":device",
":errors",
":framework",
":framework_for_generated_wrappers",
+ ":framework_ops",
":gradients",
":init_ops",
":io_ops",
@@ -2911,6 +2913,7 @@ py_library(
":variable_scope",
":variables",
"//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 701f68b8f7..55ba509065 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1013,12 +1013,13 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = EagerTensor_id(tensor);
const tensorflow::Tensor* tensor = nullptr;
- const tensorflow::Status status = t->Tensor(&tensor);
+ const tensorflow::Status status = t->handle->Tensor(&tensor);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->dtype,
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype,
tensorflow::TensorShape({})};
} else {
- return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()};
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype,
+ tensor->shape()};
}
}
tensorflow::int64 id = FastTensorId(tensor);
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index c9635a9c27..bb033d3495 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -887,11 +887,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
Raises:
ValueError: If `thresholds` contains a value outside of `(0, 1)`.
ValueError: If `loss_reduction` is invalid.
+ TypeError: if `label_vocabulary` has invalid type.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if label_vocabulary is not None and not isinstance(label_vocabulary,
(list, tuple)):
- raise ValueError(
+ raise TypeError(
'label_vocabulary should be a list or tuple. Given type: {}'.format(
type(label_vocabulary)))
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
index 4bb9941bb7..391b17720c 100644
--- a/tensorflow/python/framework/meta_graph.py
+++ b/tensorflow/python/framework/meta_graph.py
@@ -737,7 +737,9 @@ def import_scoped_meta_graph(meta_graph_or_file,
import_scope or "", mark_as_used=False)
importer.import_graph_def(
- input_graph_def, name=(import_scope or ""), input_map=input_map,
+ input_graph_def,
+ name=(import_scope or scope_to_prepend_to_names),
+ input_map=input_map,
producer_op_list=producer_op_list)
# Restores all the other collections.
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 21963d0bee..5d5fb037fc 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -537,6 +537,21 @@ class ScopedMetaGraphTest(test.TestCase):
self.assertEqual(list(imported_variables.values())[0].name,
"foo/bar/myvar:0")
+ def testScopedImportUnderNameScopeNoVarScope(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ variables.Variable(initial_value=1.0, trainable=True, name="myvar")
+ meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)
+
+ graph = ops.Graph()
+ with graph.as_default():
+ with ops.name_scope("foo"):
+ imported_variables = meta_graph.import_scoped_meta_graph(
+ meta_graph_def)
+ self.assertEqual(len(imported_variables), 1)
+ self.assertEqual(list(imported_variables.values())[0].name,
+ "foo/myvar:0")
+
def testImportsUsingSameScopeName(self):
with ops.Graph().as_default():
variables.Variable(0, name="v")
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 93edaa0cf0..e579289a8d 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -5863,6 +5863,9 @@ def strip_name_scope(name, export_scope):
is None.
"""
if export_scope:
+ if export_scope[-1] == "/":
+ export_scope = export_scope[:-1]
+
try:
# Strips export_scope/, export_scope///,
# ^export_scope/, loc:@export_scope/.
@@ -5888,6 +5891,9 @@ def prepend_name_scope(name, import_scope):
is None.
"""
if import_scope:
+ if import_scope[-1] == "/":
+ import_scope = import_scope[:-1]
+
try:
str_to_replace = r"([\^]|loc:@|^)(.*)"
return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
diff --git a/tensorflow/python/framework/test_file_system.cc b/tensorflow/python/framework/test_file_system.cc
index 094ea6f658..6e9915adbb 100644
--- a/tensorflow/python/framework/test_file_system.cc
+++ b/tensorflow/python/framework/test_file_system.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/null_file_system.h"
namespace tensorflow {
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 711106d2db..16033e9b8f 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -402,11 +402,10 @@ py_test(
py_test(
name = "convolutional_recurrent_test",
- size = "medium",
+ size = "large",
srcs = ["_impl/keras/layers/convolutional_recurrent_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
- tags = ["noasan"], # times out b/63678675
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 2dc993f811..742564f9bf 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -103,6 +103,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(False, name="bool_test")
self.assertAllEqual(bool(v), False)
+ def testFetchHandle(self):
+ with self.test_session():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1], name="foo")
+ self.assertGreater(len(handle.eval()), 0)
+
def testAssignVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
@@ -179,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterSub(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(handle, [0],
+ constant_op.constant(
+ [[2]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMul(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(handle, [0],
+ constant_op.constant(
+ [[5]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterDiv(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMin(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMax(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(handle, [0],
+ constant_op.constant(
+ [[3]],
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterAddScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_add(handle, [0],
+ constant_op.constant(
+ 2,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterSubScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(handle, [0],
+ constant_op.constant(
+ 2,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMulScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(handle, [0],
+ constant_op.constant(
+ 5,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterDivScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMinScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ def testScatterMaxScalar(self):
+ with ops.device("cpu:0"):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[6]],
+ dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(handle, [0],
+ constant_op.constant(
+ 3,
+ dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
+
def testScatterUpdateString(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.string, shape=[1, 1])
@@ -190,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
compat.as_bytes("b"))
+ def testScatterUpdateStringScalar(self):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.string, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [["a"]],
+ dtype=dtypes.string)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_update(handle, [0],
+ constant_op.constant(
+ "b",
+ dtype=dtypes.string)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
+ self.assertEqual(
+ compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
+
# TODO(alive): get this to work in Eager mode.
def testGPU(self):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 7cdf11d884..c70a4ffce7 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates):
ref[indx] += updates[i]
+def _NumpyAddScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] += update
+
+
def _NumpySub(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] -= updates[i]
+def _NumpySubScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] -= update
+
+
def _NumpyMul(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] *= updates[i]
+def _NumpyMulScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] *= update
+
+
def _NumpyDiv(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] /= updates[i]
+def _NumpyDivScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] /= update
+
+
+def _NumpyMin(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] = np.minimum(ref[indx], updates[i])
+
+
+def _NumpyMinScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = np.minimum(ref[indx], update)
+
+
+def _NumpyMax(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] = np.maximum(ref[indx], updates[i])
+
+
+def _NumpyMaxScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = np.maximum(ref[indx], update)
+
+
def _NumpyUpdate(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
ref[indx] = updates[i]
+def _NumpyUpdateScalar(ref, indices, update):
+ for _, indx in np.ndenumerate(indices):
+ ref[indx] = update
+
+
_TF_OPS_TO_NUMPY = {
state_ops.scatter_update: _NumpyUpdate,
state_ops.scatter_add: _NumpyAdd,
state_ops.scatter_sub: _NumpySub,
state_ops.scatter_mul: _NumpyMul,
state_ops.scatter_div: _NumpyDiv,
+ state_ops.scatter_min: _NumpyMin,
+ state_ops.scatter_max: _NumpyMax,
+}
+
+_TF_OPS_TO_NUMPY_SCALAR = {
+ state_ops.scatter_update: _NumpyUpdateScalar,
+ state_ops.scatter_add: _NumpyAddScalar,
+ state_ops.scatter_sub: _NumpySubScalar,
+ state_ops.scatter_mul: _NumpyMulScalar,
+ state_ops.scatter_div: _NumpyDivScalar,
+ state_ops.scatter_min: _NumpyMinScalar,
+ state_ops.scatter_max: _NumpyMaxScalar,
}
class ScatterTest(test.TestCase):
- def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False):
+ def _VariableRankTest(self,
+ tf_scatter,
+ vtype,
+ itype,
+ repeat_indices=False,
+ updates_are_scalar=False):
np.random.seed(8)
with self.test_session(use_gpu=True):
for indices_shape in (), (2,), (3, 7), (3, 4, 7):
@@ -89,8 +151,11 @@ class ScatterTest(test.TestCase):
indices[np.random.randint(size // 2)])
np.random.shuffle(indices)
indices = indices.reshape(indices_shape)
- updates = _AsType(
- np.random.randn(*(indices_shape + extra_shape)), vtype)
+ if updates_are_scalar:
+ updates = _AsType(np.random.randn(), vtype)
+ else:
+ updates = _AsType(
+ np.random.randn(*(indices_shape + extra_shape)), vtype)
# Clips small values to avoid division by zero.
def clip_small_values(x):
@@ -101,7 +166,10 @@ class ScatterTest(test.TestCase):
# Scatter via numpy
new = old.copy()
- np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
+ if updates_are_scalar:
+ np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter]
+ else:
+ np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = variables.Variable(old)
@@ -109,25 +177,35 @@ class ScatterTest(test.TestCase):
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
- def _VariableRankTests(self, tf_scatter, repeat_indices=False):
+ def _VariableRankTests(self,
+ tf_scatter,
+ repeat_indices=False,
+ updates_are_scalar=False):
for vtype in (np.float32, np.float64):
for itype in (np.int32, np.int64):
- self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices)
+ self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
+ updates_are_scalar)
def testVariableRankUpdate(self):
- self._VariableRankTests(state_ops.scatter_update)
+ self._VariableRankTests(state_ops.scatter_update, False)
def testVariableRankAdd(self):
- self._VariableRankTests(state_ops.scatter_add)
+ self._VariableRankTests(state_ops.scatter_add, False)
def testVariableRankSub(self):
- self._VariableRankTests(state_ops.scatter_sub)
+ self._VariableRankTests(state_ops.scatter_sub, False)
def testVariableRankMul(self):
- self._VariableRankTests(state_ops.scatter_mul)
+ self._VariableRankTests(state_ops.scatter_mul, False)
def testVariableRankDiv(self):
- self._VariableRankTests(state_ops.scatter_div)
+ self._VariableRankTests(state_ops.scatter_div, False)
+
+ def testVariableRankMin(self):
+ self._VariableRankTests(state_ops.scatter_min, False)
+
+ def testVariableRankMax(self):
+ self._VariableRankTests(state_ops.scatter_max, False)
def testRepeatIndicesAdd(self):
self._VariableRankTests(state_ops.scatter_add, True)
@@ -141,6 +219,51 @@ class ScatterTest(test.TestCase):
def testRepeatIndicesDiv(self):
self._VariableRankTests(state_ops.scatter_div, True)
+ def testRepeatIndicesMin(self):
+ self._VariableRankTests(state_ops.scatter_min, True)
+
+ def testRepeatIndicesMax(self):
+ self._VariableRankTests(state_ops.scatter_max, True)
+
+ def testVariableRankUpdateScalar(self):
+ self._VariableRankTests(state_ops.scatter_update, False, True)
+
+ def testVariableRankAddScalar(self):
+ self._VariableRankTests(state_ops.scatter_add, False, True)
+
+ def testVariableRankSubScalar(self):
+ self._VariableRankTests(state_ops.scatter_sub, False, True)
+
+ def testVariableRankMulScalar(self):
+ self._VariableRankTests(state_ops.scatter_mul, False, True)
+
+ def testVariableRankDivScalar(self):
+ self._VariableRankTests(state_ops.scatter_div, False, True)
+
+ def testVariableRankMinScalar(self):
+ self._VariableRankTests(state_ops.scatter_min, False, True)
+
+ def testVariableRankMaxScalar(self):
+ self._VariableRankTests(state_ops.scatter_max, False, True)
+
+ def testRepeatIndicesAddScalar(self):
+ self._VariableRankTests(state_ops.scatter_add, True, True)
+
+ def testRepeatIndicesSubScalar(self):
+ self._VariableRankTests(state_ops.scatter_sub, True, True)
+
+ def testRepeatIndicesMulScalar(self):
+ self._VariableRankTests(state_ops.scatter_mul, True, True)
+
+ def testRepeatIndicesDivScalar(self):
+ self._VariableRankTests(state_ops.scatter_div, True, True)
+
+ def testRepeatIndicesMinScalar(self):
+ self._VariableRankTests(state_ops.scatter_min, True, True)
+
+ def testRepeatIndicesMaxScalar(self):
+ self._VariableRankTests(state_ops.scatter_max, True, True)
+
def testBooleanScatterUpdate(self):
if not test.is_gpu_available():
with self.test_session(use_gpu=False) as session:
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index 994af69386..a07e305ffb 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -267,7 +267,9 @@ gtl::InlinedVector<npy_intp, 4> GetPyArrayDimensionsForTensor(
const int ndims = TF_NumDims(tensor);
gtl::InlinedVector<npy_intp, 4> dims(ndims);
if (TF_TensorType(tensor) == TF_RESOURCE) {
- dims[0] = TF_TensorByteSize(tensor);
+ CHECK_EQ(ndims, 0)
+ << "Fetching of non-scalar resource tensors is not supported.";
+ dims.push_back(TF_TensorByteSize(tensor));
*nelems = dims[0];
} else {
*nelems = 1;
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 02eafd42b3..22317a348c 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -166,7 +166,7 @@ bool IsSingleNone(PyObject* obj) {
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
const Tensor** output_tensor) {
- return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor);
+ return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor);
}
// Calls the registered py function through the trampoline.
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 4071e50e81..0866fa8b0b 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -593,7 +593,7 @@ class Distribution(_BaseDistribution):
Returns:
batch_shape: `TensorShape`, possibly unknown.
"""
- return self._batch_shape()
+ return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
raise NotImplementedError("event_shape_tensor is not implemented")
@@ -626,7 +626,7 @@ class Distribution(_BaseDistribution):
Returns:
event_shape: `TensorShape`, possibly unknown.
"""
- return self._event_shape()
+ return tensor_shape.as_shape(self._event_shape())
def is_scalar_event(self, name="is_scalar_event"):
"""Indicates that `event_shape == []`.
@@ -1105,6 +1105,34 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
return self._kl_divergence(other)
+ def __str__(self):
+ return ("tf.distributions.{type_name}("
+ "\"{self_name}\""
+ "{maybe_batch_shape}"
+ "{maybe_event_shape}"
+ ", dtype={dtype})".format(
+ type_name=type(self).__name__,
+ self_name=self.name,
+ maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
+ if self.batch_shape.ndims is not None
+ else ""),
+ maybe_event_shape=(", event_shape={}".format(self.event_shape)
+ if self.event_shape.ndims is not None
+ else ""),
+ dtype=self.dtype.name))
+
+ def __repr__(self):
+ return ("<tf.distributions.{type_name} "
+ "'{self_name}'"
+ " batch_shape={batch_shape}"
+ " event_shape={event_shape}"
+ " dtype={dtype}>".format(
+ type_name=type(self).__name__,
+ self_name=self.name,
+ batch_shape=self.batch_shape,
+ event_shape=self.event_shape,
+ dtype=self.dtype.name))
+
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
"""Helper function to standardize op scope."""
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 230b7ef937..e90ff0746a 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add
from tensorflow.python.ops.state_ops import scatter_div
from tensorflow.python.ops.state_ops import scatter_mul
from tensorflow.python.ops.state_ops import scatter_sub
+from tensorflow.python.ops.state_ops import scatter_min
+from tensorflow.python.ops.state_ops import scatter_max
from tensorflow.python.ops.state_ops import scatter_update
from tensorflow.python.ops.state_ops import scatter_nd_add
from tensorflow.python.ops.state_ops import scatter_nd_sub
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index c3ad5831b4..01fc3182bc 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -63,6 +63,8 @@
@@scatter_nd_update
@@scatter_sub
@@scatter_update
+@@scatter_min
+@@scatter_max
@@sparse_mask
@@tables_initializer
@@trainable_variables
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py
new file mode 100644
index 0000000000..f1137e80ab
--- /dev/null
+++ b/tensorflow/python/training/device_util.py
@@ -0,0 +1,68 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Device-related support functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import ops
+
+
+def canonicalize(d):
+ d = tf_device.DeviceSpec.from_string(d)
+ assert d.device_type is None or d.device_type == d.device_type.upper(), (
+ "Device type '%s' must be all-caps." % (d.device_type,))
+ # Fill in missing device fields using defaults.
+ result = tf_device.DeviceSpec(
+ job="localhost", replica=0, task=0, device_type="CPU", device_index=0)
+ result.merge_from(d)
+ return result.to_string()
+
+
+class _FakeNodeDef(object):
+ """A fake NodeDef for _FakeOperation."""
+
+ def __init__(self):
+ self.op = ""
+ self.name = ""
+
+
+class _FakeOperation(object):
+ """A fake Operation object to pass to device functions."""
+
+ def __init__(self):
+ self.device = ""
+ self.type = ""
+ self.name = ""
+ self.node_def = _FakeNodeDef()
+
+ def _set_device(self, device):
+ self.device = ops._device_string(device) # pylint: disable=protected-access
+
+
+def current():
+ """Return a string (not canonicalized) for the current device."""
+ # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
+ ctx = context.context()
+ if ctx.executing_eagerly():
+ d = ctx.device_name
+ else:
+ op = _FakeOperation()
+ ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
+ d = op.device
+ return d
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 03e3e0857f..ab5e6590e0 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3157,12 +3157,18 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
dnn::DataType output_type, float scale,
DeviceMemoryBase* output_data) {
mutex_lock lock{dnn_handle_mutex_};
+ cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+ }
+
float beta = 0.0f;
ScopedTensorDescriptor input_tensor_desc(
parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout()));
ScopedTensorDescriptor output_tensor_desc(
parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout()));
- cudnnStatus_t status = wrap::cudnnTransformTensor(
+ status = wrap::cudnnTransformTensor(
parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(),
input_data.opaque(), &beta, output_tensor_desc.handle(),
output_data->opaque());
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 55b82dd765..937044aece 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1689,6 +1689,14 @@ tf_module {
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "scatter_max"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_min"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
name: "scatter_mul"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh
index 564c5aa148..d81793efe0 100755
--- a/tensorflow/tools/ci_build/builds/android.sh
+++ b/tensorflow/tools/ci_build/builds/android.sh
@@ -29,7 +29,8 @@ echo "========== TensorFlow Demo Build Test =========="
# Enable sandboxing so that zip archives don't get incorrectly packaged
# in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334)
# TODO(gunan): remove extra flags once sandboxing is enabled for all builds.
-bazel --bazelrc=/dev/null build -c opt --fat_apk_cpu=x86_64 \
+bazel --bazelrc=/dev/null build \
+ --compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \
--spawn_strategy=sandboxed --genrule_strategy=sandboxed \
//tensorflow/examples/android:tensorflow_demo
diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh
index 9d449241e8..41dc66dd54 100755
--- a/tensorflow/tools/ci_build/builds/android_full.sh
+++ b/tensorflow/tools/ci_build/builds/android_full.sh
@@ -40,7 +40,8 @@ rm -rf ${AAR_LIB_TMP}
for CPU in ${CPUS//,/ }
do
echo "========== Building native libs for Android ${CPU} =========="
- bazel build -c opt --config=monolithic --cpu=${CPU} \
+ bazel build --config=monolithic --cpu=${CPU} \
+ --compilation_mode=opt --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
//tensorflow/core:android_tensorflow_lib \
@@ -62,7 +63,8 @@ done
# in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334)
# TODO(gunan): remove extra flags once sandboxing is enabled for all builds.
echo "========== Building TensorFlow Android Jar and Demo =========="
-bazel --bazelrc=/dev/null build -c opt --config=monolithic --fat_apk_cpu=${CPUS} \
+bazel --bazelrc=/dev/null build --config=monolithic --fat_apk_cpu=${CPUS} \
+ --compilation_mode=opt --cxxopt=-std=c++11 \
--spawn_strategy=sandboxed --genrule_strategy=sandboxed \
//tensorflow/contrib/android:android_tensorflow_inference_java \
//tensorflow/contrib/android:android_tensorflow_inference_java.aar \
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 34dd419f15..d22a465376 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -211,6 +211,7 @@ def _get_default_do_not_descend_map():
'tf': ['cli', 'lib', 'wrappers'],
'tf.contrib': [
'compiler',
+ 'distribute',
'grid_rnn',
# Block contrib.keras to de-clutter the docs
'keras',
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 8a1c7db2ea..f8fb6ecb0c 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -51,6 +51,9 @@ import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.spinn import data
+layers = tf.keras.layers
+
+
def _bundle(lstm_iter):
"""Concatenate a list of Tensors along 1st axis and split result into two.
@@ -78,17 +81,16 @@ def _unbundle(state):
return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
-class Reducer(tfe.Network):
+# pylint: disable=not-callable
+class Reducer(tf.keras.Model):
"""A module that applies reduce operation on left and right vectors."""
def __init__(self, size, tracker_size=None):
super(Reducer, self).__init__()
- self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
- self.right = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.left = layers.Dense(5 * size, activation=None)
+ self.right = layers.Dense(5 * size, activation=None, use_bias=False)
if tracker_size is not None:
- self.track = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.track = layers.Dense(5 * size, activation=None, use_bias=False)
else:
self.track = None
@@ -123,7 +125,7 @@ class Reducer(tfe.Network):
return h, c
-class Tracker(tfe.Network):
+class Tracker(tf.keras.Model):
"""A module that tracks the history of the sentence with an LSTM."""
def __init__(self, tracker_size, predict):
@@ -134,10 +136,10 @@ class Tracker(tfe.Network):
predict: (`bool`) Whether prediction mode is enabled.
"""
super(Tracker, self).__init__()
- self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+ self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
self._state_size = tracker_size
if predict:
- self._transition = self.track_layer(tf.layers.Dense(4))
+ self._transition = layers.Dense(4)
else:
self._transition = None
@@ -182,7 +184,7 @@ class Tracker(tfe.Network):
return unbundled, None
-class SPINN(tfe.Network):
+class SPINN(tf.keras.Model):
"""Stack-augmented Parser-Interpreter Neural Network.
See https://arxiv.org/abs/1603.06021 for more details.
@@ -204,9 +206,9 @@ class SPINN(tfe.Network):
"""
super(SPINN, self).__init__()
self.config = config
- self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+ self.reducer = Reducer(config.d_hidden, config.d_tracker)
if config.d_tracker is not None:
- self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+ self.tracker = Tracker(config.d_tracker, config.predict)
else:
self.tracker = None
@@ -248,7 +250,7 @@ class SPINN(tfe.Network):
trans = transitions[i]
if self.tracker:
# Invoke tracker to obtain the current tracker states for the sentences.
- tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+ tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
if trans_hypothesis:
trans = tf.argmax(trans_hypothesis, axis=-1)
else:
@@ -264,7 +266,8 @@ class SPINN(tfe.Network):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(lefts, rights, trackings)
+ reducer_output = self.reducer(
+ lefts, right_in=rights, tracking=trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -273,7 +276,27 @@ class SPINN(tfe.Network):
return _bundle([stack.pop() for stack in stacks])[0]
-class SNLIClassifier(tfe.Network):
+class Perceptron(tf.keras.Model):
+ """One layer of the SNLIClassifier multi-layer perceptron."""
+
+ def __init__(self, dimension, dropout_rate, previous_layer):
+ """Configure the Perceptron."""
+ super(Perceptron, self).__init__()
+ self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
+ self.batchnorm = layers.BatchNormalization()
+ self.dropout = layers.Dropout(rate=dropout_rate)
+ self.previous_layer = previous_layer
+
+ def call(self, x, training):
+ """Run previous Perceptron layers, then this one."""
+ x = self.previous_layer(x, training=training)
+ x = self.dense(x)
+ x = self.batchnorm(x, training=training)
+ x = self.dropout(x, training=training)
+ return x
+
+
+class SNLIClassifier(tf.keras.Model):
"""SNLI Classifier Model.
A model aimed at solving the SNLI (Standford Natural Language Inference)
@@ -304,29 +327,24 @@ class SNLIClassifier(tfe.Network):
self.config = config
self.embed = tf.constant(embed)
- self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
- self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
- self.embed_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.embed_dropout))
- self.encoder = self.track_layer(SPINN(config))
-
- self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
- self.feature_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.mlp_dropout))
-
- self.mlp_dense = []
- self.mlp_bn = []
- self.mlp_dropout = []
- for _ in xrange(config.n_mlp_layers):
- self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
- self.mlp_bn.append(
- self.track_layer(tf.layers.BatchNormalization()))
- self.mlp_dropout.append(
- self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
- self.mlp_output = self.track_layer(tf.layers.Dense(
+ self.projection = layers.Dense(config.d_proj)
+ self.embed_bn = layers.BatchNormalization()
+ self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
+ self.encoder = SPINN(config)
+
+ self.feature_bn = layers.BatchNormalization()
+ self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)
+
+ current_mlp = lambda result, training: result
+ for _ in range(config.n_mlp_layers):
+ current_mlp = Perceptron(dimension=config.d_mlp,
+ dropout_rate=config.mlp_dropout,
+ previous_layer=current_mlp)
+ self.mlp = current_mlp
+ self.mlp_output = layers.Dense(
config.d_out,
kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
- maxval=5e-3)))
+ maxval=5e-3))
def call(self,
premise,
@@ -370,10 +388,10 @@ class SNLIClassifier(tfe.Network):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(premise_embed, premise_transition,
- training=training)
- hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
- training=training)
+ premise = self.encoder(
+ premise_embed, transitions=premise_transition, training=training)
+ hypothesis = self.encoder(
+ hypothesis_embed, transitions=hypothesis_transition, training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -383,15 +401,12 @@ class SNLIClassifier(tfe.Network):
self.feature_bn(logits, training=training), training=training)
# Apply the multi-layer perceptron on the logits.
- for dense, bn, dropout in zip(
- self.mlp_dense, self.mlp_bn, self.mlp_dropout):
- logits = tf.nn.elu(dense(logits))
- logits = dropout(bn(logits, training=training), training=training)
+ logits = self.mlp(logits, training=training)
logits = self.mlp_output(logits)
return logits
-class SNLIClassifierTrainer(object):
+class SNLIClassifierTrainer(tfe.Checkpointable):
"""A class that coordinates the training of an SNLIClassifier."""
def __init__(self, snli_classifier, lr):
@@ -450,10 +465,11 @@ class SNLIClassifierTrainer(object):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
+ # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition,
- hypothesis,
- hypothesis_transition,
+ premise_transition=premise_transition,
+ hypothesis=hypothesis,
+ hypothesis_transition=hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -517,7 +533,9 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = trainer.model(
+ prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -609,29 +627,30 @@ def train_or_infer_spinn(embed,
with tf.device(device), \
summary_writer.as_default(), \
tf.contrib.summary.always_record_summaries():
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- model = SNLIClassifier(config, embed)
- global_step = tf.train.get_or_create_global_step()
- trainer = SNLIClassifierTrainer(model, config.lr)
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+ checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair:
# Inference mode.
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- prem, prem_trans = inference_sentence_pair[0]
- hypo, hypo_trans = inference_sentence_pair[1]
- hypo_trans = inference_sentence_pair[1][1]
- inference_logits = model( # pylint: disable=not-callable
- tf.constant(prem), tf.constant(prem_trans),
- tf.constant(hypo), tf.constant(hypo_trans), training=False)
- inference_logits = inference_logits[0][1:]
- max_index = tf.argmax(inference_logits)
- print("\nInference logits:")
- for i, (label, logit) in enumerate(
- zip(data.POSSIBLE_LABELS, inference_logits)):
- winner_tag = " (winner)" if max_index == i else ""
- print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+ prem, prem_trans = inference_sentence_pair[0]
+ hypo, hypo_trans = inference_sentence_pair[1]
+ hypo_trans = inference_sentence_pair[1][1]
+ inference_logits = model(
+ tf.constant(prem),
+ premise_transition=tf.constant(prem_trans),
+ hypothesis=tf.constant(hypo),
+ hypothesis_transition=tf.constant(hypo_trans),
+ training=False)
+ inference_logits = inference_logits[0][1:]
+ max_index = tf.argmax(inference_logits)
+ print("\nInference logits:")
+ for i, (label, logit) in enumerate(
+ zip(data.POSSIBLE_LABELS, inference_logits)):
+ winner_tag = " (winner)" if max_index == i else ""
+ print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
return inference_logits
train_len = train_data.num_batches(config.batch_size)
@@ -650,20 +669,15 @@ def train_or_infer_spinn(embed,
# remain on CPU. Same in _evaluate_on_dataset().
iterations += 1
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- batch_train_loss, batch_train_logits = trainer.train_batch(
- label, prem, prem_trans, hypo, hypo_trans)
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
batch_size = tf.shape(label)[0]
mean_loss(batch_train_loss.numpy(),
weights=batch_size.gpu() if use_gpu else batch_size)
accuracy(tf.argmax(batch_train_logits, axis=1), label)
if iterations % config.save_every == 0:
- all_variables = trainer.variables + [global_step]
- saver = tfe.Saver(all_variables)
- saver.save(os.path.join(config.logdir, "ckpt"),
- global_step=global_step)
+ checkpoint.save(os.path.join(config.logdir, "ckpt"))
if iterations % config.dev_every == 0:
dev_loss, dev_frac_correct = _evaluate_on_dataset(